diff --git a/.gitignore b/.gitignore index 9757054a50f9..debad77ec2ad 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ *.iml *.iws *.pyc +*.pyo .idea/ .idea_modules/ build/*.jar @@ -62,6 +63,10 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out +python/lib/pyspark.zip +lint-r-report.log # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 769defbac11b..9165872b9fb2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,4 +1,5 @@ target +cache .gitignore .gitattributes .project @@ -14,10 +15,12 @@ TAGS RELEASE control docs +docker.properties.template fairscheduler.xml.template spark-defaults.conf.template log4j.properties log4j.properties.template +metrics.properties metrics.properties.template slaves slaves.template @@ -25,9 +28,15 @@ spark-env.sh spark-env.cmd spark-env.sh.template log4j-defaults.properties +log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js +d3.min.js +dagre-d3.min.js +graphlib-dot.min.js sorttable.js +vis.min.js +vis.min.css .*avsc .*txt .*json @@ -65,3 +74,24 @@ logs .*scalastyle-output.xml .*dependency-reduced-pom.xml known_translations +json_expectation +local-1422981759269/* +local-1422981780767/* +local-1425081759269/* +local-1426533911241/* +local-1426633911242/* +local-1430917381534/* +local-1430917381535_1 +local-1430917381535_2 +DESCRIPTION +NAMESPACE +test_support/* +.*Rd +help/* +html/* +INDEX +.lintr +gen-java.* +.*avpr +org.apache.spark.sql.sources.DataSourceRegister +.*parquet diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b6c6b050fa33..f10d7e277eea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,16 @@ ## Contributing to Spark -Contributions via GitHub pull requests are gladly accepted from their original -author. Along with any pull requests, please state that the contribution is -your original work and that you license the work to the project under the -project's open source license. Whether or not you state this explicitly, by -submitting any copyrighted material via pull request, email, or other means -you agree to license the material under the project's open source license and -warrant that you have the legal authority to do so. +*Before opening a pull request*, review the +[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +It lists steps that are required before creating a PR. In particular, consider: + +- Is the change important and ready enough to ask the community to spend time reviewing? +- Have you searched for existing, related JIRAs and pull requests? +- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is the change being proposed clearly explained and motivated? -Please see the [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) -for more information. +When you contribute code, you affirm that the contribution is your original work and that you +license the work to the project under the project's open source license. Whether or not you +state this explicitly, by submitting any copyrighted material via pull request, email, or +other means you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. diff --git a/LICENSE b/LICENSE index 0a42d389e4c3..f9e412cade34 100644 --- a/LICENSE +++ b/LICENSE @@ -643,6 +643,36 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +======================================================================== +For d3 (core/src/main/resources/org/apache/spark/ui/static/d3.min.js): +======================================================================== + +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ======================================================================== For Scala Interpreter classes (all .scala files in repl/src/main/scala @@ -771,6 +801,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java): +======================================================================== +Copyright (C) 2015 Stijn de Gouw + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. ======================================================================== For LimitedInputStream @@ -790,6 +836,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For vis.js (core/src/main/resources/org/apache/spark/ui/static/vis.min.js): +======================================================================== +Copyright (C) 2010-2015 Almende B.V. + +Vis.js is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + + * The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. + +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. ======================================================================== BSD-style licenses @@ -798,7 +906,8 @@ BSD-style licenses The following components are provided under a BSD-style license. See project link for details. (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) + (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) @@ -839,5 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) + (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 000000000000..9a5889ba28b2 --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 000000000000..931d01549b26 --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 000000000000..005f56da1670 --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/dataframe.R + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/dataframe.R +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 000000000000..3f889c0ca3d1 --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 000000000000..d2ae160b5002 --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +set -o pipefail +set -e + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Install the package (this will also generate the Rd files) +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 000000000000..008a5c668bc4 --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 000000000000..59d98c9c7a64 --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + +set -o pipefail +set -e + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +pushd $FWDIR > /dev/null + +# Generate Rd files if devtools is installed +Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' + +# Install SparkR to $LIB_DIR +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +popd > /dev/null diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 000000000000..cce8d9152d32 --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R/target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/.lintr b/R/pkg/.lintr new file mode 100644 index 000000000000..038236fc149e --- /dev/null +++ b/R/pkg/.lintr @@ -0,0 +1,2 @@ +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 000000000000..d0d7201f004a --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,36 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.5.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'schema.R' + 'generics.R' + 'jobj.R' + 'RDD.R' + 'pairRDD.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'backend.R' + 'broadcast.R' + 'client.R' + 'context.R' + 'deserialize.R' + 'functions.R' + 'mllib.R' + 'serialize.R' + 'sparkR.R' + 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 000000000000..3e5c89d779b7 --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,246 @@ +# Imports from base R +importFrom(methods, setGeneric, setMethod, setOldClass) + +# Disable native libraries till we figure out how to package it +# See SPARKR-7839 +#useDynLib(SparkR, stringHashCode) + +# S3 methods exported +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") + +# MLlib integration +exportMethods("glm", + "predict", + "summary") + +# Job group lifecycle management methods +export("setJobGroup", + "clearJobGroup", + "cancelJobGroup") + +exportClasses("DataFrame") + +exportMethods("arrange", + "cache", + "collect", + "columns", + "count", + "crosstab", + "describe", + "dim", + "distinct", + "dropna", + "dtypes", + "except", + "explain", + "fillna", + "filter", + "first", + "group_by", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "join", + "limit", + "merge", + "names", + "ncol", + "nrow", + "orderBy", + "mutate", + "names", + "persist", + "printSchema", + "rbind", + "registerTempTable", + "rename", + "repartition", + "sample", + "sample_frac", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "summarize", + "summary", + "take", + "unionAll", + "unique", + "unpersist", + "where", + "withColumn", + "withColumnRenamed", + "write.df") + +exportClasses("Column") + +exportMethods("abs", + "acos", + "add_months", + "alias", + "approxCountDistinct", + "asc", + "ascii", + "asin", + "atan", + "atan2", + "avg", + "base64", + "between", + "bin", + "bitwiseNOT", + "cast", + "cbrt", + "ceil", + "ceiling", + "concat", + "concat_ws", + "contains", + "conv", + "cos", + "cosh", + "count", + "countDistinct", + "crc32", + "date_add", + "date_format", + "date_sub", + "datediff", + "dayofmonth", + "dayofyear", + "desc", + "endsWith", + "exp", + "explode", + "expm1", + "expr", + "factorial", + "first", + "floor", + "format_number", + "format_string", + "from_unixtime", + "from_utc_timestamp", + "getField", + "getItem", + "greatest", + "hex", + "hour", + "hypot", + "ifelse", + "initcap", + "instr", + "isNaN", + "isNotNull", + "isNull", + "last", + "last_day", + "least", + "length", + "levenshtein", + "like", + "lit", + "locate", + "log", + "log10", + "log1p", + "log2", + "lower", + "lpad", + "ltrim", + "max", + "md5", + "mean", + "min", + "minute", + "month", + "months_between", + "n", + "n_distinct", + "nanvl", + "negate", + "next_day", + "otherwise", + "pmod", + "quarter", + "rand", + "randn", + "regexp_extract", + "regexp_replace", + "reverse", + "rint", + "rlike", + "round", + "rpad", + "rtrim", + "second", + "sha1", + "sha2", + "shiftLeft", + "shiftRight", + "shiftRightUnsigned", + "sign", + "signum", + "sin", + "sinh", + "size", + "soundex", + "sqrt", + "startsWith", + "substr", + "substring_index", + "sum", + "sumDistinct", + "tan", + "tanh", + "toDegrees", + "toRadians", + "to_date", + "to_utc_timestamp", + "translate", + "trim", + "unbase64", + "unhex", + "unix_timestamp", + "upper", + "weekofyear", + "when", + "year") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "loadDF", + "parquetFile", + "read.df", + "sql", + "table", + "tableNames", + "tables", + "uncacheTable") + +export("structField", + "structField.jobj", + "structField.character", + "print.structField", + "structType", + "structType.jobj", + "structType.structField", + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 000000000000..a5162de705f8 --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1810 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' @docType class +#' +#' @slot env An R environment that stores bookkeeping states of the DataFrame +#' @slot sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +#' +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @param isCached TRUE if the dataFrame is cached +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @name printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @name schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @name explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @name isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' showDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @name showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) + cat(s) + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @name show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' df +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @name dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @name columns +#' @aliases names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @name names +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name names<- +setMethod("names<-", + signature(x = "DataFrame"), + function(x, value) { + if (!is.null(value)) { + sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + dataFrame(sdf) + } + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @name registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlContext, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + invisible(callJMethod(x@sdf, "registerTempTable", tableName)) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @name insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") +#' df2 <- read.df(sqlContext, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache +#' @name cache +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @name persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @name unpersist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @name repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +# toJSON +# +# Convert the rows of a DataFrame into JSON objects and return an RDD where +# each element contains a JSON string. +# +#@param x A SparkSQL DataFrame +# @return A StringRRDD of JSON objects +# @rdname tojson +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# sqlContext <- sparkRSQL.init(sc) +# path <- "path/to/file.json" +# df <- jsonFile(sqlContext, path) +# newRDD <- toJSON(df) +#} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @name saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @name distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' @title Distinct rows in a DataFrame +# +#' @description Returns a new DataFrame containing distinct rows in this DataFrame +#' +#' @rdname unique +#' @name unique +#' @aliases distinct +setMethod("unique", + signature(x = "DataFrame"), + function(x) { + distinct(x) + }) + +#' Sample +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sample +#' @aliases sample_frac +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' collect(sample(df, FALSE, 0.5)) +#' collect(sample(df, TRUE, 0.5)) +#'} +setMethod("sample", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' @rdname sample +#' @name sample_frac +setMethod("sample_frac", + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + sample(x, withReplacement, fraction) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @name count +#' @aliases nrow +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' @title Number of rows for a DataFrame +#' @description Returns number of rows in a DataFrames +#' +#' @name nrow +#' +#' @rdname nrow +#' @aliases count +setMethod("nrow", + signature(x = "DataFrame"), + function(x) { + count(x) + }) + +#' Returns the number of columns in a DataFrame +#' +#' @param x a SparkSQL DataFrame +#' +#' @rdname ncol +#' @name ncol +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' ncol(df) +#' } +setMethod("ncol", + signature(x = "DataFrame"), + function(x) { + length(columns(x)) + }) + +#' Returns the dimentions (number of rows and columns) of a DataFrame +#' @param x a SparkSQL DataFrame +#' +#' @rdname dim +#' @name dim +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' dim(df) +#' } +setMethod("dim", + signature(x = "DataFrame"), + function(x) { + c(count(x), ncol(x)) + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. +#' @rdname collect +#' @name collect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + names <- columns(x) + ncol <- length(names) + if (ncol <= 0) { + # empty data.frame with 0 columns and 0 rows + data.frame() + } else { + # listCols is a list of columns + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + stopifnot(length(listCols) == ncol) + + # An empty data.frame with 0 columns and number of rows as collected + nrow <- length(listCols[[1]]) + if (nrow <= 0) { + df <- data.frame() + } else { + df <- data.frame(row.names = 1 : nrow) + } + + # Append columns one by one + for (colIndex in 1 : ncol) { + # Note: appending a column of list type into a data.frame so that + # data of complex type can be held. But getting a cell from a column + # of list type returns a list instead of a vector. So for columns of + # non-complex type, append them as vector. + col <- listCols[[colIndex]] + if (length(col) <= 0) { + df[[names[colIndex]]] <- col + } else { + # TODO: more robust check on column of primitive types + vec <- do.call(c, col) + if (class(vec) != "list") { + df[[names[colIndex]]] <- vec + } else { + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. + df[[names[colIndex]]] <- col + } + } + } + df + } + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @name limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +#' Take the first NUM rows of a DataFrame and return a the results as a data.frame +#' +#' @rdname take +#' @name take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @name head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @name first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +# toRDD +# +# Converts a Spark DataFrame to an RDD while preserving column names. +# +# @param x A Spark DataFrame +# +# @rdname DataFrame +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# sqlContext <- sparkRSQL.init(sc) +# path <- "path/to/file.json" +# df <- jsonFile(sqlContext, path) +# rdd <- toRDD(df) +# } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @aliases group_by +#' @rdname groupBy +#' @name groupBy +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' @rdname groupBy +#' @name group_by +setMethod("group_by", + signature(x = "DataFrame"), + function(x, ...) { + groupBy(x, ...) + }) + +#' Summarize data across columns +#' +#' Compute aggregates by specifying a list of columns +#' +#' @param x a DataFrame +#' @rdname agg +#' @name agg +#' @aliases summarize +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + +#' @rdname agg +#' @name summarize +setMethod("summarize", + signature(x = "DataFrame"), + function(x, ...) { + agg(x, ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +# @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +# @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +# @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +# @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +# @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +# @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +# @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +#' @name $ +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +#' @rdname select +#' @name $<- +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column" || is.null(value)) + cols <- columns(x) + if (name %in% cols) { + if (is.null(value)) { + cols <- Filter(function(c) { c != name }, cols) + } + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + if (is.null(value)) { + return(x) + } + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +setClassUnion("numericOrcharacter", c("numeric", "character")) + +#' @rdname select +#' @name [[ +setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +#' @name [ +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' @rdname select +#' @name [ +setMethod("[", signature(x = "DataFrame", i = "Column"), + function(x, i, j, ...) { + # It could handle i as "character" but it seems confusing and not required + # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html + filtered <- filter(x, i) + if (!missing(j)) { + filtered[, j] + } else { + filtered + } + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' df[,c("name", "age")] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' # It can also be subset on rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c) == "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @name selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @name withColumn +#' @aliases mutate +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' Mutate +#' +#' Return a new DataFrame with the specified columns added. +#' +#' @param x A DataFrame +#' @param col a named argument of the form name = col +#' @return A new DataFrame with the new columns added. +#' @rdname withColumn +#' @name mutate +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) +#' names(newDF) # Will contain newCol, newCol2 +#' } +setMethod("mutate", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + stopifnot(length(cols) > 0) + stopifnot(class(cols[[1]]) == "Column") + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] <- alias(cols[[n]], n) + } + } + } + do.call(select, c(x, x$"*", cols)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @name withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +#' Rename +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param newCol A named pair of the form new_column_name = existing_column +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @name rename +#' @aliases withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' newDF <- rename(df, col1 = df$newCol1) +#' } +setMethod("rename", + signature(x = "DataFrame"), + function(x, ...) { + renameCols <- list(...) + stopifnot(length(renameCols) > 0) + stopifnot(class(renameCols[[1]]) == "Column") + newNames <- names(renameCols) + oldNames <- lapply(renameCols, function(col) { + callJMethod(col@jc, "toString") + }) + cols <- lapply(columns(x), function(c) { + if (c %in% oldNames) { + alias(col(c), newNames[[match(c, oldNames)]]) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' Arrange +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname arrange +#' @name arrange +#' @aliases orderby +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' arrange(df, df$col1) +#' arrange(df, "col1") +#' arrange(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("arrange", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname arrange +#' @name orderby +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + arrange(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to filter on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @name filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @name where +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @name join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' @rdname merge +#' @name merge +#' @aliases join +setMethod("merge", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL, ...) { + join(x, y, joinExpr, joinType) + }) + + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' Note that this does not remove duplicate rows across the two DataFrames. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @name unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' @title Union two or more DataFrames +# +#' @description Returns a new DataFrame containing rows of all parameters. +# +#' @rdname rbind +#' @name rbind +#' @aliases unionAll +setMethod("rbind", + signature(... = "DataFrame"), + function(x, ..., deparse.level = 1) { + if (nargs() == 3) { + unionAll(x, ...) + } else { + unionAll(x, Recall(..., deparse.level = 1)) + } + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @name intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' except +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the except operation. +#' @rdname except +#' @name except +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) +#' exceptDF <- except(df, df2) +#' } +#' @rdname except +#' @export +setMethod("except", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + excepted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(excepted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname write.df +#' @name write.df +#' @aliases saveDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' write.df(df, "myfile", "parquet", "overwrite") +#' } +setMethod("write.df", + signature(df = "DataFrame", path = "character"), + function(df, path, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + # nolint start + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + # nolint end + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + +#' @rdname write.df +#' @name saveDF +#' @export +setMethod("saveDF", + signature(df = "DataFrame", path = "character"), + function(df, path, source = NULL, mode = "append", ...){ + write.df(df, path, source, mode, ...) + }) + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @name saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = "character", source = "character", + mode = "character"), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + # nolint start + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + # nolint end + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + +#' describe +#' +#' Computes statistics for numeric columns. +#' If no columns are given, this function computes statistics for all numerical columns. +#' +#' @param x A DataFrame to be computed. +#' @param col A string of name +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname describe +#' @name describe +#' @aliases summary +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' describe(df) +#' describe(df, "col1") +#' describe(df, "col1", "col2") +#' } +setMethod("describe", + signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + colList <- list(col, ...) + sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + dataFrame(sdf) + }) + +#' @rdname describe +#' @name describe +setMethod("describe", + signature(x = "DataFrame"), + function(x) { + colList <- as.list(c(columns(x))) + sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + dataFrame(sdf) + }) + +#' @title Summary +#' +#' @description Computes statistics for numeric columns of the DataFrame +#' +#' @rdname summary +#' @name summary +setMethod("summary", + signature(x = "DataFrame"), + function(x) { + describe(x) + }) + + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @name dropna +#' @aliases na.omit +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @rdname nafunctions +#' @name na.omit +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @name fillna +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have zero as their counts. +#' +#' @rdname statfunctions +#' @name crosstab +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlCtx, "/path/to/file.json") +#' ct = crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 000000000000..051e441d4e06 --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1644 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +# @title S4 class that represents an RDD +# @description RDD can be created using functions like +# \code{parallelize}, \code{textFile} etc. +# @rdname RDD +# @seealso parallelize, textFile +# +# @slot env An R environment that stores bookkeeping states of the RDD +# @slot jrdd Java object reference to the backing JavaRDD +# to an RDD +# @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("show", "RDD", + function(object) { + cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep="")) + }) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + # nolint start + !(e$isCached || e$isCheckpointed) + # nolint end + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- cleanClosure(func) + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(partIndex, part) { + f <- prev@func + func(partIndex, f(partIndex, part)) + } + .Object@func <- cleanClosure(pipelinedFunc) + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +# @rdname RDD +# @export +# +# @param jrdd Java object reference to the backing JavaRDD +# @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +# stores strings, and "row" if the RDD stores the rows of a DataFrame +# @param isCached TRUE if the RDD is cached +# @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(rdd@func, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +# Persist an RDD +# +# Persist this RDD with the default storage level (MEMORY_ONLY). +# +# @param x The RDD to cache +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 2L) +# cache(rdd) +#} +# @rdname cache-methods +# @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +# Persist an RDD +# +# Persist this RDD with the specified storage level. For details of the +# supported storage levels, refer to +# http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +# +# @param x The RDD to persist +# @param newLevel The new storage level to be assigned +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 2L) +# persist(rdd, "MEMORY_AND_DISK") +#} +# @rdname persist +# @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel = "MEMORY_ONLY") { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +# Unpersist an RDD +# +# Mark the RDD as non-persistent, and remove all blocks for it from memory and +# disk. +# +# @param x The RDD to unpersist +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 2L) +# cache(rdd) # rdd@@env$isCached == TRUE +# unpersist(rdd) # rdd@@env$isCached == FALSE +#} +# @rdname unpersist-methods +# @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +# Checkpoint an RDD +# +# Mark this RDD for checkpointing. It will be saved to a file inside the +# checkpoint directory set with setCheckpointDir() and all references to its +# parent RDDs will be removed. This function must be called before any job has +# been executed on this RDD. It is strongly recommended that this RDD is +# persisted in memory, otherwise saving it on a file will require recomputation. +# +# @param x The RDD to checkpoint +# @examples +#\dontrun{ +# sc <- sparkR.init() +# setCheckpointDir(sc, "checkpoint") +# rdd <- parallelize(sc, 1:10, 2L) +# checkpoint(rdd) +#} +# @rdname checkpoint-methods +# @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +# Gets the number of partitions of an RDD +# +# @param x A RDD. +# @return the number of partitions of rdd as an integer. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 2L) +# numPartitions(rdd) # 2L +#} +# @rdname numPartitions +# @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "partitions") + callJMethod(partitions, "size") + }) + +# Collect elements of an RDD +# +# @description +# \code{collect} returns a list that contains all of the elements in this RDD. +# +# @param x The RDD to collect +# @param ... Other optional arguments to collect +# @param flatten FALSE if the list should not flattened +# @return a list containing elements in the RDD +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 2L) +# collect(rdd) # list from 1 to 10 +# collectPartition(rdd, 0L) # list from 1 to 5 +#} +# @rdname collect-methods +# @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +# @description +# \code{collectPartition} returns a list that contains all of the elements +# in the specified partition of the RDD. +# @param partitionId the partition to collect (starts from 0) +# @rdname collect-methods +# @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +# @description +# \code{collectAsMap} returns a named list as a map that contains all of the elements +# in a key-value pair RDD. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +# collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#} +# @rdname collect-methods +# @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +# Return the number of elements in the RDD. +# +# @param x The RDD to count +# @return number of elements in the RDD. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# count(rdd) # 10 +# length(rdd) # Same as count +#} +# @rdname count +# @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +# Return the number of elements in the RDD +# @export +# @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +# Return the count of each unique value in this RDD as a list of +# (value, count) pairs. +# +# Same as countByValue in Spark. +# +# @param x The RDD to count +# @return list of (value, count) pairs, where count is number of each unique +# value in rdd. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, c(1,2,3,2,1)) +# countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#} +# @rdname countByValue +# @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +# Apply a function to all elements +# +# This function creates a new RDD by applying the given transformation to all +# elements of the given RDD +# +# @param X The RDD to apply the transformation. +# @param FUN the transformation to apply on each element +# @return a new RDD created by the transformation. +# @rdname lapply +# @aliases lapply +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +# collect(multiplyByTwo) # 2,4,6... +#} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(partIndex, part) { + lapply(part, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +# @rdname lapply +# @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +# Flatten results after apply a function to all elements +# +# This function return a new RDD by first applying a function to all +# elements of this RDD, and then flattening the results. +# +# @param X The RDD to apply the transformation. +# @param FUN the transformation to apply on each element +# @return a new RDD created by the transformation. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +# collect(multiplyByTwo) # 2,20,4,40,6,60... +#} +# @rdname flatMap +# @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +# Apply a function to each partition of an RDD +# +# Return a new RDD by applying a function to each partition of this RDD. +# +# @param X The RDD to apply the transformation. +# @param FUN the transformation to apply on each partition. +# @return a new RDD created by the transformation. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +# collect(partitionSum) # 15, 40 +#} +# @rdname lapplyPartition +# @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +# mapPartitions is the same as lapplyPartition. +# +# @rdname lapplyPartition +# @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +# Return a new RDD by applying a function to each partition of this RDD, while +# tracking the index of the original partition. +# +# @param X The RDD to apply the transformation. +# @param FUN the transformation to apply on each partition; takes the partition +# index and a list of elements in the particular partition. +# @return a new RDD created by the transformation. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 5L) +# prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +# partIndex * Reduce("+", part) }) +# collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#} +# @rdname lapplyPartitionsWithIndex +# @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + PipelinedRDD(X, FUN) + }) + +# @rdname lapplyPartitionsWithIndex +# @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +# This function returns a new RDD containing only the elements that satisfy +# a predicate (i.e. returning TRUE in a given logical function). +# The same as `filter()' in Spark. +# +# @param x The RDD to be filtered. +# @param f A unary predicate function. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#} +# @rdname filterRDD +# @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +# @rdname filterRDD +# @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +# Reduce across elements of an RDD. +# +# This function reduces the elements of this RDD using the +# specified commutative and associative binary operator. +# +# @param x The RDD to reduce +# @param func Commutative and associative function to apply on elements +# of the RDD. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# reduce(rdd, "+") # 55 +#} +# @rdname reduce +# @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +# Get the maximum element of an RDD. +# +# @param x The RDD to get the maximum element from +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# maximum(rdd) # 10 +#} +# @rdname maximum +# @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +# Get the minimum element of an RDD. +# +# @param x The RDD to get the minimum element from +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# minimum(rdd) # 1 +#} +# @rdname minimum +# @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +# Add up the elements in an RDD. +# +# @param x The RDD to add up the elements in +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# sumRDD(rdd) # 55 +#} +# @rdname sumRDD +# @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +# Applies a function to all elements in an RDD, and force evaluation. +# +# @param x The RDD to apply the function +# @param func The function to be applied. +# @return invisible NULL. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# foreach(rdd, function(x) { save(x, file=...) }) +#} +# @rdname foreach +# @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +# Applies a function to each partition in an RDD, and force evaluation. +# +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#} +# @rdname foreach +# @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +# Take elements from an RDD. +# +# This function takes the first NUM elements in the RDD and +# returns them in a list. +# +# @param x The RDD to take elements from +# @param num Number of elements to take +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# take(rdd, 2L) # list(1, 2) +#} +# @rdname take +# @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + serializedModeRDD <- getSerializedMode(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList + }) + + +# First +# +# Return the first element of an RDD +# +# @rdname first +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# first(rdd) +# } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +# Removes the duplicates from RDD. +# +# This function returns a new RDD containing the distinct elements in the +# given RDD. The same as `distinct()' in Spark. +# +# @param x The RDD to remove duplicates from. +# @param numPartitions Number of partitions to create. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, c(1,2,2,3,3,3)) +# sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#} +# @rdname distinct +# @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR:::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +# Return an RDD that is a sampled subset of the given RDD. +# +# The same as `sample()' in Spark. (We rename it due to signature +# inconsistencies with the `sample()' function in R's base package.) +# +# @param x The RDD to sample elements from +# @param withReplacement Sampling with replacement or not +# @param fraction The (rough) sample target fraction +# @param seed Randomness seed value +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +# collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#} +# @rdname sampleRDD +# @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(partIndex, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(partIndex) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +# Return a list of the elements that are a sampled subset of the given RDD. +# +# @param x The RDD to sample elements from +# @param withReplacement Sampling with replacement or not +# @param num Number of elements to return +# @param seed Randomness seed value +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:100) +# # exactly 5 elements sampled, which may not be distinct +# takeSample(rdd, TRUE, 5L, 1618L) +# # exactly 5 distinct elements sampled +# takeSample(rdd, FALSE, 5L, 16181618L) +#} +# @rdname takeSample +# @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + base::sample(samples)[1:total] + }) + +# Creates tuples of the elements in this RDD by applying a function. +# +# @param x The RDD. +# @param func The function to be applied. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1, 2, 3)) +# collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#} +# @rdname keyBy +# @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +# Return a new RDD that has exactly numPartitions partitions. +# Can increase or decrease the level of parallelism in this RDD. Internally, +# this uses a shuffle to redistribute data. +# If you are decreasing the number of partitions in this RDD, consider using +# coalesce, which can avoid performing a shuffle. +# +# @param x The RDD. +# @param numPartitions Number of partitions to create. +# @seealso coalesce +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +# numPartitions(rdd) # 4 +# numPartitions(repartition(rdd, 2L)) # 2 +#} +# @rdname repartition +# @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numPartitions, TRUE) + }) + +# Return a new RDD that is reduced into numPartitions partitions. +# +# @param x The RDD. +# @param numPartitions Number of partitions to create. +# @seealso repartition +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +# numPartitions(rdd) # 3 +# numPartitions(coalesce(rdd, 1L)) # 1 +#} +# @rdname coalesce +# @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR:::numPartitions(x)) { + func <- function(partIndex, part) { + set.seed(partIndex) # partIndex as seed + start <- as.integer(base::sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +# Save this RDD as a SequenceFile of serialized objects. +# +# @param x The RDD to save +# @param path The directory where the file is saved +# @seealso objectFile +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:3) +# saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#} +# @rdname saveAsObjectFile +# @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +# Save this RDD as a text file, using string representations of elements. +# +# @param x The RDD to save +# @param path The directory where the partitions of the text file are saved +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:3) +# saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#} +# @rdname saveAsTextFile +# @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +# Sort an RDD by the given key function. +# +# @param x An RDD to be sorted. +# @param func A function used to compute the sort key for each element. +# @param ascending A flag to indicate whether the sorting is ascending or descending. +# @param numPartitions Number of partitions to create. +# @return An RDD where all elements are sorted. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(3, 2, 1)) +# collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#} +# @rdname sortBy +# @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + part[ord[1:num]] + } else { + part + } + } + + newRdd <- mapPartitions(x, partitionFunc) + + resList <- list() + index <- -1 + jrdd <- getJRDD(newRdd) + numPartitions <- numPartitions(newRdd) + serializedModeRDD <- getSerializedMode(newRdd) + + while (TRUE) { + index <- index + 1 + + if (index >= numPartitions) { + ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending) + resList <- resList[ord[1:num]] + break + } + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + # elems is capped to have at most `num` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = num, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList +} + +# Returns the first N elements from an RDD in ascending order. +# +# @param x An RDD. +# @param num Number of elements to return. +# @return The first N elements from the RDD in ascending order. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +# takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#} +# @rdname takeOrdered +# @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +# Returns the top N elements from an RDD. +# +# @param x An RDD. +# @param num Number of elements to return. +# @return The top N elements from the RDD. +# @rdname top +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +# top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#} +# @rdname top +# @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +# Fold an RDD using a given associative function and a neutral "zero value". +# +# Aggregate the elements of each partition, and then the results for all the +# partitions, using a given associative function and a neutral "zero value". +# +# @param x An RDD. +# @param zeroValue A neutral "zero value". +# @param op An associative function for the folding operation. +# @return The folding result. +# @rdname fold +# @seealso reduce +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +# fold(rdd, 0, "+") # 15 +#} +# @rdname fold +# @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +# Aggregate an RDD using the given combine functions and a neutral "zero value". +# +# Aggregate the elements of each partition, and then the results for all the +# partitions, using given combine functions and a neutral "zero value". +# +# @param x An RDD. +# @param zeroValue A neutral "zero value". +# @param seqOp A function to aggregate the RDD elements. It may return a different +# result type from the type of the RDD elements. +# @param combOp A function to aggregate results of seqOp. +# @return The aggregation result. +# @rdname aggregateRDD +# @seealso reduce +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1, 2, 3, 4)) +# zeroValue <- list(0, 0) +# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +# aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#} +# @rdname aggregateRDD +# @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +# Pipes elements to a forked external process. +# +# The same as 'pipe()' in Spark. +# +# @param x The RDD whose elements are piped to the forked external process. +# @param command The command to fork an external process. +# @param env A named list to set environment variables of the external process. +# @return A new RDD created by piping all elements to a forked external process. +# @rdname pipeRDD +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# collect(pipeRDD(rdd, "more") +# Output: c("1", "2", ..., "10") +#} +# @rdname pipeRDD +# @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim_trailing_func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim_trailing_func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim_trailing_func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +# Return an RDD's name. +# +# @param x The RDD whose name is returned. +# @rdname name +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1,2,3)) +# name(rdd) # NULL (if not set before) +#} +# @rdname name +# @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +# Set an RDD's name. +# +# @param x The RDD whose name is to be set. +# @param name The RDD name to be set. +# @return a new RDD renamed. +# @rdname setName +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(1,2,3)) +# setName(rdd, "myRDD") +# name(rdd) # "myRDD" +#} +# @rdname setName +# @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +# Zip an RDD with generated unique Long IDs. +# +# Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +# n is the number of partitions. So there may exist gaps, but this +# method won't trigger a spark job, which is different from +# zipWithIndex. +# +# @param x An RDD to be zipped. +# @return An RDD with zipped items. +# @seealso zipWithIndex +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +# collect(zipWithUniqueId(rdd)) +# # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#} +# @rdname zipWithUniqueId +# @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(partIndex, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + partIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +# Zip an RDD with its element indices. +# +# The ordering is first based on the partition index and then the +# ordering of items within each partition. So the first item in +# the first partition gets index 0, and the last item in the last +# partition receives the largest index. +# +# This method needs to trigger a Spark job when this RDD contains +# more than one partition. +# +# @param x An RDD to be zipped. +# @return An RDD with zipped items. +# @seealso zipWithUniqueId +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +# collect(zipWithIndex(rdd)) +# # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#} +# @rdname zipWithIndex +# @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(partIndex, part) { + if (partIndex == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[partIndex]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +# Coalesce all elements within each partition of an RDD into a list. +# +# @param x An RDD. +# @return An RDD created by coalescing all elements within +# each partition into a list. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, as.list(1:4), 2L) +# collect(glom(rdd)) +# # list(list(1, 2), list(3, 4)) +#} +# @rdname glom +# @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +# Return the union RDD of two RDDs. +# The same as union() in Spark. +# +# @param x An RDD. +# @param y An RDD. +# @return a new RDD created by performing the simple union (witout removing +# duplicates) of two input RDDs. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:3) +# unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#} +# @rdname unionRDD +# @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +# Zip an RDD with another RDD. +# +# Zips this RDD with another one, returning key-value pairs with the +# first element in each RDD second element in each RDD, etc. Assumes +# that the two RDDs have the same number of partitions and the same +# number of elements in each partition (e.g. one was made through +# a map on the other). +# +# @param x An RDD to be zipped. +# @param other Another RDD to be zipped. +# @return An RDD zipped from the two RDDs. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, 0:4) +# rdd2 <- parallelize(sc, 1000:1004) +# collect(zipRDD(rdd1, rdd2)) +# # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#} +# @rdname zipRDD +# @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) + + mergePartitions(rdd, TRUE) + }) + +# Cartesian product of this RDD and another one. +# +# Return the Cartesian product of this RDD and another one, +# that is, the RDD of all pairs of elements (a, b) where a +# is in this and b is in other. +# +# @param x An RDD. +# @param other An RDD. +# @return A new RDD which is the Cartesian product of these two RDDs. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:2) +# sortByKey(cartesian(rdd, rdd)) +# # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#} +# @rdname cartesian +# @aliases cartesian,RDD,RDD-method +setMethod("cartesian", + signature(x = "RDD", other = "RDD"), + function(x, other) { + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) + + mergePartitions(rdd, FALSE) + }) + +# Subtract an RDD with another RDD. +# +# Return an RDD with the elements from this that are not in other. +# +# @param x An RDD. +# @param other An RDD. +# @param numPartitions Number of the partitions in the result RDD. +# @return An RDD with the elements from this that are not in other. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +# rdd2 <- parallelize(sc, list(2, 4)) +# collect(subtract(rdd1, rdd2)) +# # list(1, 1, 3) +#} +# @rdname subtract +# @aliases subtract,RDD +setMethod("subtract", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR:::numPartitions(x)) { + mapFunction <- function(e) { list(e, NA) } + rdd1 <- map(x, mapFunction) + rdd2 <- map(other, mapFunction) + keys(subtractByKey(rdd1, rdd2, numPartitions)) + }) + +# Intersection of this RDD and another one. +# +# Return the intersection of this RDD and another one. +# The output will not contain any duplicate elements, +# even if the input RDDs did. Performs a hash partition +# across the cluster. +# Note that this method performs a shuffle internally. +# +# @param x An RDD. +# @param other An RDD. +# @param numPartitions The number of partitions in the result RDD. +# @return An RDD which is the intersection of these two RDDs. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +# rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +# collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +# # list(1, 2, 3) +#} +# @rdname intersection +# @aliases intersection,RDD +setMethod("intersection", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR:::numPartitions(x)) { + rdd1 <- map(x, function(v) { list(v, NA) }) + rdd2 <- map(other, function(v) { list(v, NA) }) + + filterFunction <- function(elem) { + iters <- elem[[2]] + all(as.vector( + lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical")) + } + + keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) + }) + +# Zips an RDD's partitions with one (or more) RDD(s). +# Same as zipPartitions in Spark. +# +# @param ... RDDs to be zipped. +# @param func A function to transform zipped partitions. +# @return A new RDD by applying a function to the zipped partitions. +# Assumes that all the RDDs have the *same number of partitions*, but +# does *not* require them to have the same number of elements in each partition. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +# rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +# rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +# collect(zipPartitions(rdd1, rdd2, rdd3, +# func = function(x, y, z) { list(list(x, y, z))} )) +# # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#} +# @rdname zipRDD +# @aliases zipPartitions,RDD +setMethod("zipPartitions", + "RDD", + function(..., func) { + rrdds <- list(...) + if (length(rrdds) == 1) { + return(rrdds[[1]]) + } + nPart <- sapply(rrdds, numPartitions) + if (length(unique(nPart)) != 1) { + stop("Can only zipPartitions RDDs which have the same number of partitions.") + } + + rrdds <- lapply(rrdds, function(rdd) { + mapPartitionsWithIndex(rdd, function(partIndex, part) { + print(length(part)) + list(list(partIndex, part)) + }) + }) + union.rdd <- Reduce(unionRDD, rrdds) + zipped.rdd <- values(groupByKey(union.rdd, numPartitions = nPart[1])) + res <- mapPartitions(zipped.rdd, function(plist) { + do.call(func, plist[[1]]) + }) + res + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 000000000000..110117a18ccb --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,513 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + structField(names[[i]], types[[i]], TRUE) + }) + do.call(structType, fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlContext A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlContext, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + if (is.null(schema)) { + schema <- names(data) + } + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + structField(names[[i]], types[[i]], TRUE) + }) + schema <- do.call(structType, fields) + } + + stopifnot(class(schema) == "structType") + # schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schema$jobj, sqlContext) + dataFrame(sdf) +} + +# toDF +# +# Converts an RDD to a DataFrame by infer the types. +# +# @param x An RDD +# +# @rdname DataFrame +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# sqlContext <- sparkRSQL.init(sc) +# rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +# df <- toDF(rdd) +# } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlContext, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlContext SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' } + +jsonFile <- function(sqlContext, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlContext, "jsonFile", path) + dataFrame(sdf) +} + + +# JSON RDD +# +# Loads an RDD storing one JSON object per string as a DataFrame. +# +# @param sqlContext SQLContext to use +# @param rdd An RDD of JSON string +# @param schema A StructType object to use as schema +# @param samplingRatio The ratio of simpling used to infer the schema +# @return A DataFrame +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# sqlContext <- sparkRSQL.init(sc) +# rdd <- texFile(sc, "path/to/json") +# df <- jsonRDD(sqlContext, rdd) +# } + +# TODO: support schema +jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlContext, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlContext, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlContext, "SELECT * FROM table") +#' } + +sql <- function(sqlContext, sqlQuery) { + sdf <- callJMethod(sqlContext, "sql", sqlQuery) + dataFrame(sdf) +} + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlContext SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlContext, "table") +#' } + +table <- function(sqlContext, tableName) { + sdf <- callJMethod(sqlContext, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlContext SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' tables(sqlContext, "hive") +#' } + +tables <- function(sqlContext, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlContext, "tables") + } else { + callJMethod(sqlContext, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlContext SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' tableNames(sqlContext, "hive") +#' } + +tableNames <- function(sqlContext, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlContext, "tableNames") + } else { + callJMethod(sqlContext, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlContext SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlContext, "table") +#' } + +cacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlContext SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlContext, "table") +#' } + +uncacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlContext SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlContext) +#' } + +clearCache <- function(sqlContext) { + callJMethod(sqlContext, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlContext SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlContext, "table") +#' } + +dropTempTable <- function(sqlContext, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlContext, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlContext SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, "path/to/file.json", source = "json") +#' } + +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } + dataFrame(sdf) +} + +#' @aliases loadDF +#' @export + +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlContext SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 000000000000..49162838b8d1 --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + if (returnStatus != 0) { + stop(readString(conn)) + } + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 000000000000..2403925b267c --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +# @title S4 class that represents a Broadcast variable +# @description Broadcast variables can be created using the broadcast +# function from a \code{SparkContext}. +# @rdname broadcast-class +# @seealso broadcast +# +# @param id Id of the backing Spark broadcast variable +# @export +setClass("Broadcast", slots = list(id = "character")) + +# @rdname broadcast-class +# @param value Value of the broadcast variable +# @param jBroadcastRef reference to the backing Java broadcast object +# @param objName name of broadcasted object +# @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +# @description +# \code{value} can be used to get the value of a broadcast variable inside +# a distributed function. +# +# @param bcast The broadcast variable to get +# @rdname broadcast +# @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +# Internal function to set values of a broadcast variable. +# +# This function is used internally by Spark to set the value of a broadcast +# variable on workers. Not intended for use outside the package. +# +# @rdname broadcast-internal +# @seealso broadcast, value + +# @param bcastId The id of broadcast variable to set +# @param value The value to be set +# @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +# Helper function to clear the list of broadcast variables we know about +# Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 000000000000..c811d1dac3bd --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +determineSparkSubmitBin <- function() { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName <- "spark-submit" + } else { + sparkSubmitBinName <- "spark-submit.cmd" + } + sparkSubmitBinName +} + +generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + if (jars != "") { + jars <- paste("--jars", jars) + } + + if (!identical(packages, "")) { + packages <- paste("--packages", packages) + } + + combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") + combinedArgs +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + sparkSubmitBinName <- determineSparkSubmitBin() + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 000000000000..4805096f3f9c --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,234 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R schema.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns +#' @rdname column +#' +#' @slot jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +#' @name show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or", #, "!" = "unary_$bang" + "^" = "pow" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + if (op == "^") { + jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2) + } else { + callJMethod(e1@jc, operators[[op]], e2) + } + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +#' +#' @rdname alias +#' @name alias +#' @family colum_func +#' @export +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' substr +#' +#' An expression that returns a substring. +#' +#' @rdname substr +#' @name substr +#' @family colum_func +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname between +#' @name between +#' @family colum_func +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + +#' Casts the column to a different data type. +#' +#' @rdname cast +#' @name cast +#' @family colum_func +#' +#' @examples \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Match a column with given values. +#' +#' @rdname match +#' @name %in% +#' @aliases %in% +#' @return a matched values as a result of comparing with given values. +#' @export +#' @examples +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) +#' } +setMethod("%in%", + signature(x = "Column"), + function(x, table) { + table <- listToSeq(as.list(table)) + jc <- callJMethod(x@jc, "in", table) + return(column(jc)) + }) + +#' otherwise +#' +#' If values in the specified column are null, returns the value. +#' Can be used in conjunction with `when` to specify a default value for expressions. +#' +#' @rdname otherwise +#' @name otherwise +#' @family colum_func +#' @export +setMethod("otherwise", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJMethod(x@jc, "otherwise", value) + column(jc) + }) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 000000000000..720990e1c608 --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinPartitions <- function(sc, minPartitions) { + if (is.null(minPartitions)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minPartitions <- min(defaultParallelism, 2) + } + as.integer(minPartitions) +} + +# Create an RDD from a text file. +# +# This function reads a text file from HDFS, a local file system (available on all +# nodes), or any Hadoop-supported file system URI, and creates an +# RDD of strings from it. +# +# @param sc SparkContext to use +# @param path Path of file to read. A vector of multiple paths is allowed. +# @param minPartitions Minimum number of partitions to be created. If NULL, the default +# value is chosen based on available parallelism. +# @return RDD where each item is of type \code{character} +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# lines <- textFile(sc, "myfile.txt") +#} +textFile <- function(sc, path, minPartitions = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinPartitions(sc, minPartitions)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +# Load an RDD saved as a SequenceFile containing serialized objects. +# +# The file to be loaded should be one that was previously generated by calling +# saveAsObjectFile() of the RDD class. +# +# @param sc SparkContext to use +# @param path Path of file to read. A vector of multiple paths is allowed. +# @param minPartitions Minimum number of partitions to be created. If NULL, the default +# value is chosen based on available parallelism. +# @return RDD containing serialized R objects. +# @seealso saveAsObjectFile +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- objectFile(sc, "myfile") +#} +objectFile <- function(sc, path, minPartitions = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinPartitions(sc, minPartitions)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +# Create an RDD from a homogeneous list or vector. +# +# This function creates an RDD from a local homogeneous list in R. The elements +# in the list are split into \code{numSlices} slices and distributed to nodes +# in the cluster. +# +# @param sc SparkContext to use +# @param coll collection to parallelize +# @param numSlices number of partitions to create in the RDD +# @return an RDD created from this collection +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10, 2) +# # The RDD should contain 10 elements +# length(rdd) +#} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +# Include this specified package on all workers +# +# This function can be used to include a package on all workers before the +# user's code is executed. This is useful in scenarios where other R package +# functions are used in a function passed to functions like \code{lapply}. +# NOTE: The package is assumed to be installed on every node in the Spark +# cluster. +# +# @param sc SparkContext to use +# @param pkg Package name +# +# @export +# @examples +#\dontrun{ +# library(Matrix) +# +# sc <- sparkR.init() +# # Include the matrix library we will be using +# includePackage(sc, Matrix) +# +# generateSparse <- function(x) { +# sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +# } +# +# rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +# collect(rdd) +#} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +# @title Broadcast a variable to all workers +# +# @description +# Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +# object for reading it in distributed functions. +# +# @param sc Spark Context to use +# @param object Object to be broadcast +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:2, 2L) +# +# # Large Matrix object that we want to broadcast +# randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +# randomMatBr <- broadcast(sc, randomMat) +# +# # Use the broadcast variable inside the function +# useBroadcast <- function(x) { +# sum(value(randomMatBr) * x) +# } +# sumRDD <- lapply(rdd, useBroadcast) +#} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +# @title Set the checkpoint directory +# +# Set the directory under which RDDs are going to be checkpointed. The +# directory must be a HDFS path if running on a cluster. +# +# @param sc Spark Context to use +# @param dirName Directory path +# @export +# @examples +#\dontrun{ +# sc <- sparkR.init() +# setCheckpointDir(sc, "~/checkpoint") +# rdd <- parallelize(sc, 1:2, 2L) +# checkpoint(rdd) +#} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 000000000000..6cf628e3007d --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,179 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Float -> double +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "a" = readArray(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +readArray <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +# Read a list. Types of each element may be different. +# Null objects are read as NA. +readList <- function(con) { + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + elem <- readObject(con) + if (is.null(elem)) { + elem <- NA + } + l[[i]] <- elem + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readMultipleObjects <- function(inputCon) { + # readMultipleObjects will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. + data <- list() + while(TRUE) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { + break + } + data[[length(data) + 1L]] <- readTypedObject(inputCon, type) + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readObject(rawObj) +} diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 000000000000..d848730e7043 --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,1988 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' Creates a \code{Column} of literal value. +#' +#' The passed in object is returned directly if it is already a \linkS4class{Column}. +#' If the object is a Scala Symbol, it is converted into a \linkS4class{Column} also. +#' Otherwise, a new \linkS4class{Column} is created to represent the literal value. +#' +#' @family normal_funcs +#' @rdname lit +#' @name lit +#' @export +setMethod("lit", signature("ANY"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lit", + ifelse(class(x) == "Column", x@jc, x)) + column(jc) + }) + +#' abs +#' +#' Computes the absolute value. +#' +#' @rdname abs +#' @name abs +#' @family normal_funcs +#' @export +#' @examples \dontrun{abs(df$c)} +setMethod("abs", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "abs", x@jc) + column(jc) + }) + +#' acos +#' +#' Computes the cosine inverse of the given value; the returned angle is in the range +#' 0.0 through pi. +#' +#' @rdname acos +#' @name acos +#' @family math_funcs +#' @export +#' @examples \dontrun{acos(df$c)} +setMethod("acos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "acos", x@jc) + column(jc) + }) + +#' approxCountDistinct +#' +#' Aggregate function: returns the approximate number of distinct items in a group. +#' +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{approxCountDistinct(df$c)} +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc) + column(jc) + }) + +#' ascii +#' +#' Computes the numeric value of the first character of the string column, and returns the +#' result as a int column. +#' +#' @rdname ascii +#' @name ascii +#' @family string_funcs +#' @export +#' @examples \dontrun{\dontrun{ascii(df$c)}} +setMethod("ascii", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ascii", x@jc) + column(jc) + }) + +#' asin +#' +#' Computes the sine inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. +#' +#' @rdname asin +#' @name asin +#' @family math_funcs +#' @export +#' @examples \dontrun{asin(df$c)} +setMethod("asin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "asin", x@jc) + column(jc) + }) + +#' atan +#' +#' Computes the tangent inverse of the given value. +#' +#' @rdname atan +#' @name atan +#' @family math_funcs +#' @export +#' @examples \dontrun{atan(df$c)} +setMethod("atan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "atan", x@jc) + column(jc) + }) + +#' avg +#' +#' Aggregate function: returns the average of the values in a group. +#' +#' @rdname avg +#' @name avg +#' @family agg_funcs +#' @export +#' @examples \dontrun{avg(df$c)} +setMethod("avg", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "avg", x@jc) + column(jc) + }) + +#' base64 +#' +#' Computes the BASE64 encoding of a binary column and returns it as a string column. +#' This is the reverse of unbase64. +#' +#' @rdname base64 +#' @name base64 +#' @family string_funcs +#' @export +#' @examples \dontrun{base64(df$c)} +setMethod("base64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "base64", x@jc) + column(jc) + }) + +#' bin +#' +#' An expression that returns the string representation of the binary value of the given long +#' column. For example, bin("12") returns "1100". +#' +#' @rdname bin +#' @name bin +#' @family math_funcs +#' @export +#' @examples \dontrun{bin(df$c)} +setMethod("bin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bin", x@jc) + column(jc) + }) + +#' bitwiseNOT +#' +#' Computes bitwise NOT. +#' +#' @rdname bitwiseNOT +#' @name bitwiseNOT +#' @family normal_funcs +#' @export +#' @examples \dontrun{bitwiseNOT(df$c)} +setMethod("bitwiseNOT", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bitwiseNOT", x@jc) + column(jc) + }) + +#' cbrt +#' +#' Computes the cube-root of the given value. +#' +#' @rdname cbrt +#' @name cbrt +#' @family math_funcs +#' @export +#' @examples \dontrun{cbrt(df$c)} +setMethod("cbrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cbrt", x@jc) + column(jc) + }) + +#' ceil +#' +#' Computes the ceiling of the given value. +#' +#' @rdname ceil +#' @name ceil +#' @family math_funcs +#' @export +#' @examples \dontrun{ceil(df$c)} +setMethod("ceil", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ceil", x@jc) + column(jc) + }) + +#' cos +#' +#' Computes the cosine of the given value. +#' +#' @rdname cos +#' @name cos +#' @family math_funcs +#' @export +#' @examples \dontrun{cos(df$c)} +setMethod("cos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cos", x@jc) + column(jc) + }) + +#' cosh +#' +#' Computes the hyperbolic cosine of the given value. +#' +#' @rdname cosh +#' @name cosh +#' @family math_funcs +#' @export +#' @examples \dontrun{cosh(df$c)} +setMethod("cosh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cosh", x@jc) + column(jc) + }) + +#' count +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @rdname count +#' @name count +#' @family agg_funcs +#' @export +#' @examples \dontrun{count(df$c)} +setMethod("count", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "count", x@jc) + column(jc) + }) + +#' crc32 +#' +#' Calculates the cyclic redundancy check value (CRC32) of a binary column and +#' returns the value as a bigint. +#' +#' @rdname crc32 +#' @name crc32 +#' @family misc_funcs +#' @export +#' @examples \dontrun{crc32(df$c)} +setMethod("crc32", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "crc32", x@jc) + column(jc) + }) + +#' dayofmonth +#' +#' Extracts the day of the month as an integer from a given date/timestamp/string. +#' +#' @rdname dayofmonth +#' @name dayofmonth +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofmonth(df$c)} +setMethod("dayofmonth", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofmonth", x@jc) + column(jc) + }) + +#' dayofyear +#' +#' Extracts the day of the year as an integer from a given date/timestamp/string. +#' +#' @rdname dayofyear +#' @name dayofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofyear(df$c)} +setMethod("dayofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofyear", x@jc) + column(jc) + }) + +#' exp +#' +#' Computes the exponential of the given value. +#' +#' @rdname exp +#' @name exp +#' @family math_funcs +#' @export +#' @examples \dontrun{exp(df$c)} +setMethod("exp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "exp", x@jc) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' expm1 +#' +#' Computes the exponential of the given value minus one. +#' +#' @rdname expm1 +#' @name expm1 +#' @family math_funcs +#' @export +#' @examples \dontrun{expm1(df$c)} +setMethod("expm1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expm1", x@jc) + column(jc) + }) + +#' factorial +#' +#' Computes the factorial of the given value. +#' +#' @rdname factorial +#' @name factorial +#' @family math_funcs +#' @export +#' @examples \dontrun{factorial(df$c)} +setMethod("factorial", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "factorial", x@jc) + column(jc) + }) + +#' first +#' +#' Aggregate function: returns the first value in a group. +#' +#' @rdname first +#' @name first +#' @family agg_funcs +#' @export +#' @examples \dontrun{first(df$c)} +setMethod("first", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + column(jc) + }) + +#' floor +#' +#' Computes the floor of the given value. +#' +#' @rdname floor +#' @name floor +#' @family math_funcs +#' @export +#' @examples \dontrun{floor(df$c)} +setMethod("floor", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "floor", x@jc) + column(jc) + }) + +#' hex +#' +#' Computes hex value of the given column. +#' +#' @rdname hex +#' @name hex +#' @family math_funcs +#' @export +#' @examples \dontrun{hex(df$c)} +setMethod("hex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hex", x@jc) + column(jc) + }) + +#' hour +#' +#' Extracts the hours as an integer from a given date/timestamp/string. +#' +#' @rdname hour +#' @name hour +#' @family datetime_funcs +#' @export +#' @examples \dontrun{hour(df$c)} +setMethod("hour", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hour", x@jc) + column(jc) + }) + +#' initcap +#' +#' Returns a new string column by converting the first letter of each word to uppercase. +#' Words are delimited by whitespace. +#' +#' For example, "hello world" will become "Hello World". +#' +#' @rdname initcap +#' @name initcap +#' @family string_funcs +#' @export +#' @examples \dontrun{initcap(df$c)} +setMethod("initcap", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "initcap", x@jc) + column(jc) + }) + +#' isNaN +#' +#' Return true iff the column is NaN. +#' +#' @rdname isNaN +#' @name isNaN +#' @family normal_funcs +#' @export +#' @examples \dontrun{isNaN(df$c)} +setMethod("isNaN", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + column(jc) + }) + +#' last +#' +#' Aggregate function: returns the last value in a group. +#' +#' @rdname last +#' @name last +#' @family agg_funcs +#' @export +#' @examples \dontrun{last(df$c)} +setMethod("last", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + column(jc) + }) + +#' last_day +#' +#' Given a date column, returns the last day of the month which the given date belongs to. +#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the +#' month in July 2015. +#' +#' @rdname last_day +#' @name last_day +#' @family datetime_funcs +#' @export +#' @examples \dontrun{last_day(df$c)} +setMethod("last_day", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last_day", x@jc) + column(jc) + }) + +#' length +#' +#' Computes the length of a given string or binary column. +#' +#' @rdname length +#' @name length +#' @family string_funcs +#' @export +#' @examples \dontrun{length(df$c)} +setMethod("length", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "length", x@jc) + column(jc) + }) + +#' log +#' +#' Computes the natural logarithm of the given value. +#' +#' @rdname log +#' @name log +#' @family math_funcs +#' @export +#' @examples \dontrun{log(df$c)} +setMethod("log", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log", x@jc) + column(jc) + }) + +#' log10 +#' +#' Computes the logarithm of the given value in base 10. +#' +#' @rdname log10 +#' @name log10 +#' @family math_funcs +#' @export +#' @examples \dontrun{log10(df$c)} +setMethod("log10", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log10", x@jc) + column(jc) + }) + +#' log1p +#' +#' Computes the natural logarithm of the given value plus one. +#' +#' @rdname log1p +#' @name log1p +#' @family math_funcs +#' @export +#' @examples \dontrun{log1p(df$c)} +setMethod("log1p", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log1p", x@jc) + column(jc) + }) + +#' log2 +#' +#' Computes the logarithm of the given column in base 2. +#' +#' @rdname log2 +#' @name log2 +#' @family math_funcs +#' @export +#' @examples \dontrun{log2(df$c)} +setMethod("log2", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log2", x@jc) + column(jc) + }) + +#' lower +#' +#' Converts a string column to lower case. +#' +#' @rdname lower +#' @name lower +#' @family string_funcs +#' @export +#' @examples \dontrun{lower(df$c)} +setMethod("lower", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "lower", x@jc) + column(jc) + }) + +#' ltrim +#' +#' Trim the spaces from left end for the specified string value. +#' +#' @rdname ltrim +#' @name ltrim +#' @family string_funcs +#' @export +#' @examples \dontrun{ltrim(df$c)} +setMethod("ltrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ltrim", x@jc) + column(jc) + }) + +#' max +#' +#' Aggregate function: returns the maximum value of the expression in a group. +#' +#' @rdname max +#' @name max +#' @family agg_funcs +#' @export +#' @examples \dontrun{max(df$c)} +setMethod("max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "max", x@jc) + column(jc) + }) + +#' md5 +#' +#' Calculates the MD5 digest of a binary column and returns the value +#' as a 32 character hex string. +#' +#' @rdname md5 +#' @name md5 +#' @family misc_funcs +#' @export +#' @examples \dontrun{md5(df$c)} +setMethod("md5", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "md5", x@jc) + column(jc) + }) + +#' mean +#' +#' Aggregate function: returns the average of the values in a group. +#' Alias for avg. +#' +#' @rdname mean +#' @name mean +#' @family agg_funcs +#' @export +#' @examples \dontrun{mean(df$c)} +setMethod("mean", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "mean", x@jc) + column(jc) + }) + +#' min +#' +#' Aggregate function: returns the minimum value of the expression in a group. +#' +#' @rdname min +#' @name min +#' @family agg_funcs +#' @export +#' @examples \dontrun{min(df$c)} +setMethod("min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "min", x@jc) + column(jc) + }) + +#' minute +#' +#' Extracts the minutes as an integer from a given date/timestamp/string. +#' +#' @rdname minute +#' @name minute +#' @family datetime_funcs +#' @export +#' @examples \dontrun{minute(df$c)} +setMethod("minute", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "minute", x@jc) + column(jc) + }) + +#' month +#' +#' Extracts the month as an integer from a given date/timestamp/string. +#' +#' @rdname month +#' @name month +#' @family datetime_funcs +#' @export +#' @examples \dontrun{month(df$c)} +setMethod("month", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "month", x@jc) + column(jc) + }) + +#' negate +#' +#' Unary minus, i.e. negate the expression. +#' +#' @rdname negate +#' @name negate +#' @family normal_funcs +#' @export +#' @examples \dontrun{negate(df$c)} +setMethod("negate", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "negate", x@jc) + column(jc) + }) + +#' quarter +#' +#' Extracts the quarter as an integer from a given date/timestamp/string. +#' +#' @rdname quarter +#' @name quarter +#' @family datetime_funcs +#' @export +#' @examples \dontrun{quarter(df$c)} +setMethod("quarter", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "quarter", x@jc) + column(jc) + }) + +#' reverse +#' +#' Reverses the string column and returns it as a new string column. +#' +#' @rdname reverse +#' @name reverse +#' @family string_funcs +#' @export +#' @examples \dontrun{reverse(df$c)} +setMethod("reverse", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "reverse", x@jc) + column(jc) + }) + +#' rint +#' +#' Returns the double value that is closest in value to the argument and +#' is equal to a mathematical integer. +#' +#' @rdname rint +#' @name rint +#' @family math_funcs +#' @export +#' @examples \dontrun{rint(df$c)} +setMethod("rint", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rint", x@jc) + column(jc) + }) + +#' round +#' +#' Returns the value of the column `e` rounded to 0 decimal places. +#' +#' @rdname round +#' @name round +#' @family math_funcs +#' @export +#' @examples \dontrun{round(df$c)} +setMethod("round", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "round", x@jc) + column(jc) + }) + +#' rtrim +#' +#' Trim the spaces from right end for the specified string value. +#' +#' @rdname rtrim +#' @name rtrim +#' @family string_funcs +#' @export +#' @examples \dontrun{rtrim(df$c)} +setMethod("rtrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rtrim", x@jc) + column(jc) + }) + +#' second +#' +#' Extracts the seconds as an integer from a given date/timestamp/string. +#' +#' @rdname second +#' @name second +#' @family datetime_funcs +#' @export +#' @examples \dontrun{second(df$c)} +setMethod("second", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "second", x@jc) + column(jc) + }) + +#' sha1 +#' +#' Calculates the SHA-1 digest of a binary column and returns the value +#' as a 40 character hex string. +#' +#' @rdname sha1 +#' @name sha1 +#' @family misc_funcs +#' @export +#' @examples \dontrun{sha1(df$c)} +setMethod("sha1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha1", x@jc) + column(jc) + }) + +#' signum +#' +#' Computes the signum of the given value. +#' +#' @rdname signum +#' @name signum +#' @family math_funcs +#' @export +#' @examples \dontrun{signum(df$c)} +setMethod("signum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "signum", x@jc) + column(jc) + }) + +#' sin +#' +#' Computes the sine of the given value. +#' +#' @rdname sin +#' @name sin +#' @family math_funcs +#' @export +#' @examples \dontrun{sin(df$c)} +setMethod("sin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sin", x@jc) + column(jc) + }) + +#' sinh +#' +#' Computes the hyperbolic sine of the given value. +#' +#' @rdname sinh +#' @name sinh +#' @family math_funcs +#' @export +#' @examples \dontrun{sinh(df$c)} +setMethod("sinh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sinh", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' soundex +#' +#' Return the soundex code for the specified expression. +#' +#' @rdname soundex +#' @name soundex +#' @family string_funcs +#' @export +#' @examples \dontrun{soundex(df$c)} +setMethod("soundex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "soundex", x@jc) + column(jc) + }) + +#' sqrt +#' +#' Computes the square root of the specified float value. +#' +#' @rdname sqrt +#' @name sqrt +#' @family math_funcs +#' @export +#' @examples \dontrun{sqrt(df$c)} +setMethod("sqrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sqrt", x@jc) + column(jc) + }) + +#' sum +#' +#' Aggregate function: returns the sum of all values in the expression. +#' +#' @rdname sum +#' @name sum +#' @family agg_funcs +#' @export +#' @examples \dontrun{sum(df$c)} +setMethod("sum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sum", x@jc) + column(jc) + }) + +#' sumDistinct +#' +#' Aggregate function: returns the sum of distinct values in the expression. +#' +#' @rdname sumDistinct +#' @name sumDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{sumDistinct(df$c)} +setMethod("sumDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sumDistinct", x@jc) + column(jc) + }) + +#' tan +#' +#' Computes the tangent of the given value. +#' +#' @rdname tan +#' @name tan +#' @family math_funcs +#' @export +#' @examples \dontrun{tan(df$c)} +setMethod("tan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tan", x@jc) + column(jc) + }) + +#' tanh +#' +#' Computes the hyperbolic tangent of the given value. +#' +#' @rdname tanh +#' @name tanh +#' @family math_funcs +#' @export +#' @examples \dontrun{tanh(df$c)} +setMethod("tanh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tanh", x@jc) + column(jc) + }) + +#' toDegrees +#' +#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. +#' +#' @rdname toDegrees +#' @name toDegrees +#' @family math_funcs +#' @export +#' @examples \dontrun{toDegrees(df$c)} +setMethod("toDegrees", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toDegrees", x@jc) + column(jc) + }) + +#' toRadians +#' +#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. +#' +#' @rdname toRadians +#' @name toRadians +#' @family math_funcs +#' @export +#' @examples \dontrun{toRadians(df$c)} +setMethod("toRadians", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toRadians", x@jc) + column(jc) + }) + +#' to_date +#' +#' Converts the column into DateType. +#' +#' @rdname to_date +#' @name to_date +#' @family datetime_funcs +#' @export +#' @examples \dontrun{to_date(df$c)} +setMethod("to_date", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_date", x@jc) + column(jc) + }) + +#' trim +#' +#' Trim the spaces from both ends for the specified string column. +#' +#' @rdname trim +#' @name trim +#' @family string_funcs +#' @export +#' @examples \dontrun{trim(df$c)} +setMethod("trim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "trim", x@jc) + column(jc) + }) + +#' unbase64 +#' +#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' This is the reverse of base64. +#' +#' @rdname unbase64 +#' @name unbase64 +#' @family string_funcs +#' @export +#' @examples \dontrun{unbase64(df$c)} +setMethod("unbase64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unbase64", x@jc) + column(jc) + }) + +#' unhex +#' +#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' and converts to the byte representation of number. +#' +#' @rdname unhex +#' @name unhex +#' @family math_funcs +#' @export +#' @examples \dontrun{unhex(df$c)} +setMethod("unhex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unhex", x@jc) + column(jc) + }) + +#' upper +#' +#' Converts a string column to upper case. +#' +#' @rdname upper +#' @name upper +#' @family string_funcs +#' @export +#' @examples \dontrun{upper(df$c)} +setMethod("upper", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "upper", x@jc) + column(jc) + }) + +#' weekofyear +#' +#' Extracts the week number as an integer from a given date/timestamp/string. +#' +#' @rdname weekofyear +#' @name weekofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{weekofyear(df$c)} +setMethod("weekofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "weekofyear", x@jc) + column(jc) + }) + +#' year +#' +#' Extracts the year as an integer from a given date/timestamp/string. +#' +#' @rdname year +#' @name year +#' @family datetime_funcs +#' @export +#' @examples \dontrun{year(df$c)} +setMethod("year", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "year", x@jc) + column(jc) + }) + +#' atan2 +#' +#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to +#' polar coordinates (r, theta). +#' +#' @rdname atan2 +#' @name atan2 +#' @family math_funcs +#' @export +#' @examples \dontrun{atan2(df$c, x)} +setMethod("atan2", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "atan2", y@jc, x) + column(jc) + }) + +#' datediff +#' +#' Returns the number of days from `start` to `end`. +#' +#' @rdname datediff +#' @name datediff +#' @family datetime_funcs +#' @export +#' @examples \dontrun{datediff(df$c, x)} +setMethod("datediff", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "datediff", y@jc, x) + column(jc) + }) + +#' hypot +#' +#' Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. +#' +#' @rdname hypot +#' @name hypot +#' @family math_funcs +#' @export +#' @examples \dontrun{hypot(df$c, x)} +setMethod("hypot", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "hypot", y@jc, x) + column(jc) + }) + +#' levenshtein +#' +#' Computes the Levenshtein distance of the two given string columns. +#' +#' @rdname levenshtein +#' @name levenshtein +#' @family string_funcs +#' @export +#' @examples \dontrun{levenshtein(df$c, x)} +setMethod("levenshtein", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "levenshtein", y@jc, x) + column(jc) + }) + +#' months_between +#' +#' Returns number of months between dates `date1` and `date2`. +#' +#' @rdname months_between +#' @name months_between +#' @family datetime_funcs +#' @export +#' @examples \dontrun{months_between(df$c, x)} +setMethod("months_between", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x) + column(jc) + }) + +#' nanvl +#' +#' Returns col1 if it is not NaN, or col2 if col1 is NaN. +#' hhBoth inputs should be floating point columns (DoubleType or FloatType). +#' +#' @rdname nanvl +#' @name nanvl +#' @family normal_funcs +#' @export +#' @examples \dontrun{nanvl(df$c, x)} +setMethod("nanvl", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "nanvl", y@jc, x) + column(jc) + }) + +#' pmod +#' +#' Returns the positive value of dividend mod divisor. +#' +#' @rdname pmod +#' @name pmod +#' @docType methods +#' @family math_funcs +#' @export +#' @examples \dontrun{pmod(df$c, x)} +setMethod("pmod", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "pmod", y@jc, x) + column(jc) + }) + + +#' Approx Count Distinct +#' +#' @family agg_funcs +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @return the approximate number of distinct items in a group. +#' @export +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct +#' @return the number of distinct items in a group. +#' @export +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + + +#' concat +#' +#' Concatenates multiple input string columns together into a single string column. +#' +#' @family string_funcs +#' @rdname concat +#' @name concat +#' @export +setMethod("concat", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols)) + column(jc) + }) + +#' greatest +#' +#' Returns the greatest value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null if all parameters are null. +#' +#' @family normal_funcs +#' @rdname greatest +#' @name greatest +#' @export +setMethod("greatest", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols)) + column(jc) + }) + +#' least +#' +#' Returns the least value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' +#' @family normal_funcs +#' @rdname least +#' @name least +#' @export +setMethod("least", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols)) + column(jc) + }) + +#' ceiling +#' +#' Computes the ceiling of the given value. +#' +#' @family math_funcs +#' @rdname ceil +#' @name ceil +#' @aliases ceil +#' @export +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' sign +#' +#' Computes the signum of the given value. +#' +#' @family math_funcs +#' @rdname signum +#' @name signum +#' @aliases signum +#' @export +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' n_distinct +#' +#' Aggregate function: returns the number of distinct items in a group. +#' +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct +#' @aliases countDistinct +#' @export +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' n +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @family agg_funcs +#' @rdname count +#' @name count +#' @aliases count +#' @export +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) + +#' date_format +#' +#' Converts a date/timestamp/string to a value of string in the format specified by the date +#' format given by the second argument. +#' +#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All +#' pattern letters of \code{java.text.SimpleDateFormat} can be used. +#' +#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' specialized implementation. +#' +#' @family datetime_funcs +#' @rdname date_format +#' @name date_format +#' @export +setMethod("date_format", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) + column(jc) + }) + +#' from_utc_timestamp +#' +#' Assumes given timestamp is UTC and converts to given timezone. +#' +#' @family datetime_funcs +#' @rdname from_utc_timestamp +#' @name from_utc_timestamp +#' @export +setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) + column(jc) + }) + +#' instr +#' +#' Locate the position of the first occurrence of substr column in the given string. +#' Returns null if either of the arguments are null. +#' +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @family string_funcs +#' @rdname instr +#' @name instr +#' @export +setMethod("instr", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) + column(jc) + }) + +#' next_day +#' +#' Given a date column, returns the first date which is later than the value of the date column +#' that is on the specified day of the week. +#' +#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first +#' Sunday after 2015-07-27. +#' +#' Day of the week parameter is case insensitive, and accepts: +#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". +#' +#' @family datetime_funcs +#' @rdname next_day +#' @name next_day +#' @export +setMethod("next_day", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) + column(jc) + }) + +#' to_utc_timestamp +#' +#' Assumes given timestamp is in given timezone and converts to UTC. +#' +#' @family datetime_funcs +#' @rdname to_utc_timestamp +#' @name to_utc_timestamp +#' @export +setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) + column(jc) + }) + +#' add_months +#' +#' Returns the date that is numMonths after startDate. +#' +#' @name add_months +#' @family datetime_funcs +#' @rdname add_months +#' @name add_months +#' @export +setMethod("add_months", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) + column(jc) + }) + +#' date_add +#' +#' Returns the date that is `days` days after `start` +#' +#' @family datetime_funcs +#' @rdname date_add +#' @name date_add +#' @export +setMethod("date_add", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) + column(jc) + }) + +#' date_sub +#' +#' Returns the date that is `days` days before `start` +#' +#' @family datetime_funcs +#' @rdname date_sub +#' @name date_sub +#' @export +setMethod("date_sub", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) + column(jc) + }) + +#' format_number +#' +#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' and returns the result as a string column. +#' +#' If d is 0, the result has no decimal point or fractional part. +#' If d < 0, the result will be null.' +#' +#' @family string_funcs +#' @rdname format_number +#' @name format_number +#' @export +setMethod("format_number", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "format_number", + y@jc, as.integer(x)) + column(jc) + }) + +#' sha2 +#' +#' Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. +#' +#' @param y column to compute SHA-2 on. +#' @param x one of 224, 256, 384, or 512. +#' @family misc_funcs +#' @rdname sha2 +#' @name sha2 +#' @export +setMethod("sha2", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) + column(jc) + }) + +#' shiftLeft +#' +#' Shift the the given value numBits left. If the given value is a long value, this function +#' will return a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftLeft +#' @name shiftLeft +#' @export +setMethod("shiftLeft", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftLeft", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRight +#' +#' Shift the the given value numBits right. If the given value is a long value, it will return +#' a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftRight +#' @name shiftRight +#' @export +setMethod("shiftRight", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRight", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRightUnsigned +#' +#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftRightUnsigned +#' @name shiftRightUnsigned +#' @export +setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRightUnsigned", + y@jc, as.integer(x)) + column(jc) + }) + +#' concat_ws +#' +#' Concatenates multiple input string columns together into a single string column, +#' using the given separator. +#' +#' @family string_funcs +#' @rdname concat_ws +#' @name concat_ws +#' @export +setMethod("concat_ws", signature(sep = "character", x = "Column"), + function(sep, x, ...) { + jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) + column(jc) + }) + +#' conv +#' +#' Convert a number in a string column from one base to another. +#' +#' @family math_funcs +#' @rdname conv +#' @name conv +#' @export +setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), + function(x, fromBase, toBase) { + fromBase <- as.integer(fromBase) + toBase <- as.integer(toBase) + jc <- callJStatic("org.apache.spark.sql.functions", + "conv", + x@jc, fromBase, toBase) + column(jc) + }) + +#' expr +#' +#' Parses the expression string into the column that it represents, similar to +#' DataFrame.selectExpr +#' +#' @family normal_funcs +#' @rdname expr +#' @name expr +#' @export +setMethod("expr", signature(x = "character"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) + column(jc) + }) + +#' format_string +#' +#' Formats the arguments in printf-style and returns the result as a string column. +#' +#' @family string_funcs +#' @rdname format_string +#' @name format_string +#' @export +setMethod("format_string", signature(format = "character", x = "Column"), + function(format, x, ...) { + jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jc <- callJStatic("org.apache.spark.sql.functions", + "format_string", + format, jcols) + column(jc) + }) + +#' from_unixtime +#' +#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string +#' representing the timestamp of that moment in the current system time zone in the given +#' format. +#' +#' @family datetime_funcs +#' @rdname from_unixtime +#' @name from_unixtime +#' @export +setMethod("from_unixtime", signature(x = "Column"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", + "from_unixtime", + x@jc, format) + column(jc) + }) + +#' locate +#' +#' Locate the position of the first occurrence of substr. +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @family string_funcs +#' @rdname locate +#' @name locate +#' @export +setMethod("locate", signature(substr = "character", str = "Column"), + function(substr, str, pos = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", + "locate", + substr, str@jc, as.integer(pos)) + column(jc) + }) + +#' lpad +#' +#' Left-pad the string column with +#' +#' @family string_funcs +#' @rdname lpad +#' @name lpad +#' @export +setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' rand +#' +#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export +setMethod("rand", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand") + column(jc) + }) +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export +setMethod("rand", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) + column(jc) + }) + +#' randn +#' +#' Generate a column with i.i.d. samples from the standard normal distribution. +#' +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export +setMethod("randn", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn") + column(jc) + }) +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export +setMethod("randn", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) + column(jc) + }) + +#' regexp_extract +#' +#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' +#' @family string_funcs +#' @rdname regexp_extract +#' @name regexp_extract +#' @export +setMethod("regexp_extract", + signature(x = "Column", pattern = "character", idx = "numeric"), + function(x, pattern, idx) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_extract", + x@jc, pattern, as.integer(idx)) + column(jc) + }) + +#' regexp_replace +#' +#' Replace all substrings of the specified string value that match regexp with rep. +#' +#' @family string_funcs +#' @rdname regexp_replace +#' @name regexp_replace +#' @export +setMethod("regexp_replace", + signature(x = "Column", pattern = "character", replacement = "character"), + function(x, pattern, replacement) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_replace", + x@jc, pattern, replacement) + column(jc) + }) + +#' rpad +#' +#' Right-padded with pad to a length of len. +#' +#' @family string_funcs +#' @rdname rpad +#' @name rpad +#' @export +setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "rpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' substring_index +#' +#' Returns the substring from string str before count occurrences of the delimiter delim. +#' If count is positive, everything the left of the final delimiter (counting from left) is +#' returned. If count is negative, every to the right of the final delimiter (counting from the +#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' +#' @family string_funcs +#' @rdname substring_index +#' @name substring_index +#' @export +setMethod("substring_index", + signature(x = "Column", delim = "character", count = "numeric"), + function(x, delim, count) { + jc <- callJStatic("org.apache.spark.sql.functions", + "substring_index", + x@jc, delim, as.integer(count)) + column(jc) + }) + +#' translate +#' +#' Translate any character in the src by a character in replaceString. +#' The characters in replaceString is corresponding to the characters in matchingString. +#' The translate will happen when any character in the string matching with the character +#' in the matchingString. +#' +#' @family string_funcs +#' @rdname translate +#' @name translate +#' @export +setMethod("translate", + signature(x = "Column", matchingString = "character", replaceString = "character"), + function(x, matchingString, replaceString) { + jc <- callJStatic("org.apache.spark.sql.functions", + "translate", x@jc, matchingString, replaceString) + column(jc) + }) + +#' unix_timestamp +#' +#' Gets current Unix timestamp in seconds. +#' +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "missing", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") + column(jc) + }) +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "Column", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) + column(jc) + }) +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "Column", format = "character"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) + column(jc) + }) +#' when +#' +#' Evaluates a list of conditions and returns one of multiple possible result expressions. +#' For unmatched expressions null is returned. +#' +#' @family normal_funcs +#' @rdname when +#' @name when +#' @export +setMethod("when", signature(condition = "Column", value = "ANY"), + function(condition, value) { + condition <- condition@jc + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) + column(jc) + }) + +#' ifelse +#' +#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' Otherwise \code{no} is returned for unmatched conditions. +#' +#' @family normal_funcs +#' @rdname ifelse +#' @name ifelse +#' @export +setMethod("ifelse", + signature(test = "Column", yes = "ANY", no = "ANY"), + function(test, yes, no) { + test <- test@jc + yes <- ifelse(class(yes) == "Column", yes@jc, yes) + no <- ifelse(class(no) == "Column", no@jc, no) + jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", + "when", + test, yes), + "otherwise", no) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 000000000000..a829d46c1894 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,977 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +# @rdname aggregateRDD +# @seealso reduce +# @export +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +# @rdname cache-methods +# @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +# @rdname coalesce +# @seealso repartition +# @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +# @rdname checkpoint-methods +# @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +# @rdname collect-methods +# @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +# @rdname collect-methods +# @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +# @rdname collect-methods +# @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +# @rdname count +# @export +setGeneric("count", function(x) { standardGeneric("count") }) + +# @rdname countByValue +# @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +# @rdname statfunctions +# @export +setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) + +# @rdname distinct +# @export +setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) + +# @rdname filterRDD +# @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +# @rdname first +# @export +setGeneric("first", function(x) { standardGeneric("first") }) + +# @rdname flatMap +# @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +# @rdname fold +# @seealso reduce +# @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +# @rdname foreach +# @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +# @rdname foreach +# @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +# @rdname glom +# @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +# @rdname keyBy +# @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +# @rdname lapplyPartition +# @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +# @rdname lapplyPartitionsWithIndex +# @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +# @rdname lapply +# @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +# @rdname lapplyPartition +# @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +# @rdname lapplyPartitionsWithIndex +# @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +# @rdname maximum +# @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +# @rdname minimum +# @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +# @rdname sumRDD +# @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +# @rdname name +# @export +setGeneric("name", function(x) { standardGeneric("name") }) + +# @rdname numPartitions +# @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +# @rdname persist +# @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +# @rdname pipeRDD +# @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +# @rdname reduce +# @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +# @rdname repartition +# @seealso coalesce +# @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +# @rdname sampleRDD +# @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +# @rdname saveAsObjectFile +# @seealso objectFile +# @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +# @rdname saveAsTextFile +# @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +# @rdname setName +# @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +# @rdname sortBy +# @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1) { + standardGeneric("sortBy") + }) + +# @rdname take +# @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +# @rdname takeOrdered +# @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +# @rdname takeSample +# @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +# @rdname top +# @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +# @rdname unionRDD +# @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +# @rdname unpersist-methods +# @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +# @rdname zipRDD +# @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +# @rdname zipRDD +# @export +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, + signature = "...") + +# @rdname zipWithIndex +# @seealso zipWithUniqueId +# @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +# @rdname zipWithUniqueId +# @seealso zipWithIndex +# @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +# @rdname cartesian +# @export +setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) + +# @rdname countByKey +# @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +# @rdname flatMapValues +# @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +# @rdname intersection +# @export +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) + +# @rdname keys +# @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +# @rdname lookup +# @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +# @rdname mapValues +# @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +# @rdname sampleByKey +# @export +setGeneric("sampleByKey", + function(x, withReplacement, fractions, seed) { + standardGeneric("sampleByKey") + }) + +# @rdname values +# @export +setGeneric("values", function(x) { standardGeneric("values") }) + + +############ Shuffle Functions ############ + +# @rdname aggregateByKey +# @seealso foldByKey, combineByKey +# @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +# @rdname cogroup +# @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +# @rdname combineByKey +# @seealso groupByKey, reduceByKey +# @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +# @rdname foldByKey +# @seealso aggregateByKey, combineByKey +# @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +# @rdname join-methods +# @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +# @rdname groupByKey +# @seealso reduceByKey +# @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +# @rdname join-methods +# @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +# @rdname join-methods +# @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +# @rdname partitionBy +# @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +# @rdname reduceByKey +# @seealso groupByKey +# @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +# @rdname reduceByKeyLocally +# @seealso reduceByKey +# @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +# @rdname join-methods +# @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +# @rdname sortByKey +# @export +setGeneric("sortByKey", + function(x, ascending = TRUE, numPartitions = 1) { + standardGeneric("sortByKey") + }) + +# @rdname subtract +# @export +setGeneric("subtract", + function(x, other, numPartitions = 1) { + standardGeneric("subtract") + }) + +# @rdname subtractByKey +# @export +setGeneric("subtractByKey", + function(x, other, numPartitions = 1) { + standardGeneric("subtractByKey") + }) + + +################### Broadcast Variable Methods ################# + +# @rdname broadcast +# @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname agg +#' @export +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +#' @rdname arrange +#' @export +setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname describe +#' @export +setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) + +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname except +#' @export +setGeneric("except", function(x, y) { standardGeneric("except") }) + +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname groupBy +#' @export +setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) + +#' @rdname groupBy +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' rdname merge +#' @export +setGeneric("merge") + +#' @rdname withColumn +#' @export +setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) + +#' @rdname arrange +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("rename", function(x, ...) { standardGeneric("rename") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sample +#' @export +setGeneric("sample", + function(x, withReplacement, fraction, seed) { + standardGeneric("sample") + }) + +#' @rdname sample +#' @export +setGeneric("sample_frac", + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) + +#' @rdname write.df +#' @export +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname agg +#' @export +setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) + +#' @rdname summary +#' @export +setGeneric("summary", function(x, ...) { standardGeneric("summary") }) + +# @rdname tojson +# @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("between", function(x, bounds) { standardGeneric("between") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("when", function(condition, value) { standardGeneric("when") }) + +#' @rdname column +#' @export +setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) + + +###################### Expression Function Methods ########################## + +#' @rdname add_months +#' @export +setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) + +#' @rdname approxCountDistinct +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname ascii +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname avg +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname base64 +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname bin +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname bitwiseNOT +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname cbrt +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname ceil +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname concat +#' @export +setGeneric("concat", function(x, ...) { standardGeneric("concat") }) + +#' @rdname concat_ws +#' @export +setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) + +#' @rdname conv +#' @export +setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) + +#' @rdname countDistinct +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname crc32 +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname datediff +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname date_add +#' @export +setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) + +#' @rdname date_format +#' @export +setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) + +#' @rdname date_sub +#' @export +setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) + +#' @rdname dayofmonth +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname dayofyear +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname explode +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname expr +#' @export +setGeneric("expr", function(x) { standardGeneric("expr") }) + +#' @rdname from_utc_timestamp +#' @export +setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) + +#' @rdname format_number +#' @export +setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) + +#' @rdname format_string +#' @export +setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) + +#' @rdname from_unixtime +#' @export +setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) + +#' @rdname greatest +#' @export +setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) + +#' @rdname hex +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname hour +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname hypot +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + +#' @rdname initcap +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname instr +#' @export +setGeneric("instr", function(y, x) { standardGeneric("instr") }) + +#' @rdname isNaN +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + +#' @rdname last +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname last_day +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname least +#' @export +setGeneric("least", function(x, ...) { standardGeneric("least") }) + +#' @rdname levenshtein +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname lit +#' @export +setGeneric("lit", function(x) { standardGeneric("lit") }) + +#' @rdname locate +#' @export +setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) + +#' @rdname lower +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname lpad +#' @export +setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) + +#' @rdname ltrim +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname md5 +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname minute +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname month +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname months_between +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname count +#' @export +setGeneric("n", function(x) { standardGeneric("n") }) + +#' @rdname nanvl +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname negate +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname next_day +#' @export +setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) + +#' @rdname countDistinct +#' @export +setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) + +#' @rdname pmod +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname quarter +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname rand +#' @export +setGeneric("rand", function(seed) { standardGeneric("rand") }) + +#' @rdname randn +#' @export +setGeneric("randn", function(seed) { standardGeneric("randn") }) + +#' @rdname regexp_extract +#' @export +setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) + +#' @rdname regexp_replace +#' @export +setGeneric("regexp_replace", + function(x, pattern, replacement) { standardGeneric("regexp_replace") }) + +#' @rdname reverse +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname rint +#' @export +setGeneric("rint", function(x, ...) { standardGeneric("rint") }) + +#' @rdname rpad +#' @export +setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) + +#' @rdname rtrim +#' @export +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) + +#' @rdname second +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname sha1 +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname sha2 +#' @export +setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) + +#' @rdname shiftLeft +#' @export +setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) + +#' @rdname shiftRight +#' @export +setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) + +#' @rdname shiftRightUnsigned +#' @export +setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) + +#' @rdname signum +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname size +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname soundex +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname substring_index +#' @export +setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) + +#' @rdname sumDistinct +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname toDegrees +#' @export +setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) + +#' @rdname toRadians +#' @export +setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) + +#' @rdname to_date +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname to_utc_timestamp +#' @export +setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) + +#' @rdname translate +#' @export +setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) + +#' @rdname trim +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname unbase64 +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname unhex +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname unix_timestamp +#' @export +setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) + +#' @rdname upper +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname weekofyear +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname year +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + +#' @rdname glm +#' @export +setGeneric("glm") + +#' @rdname rbind +#' @export +setGeneric("rbind", signature = "...") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 000000000000..576ac72f40fc --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,138 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R schema.R column.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols <- list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] <- alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + +#' @rdname agg +#' @aliases agg +setMethod("summarize", + signature(x = "GroupedData"), + function(x, ...) { + agg(x, ...) + }) + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 000000000000..0838a7bb35e0 --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +#' @include generics.R +NULL + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 000000000000..cea3d760d05f --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '+', '-', and '.'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param object A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname predict +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param x A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname summary +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(x = "PipelineModel"), + function(x, ...) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", x@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", x@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 000000000000..199c3fd6ab1b --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,908 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) +#' @include generics.R jobj.R RDD.R +NULL + +############ Actions and Transformations ############ + +# Look up elements of a key in an RDD +# +# @description +# \code{lookup} returns a list of values in this RDD for key key. +# +# @param x The RDD to collect +# @param key The key to look up for +# @return a list of values in this RDD for key key +# @examples +#\dontrun{ +# sc <- sparkR.init() +# pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +# rdd <- parallelize(sc, pairs) +# lookup(rdd, 1) # list(1, 3) +#} +# @rdname lookup +# @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +# Count the number of elements for each key, and return the result to the +# master as lists of (key, count) pairs. +# +# Same as countByKey in Spark. +# +# @param x The RDD to count keys. +# @return list of (key, count) pairs, where count is number of each key in rdd. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +# countByKey(rdd) # ("a", 2L), ("b", 1L) +#} +# @rdname countByKey +# @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +# Return an RDD with the keys of each tuple. +# +# @param x The RDD from which the keys of each tuple is returned. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +# collect(keys(rdd)) # list(1, 3) +#} +# @rdname keys +# @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +# Return an RDD with the values of each tuple. +# +# @param x The RDD from which the values of each tuple is returned. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +# collect(values(rdd)) # list(2, 4) +#} +# @rdname values +# @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +# Applies a function to all values of the elements, without modifying the keys. +# +# The same as `mapValues()' in Spark. +# +# @param X The RDD to apply the transformation. +# @param FUN the transformation to apply on the value of each element. +# @return a new RDD created by the transformation. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:10) +# makePairs <- lapply(rdd, function(x) { list(x, x) }) +# collect(mapValues(makePairs, function(x) { x * 2) }) +# Output: list(list(1,2), list(2,4), list(3,6), ...) +#} +# @rdname mapValues +# @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +# Pass each value in the key-value pair RDD through a flatMap function without +# changing the keys; this also retains the original RDD's partitioning. +# +# The same as 'flatMapValues()' in Spark. +# +# @param X The RDD to apply the transformation. +# @param FUN the transformation to apply on the value of each element. +# @return a new RDD created by the transformation. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +# collect(flatMapValues(rdd, function(x) { x })) +# Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#} +# @rdname flatMapValues +# @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +# Partition an RDD by key +# +# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +# For each element of this RDD, the partitioner is used to compute a hash +# function and the RDD is partitioned using this hash value. +# +# @param x The RDD to partition. Should be an RDD where each element is +# list(K, V) or c(K, V). +# @param numPartitions Number of partitions to create. +# @param ... Other optional arguments to partitionBy. +# +# @param partitionFunc The partition function to use. Uses a default hashCode +# function if not provided +# @return An RDD partitioned using the specified partitioner. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +# rdd <- parallelize(sc, pairs) +# parts <- partitionBy(rdd, 2L) +# collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#} +# @rdname partitionBy +# @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], + # where the key is the target partition number, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + numToInt(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + numToInt(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +# Group values by key +# +# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +# and group values for each key in the RDD into a single sequence. +# +# @param x The RDD to group. Should be an RDD where each element is +# list(K, V) or c(K, V). +# @param numPartitions Number of partitions to create. +# @return An RDD where each element is list(K, list(V)) +# @seealso reduceByKey +# @examples +#\dontrun{ +# sc <- sparkR.init() +# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +# rdd <- parallelize(sc, pairs) +# parts <- groupByKey(rdd, 2L) +# grouped <- collect(parts) +# grouped[[1]] # Should be a list(1, list(2, 4)) +#} +# @rdname groupByKey +# @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +# Merge values by key +# +# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +# and merges the values for each key using an associative reduce function. +# +# @param x The RDD to reduce by key. Should be an RDD where each element is +# list(K, V) or c(K, V). +# @param combineFunc The associative reduce function to use. +# @param numPartitions Number of partitions to create. +# @return An RDD where each element is list(K, V') where V' is the merged +# value +# @examples +#\dontrun{ +# sc <- sparkR.init() +# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +# rdd <- parallelize(sc, pairs) +# parts <- reduceByKey(rdd, "+", 2L) +# reduced <- collect(parts) +# reduced[[1]] # Should be a list(1, 6) +#} +# @rdname reduceByKey +# @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) + lapplyPartition(shuffled, reduceVals) + }) + +# Merge values by key locally +# +# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +# and merges the values for each key using an associative reduce function, but return the +# results immediately to the driver as an R list. +# +# @param x The RDD to reduce by key. Should be an RDD where each element is +# list(K, V) or c(K, V). +# @param combineFunc The associative reduce function to use. +# @return A list of elements of type list(K, V') where V' is the merged value for each key +# @seealso reduceByKey +# @examples +#\dontrun{ +# sc <- sparkR.init() +# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +# rdd <- parallelize(sc, pairs) +# reduced <- reduceByKeyLocally(rdd, "+") +# reduced # list(list(1, 6), list(1.1, 3)) +#} +# @rdname reduceByKeyLocally +# @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +# Combine values by key +# +# Generic function to combine the elements for each key using a custom set of +# aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +# for a "combined type" C. Note that V and C can be different -- for example, one +# might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +# Users provide three functions: +# \itemize{ +# \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +# \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +# \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +# two lists). +# } +# +# @param x The RDD to combine. Should be an RDD where each element is +# list(K, V) or c(K, V). +# @param createCombiner Create a combiner (C) given a value (V) +# @param mergeValue Merge the given value (V) with an existing combiner (C) +# @param mergeCombiners Merge two combiners and return a new combiner +# @param numPartitions Number of partitions to create. +# @return An RDD where each element is list(K, C) where C is the combined type +# +# @seealso groupByKey, reduceByKey +# @examples +#\dontrun{ +# sc <- sparkR.init() +# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +# rdd <- parallelize(sc, pairs) +# parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +# combined <- collect(parts) +# combined[[1]] # Should be a list(1, 6) +#} +# @rdname combineByKey +# @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "numeric"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +# Aggregate a pair RDD by each key. +# +# Aggregate the values of each key in an RDD, using given combine functions +# and a neutral "zero value". This function can return a different result type, +# U, than the type of the values in this RDD, V. Thus, we need one operation +# for merging a V into a U and one operation for merging two U's, The former +# operation is used for merging values within a partition, and the latter is +# used for merging values between partitions. To avoid memory allocation, both +# of these functions are allowed to modify and return their first argument +# instead of creating a new U. +# +# @param x An RDD. +# @param zeroValue A neutral "zero value". +# @param seqOp A function to aggregate the values of each key. It may return +# a different result type from the type of the values. +# @param combOp A function to aggregate results of seqOp. +# @return An RDD containing the aggregation result. +# @seealso foldByKey, combineByKey +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +# zeroValue <- list(0, 0) +# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +# aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +# # list(list(1, list(3, 2)), list(2, list(7, 2))) +#} +# @rdname aggregateByKey +# @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "numeric"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +# Fold a pair RDD by each key. +# +# Aggregate the values of each key in an RDD, using an associative function "func" +# and a neutral "zero value" which may be added to the result an arbitrary +# number of times, and must not change the result (e.g., 0 for addition, or +# 1 for multiplication.). +# +# @param x An RDD. +# @param zeroValue A neutral "zero value". +# @param func An associative function for folding values of each key. +# @return An RDD containing the aggregation result. +# @seealso aggregateByKey, combineByKey +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +# foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#} +# @rdname foldByKey +# @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "numeric"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +# Join two RDDs +# +# @description +# \code{join} This function joins two RDDs where every element is of the form list(K, V). +# The key types of the two RDDs should be the same. +# +# @param x An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param y An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param numPartitions Number of partitions to create. +# @return a new RDD containing all pairs of elements with matching keys in +# two input RDDs. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +# join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#} +# @rdname join-methods +# @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), + doJoin) + }) + +# Left outer join two RDDs +# +# @description +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. +# +# @param x An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param y An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param numPartitions Number of partitions to create. +# @return For each element (k, v) in x, the resulting RDD will either contain +# all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +# if no elements in rdd2 have key k. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +# leftOuterJoin(rdd1, rdd2, 2L) +# # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#} +# @rdname join-methods +# @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +# Right outer join two RDDs +# +# @description +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. +# +# @param x An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param y An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param numPartitions Number of partitions to create. +# @return For each element (k, w) in y, the resulting RDD will either contain +# all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +# if no elements in x have key k. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +# rightOuterJoin(rdd1, rdd2, 2L) +# # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#} +# @rdname join-methods +# @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +# Full outer join two RDDs +# +# @description +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. +# +# @param x An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param y An RDD to be joined. Should be an RDD where each element is +# list(K, V). +# @param numPartitions Number of partitions to create. +# @return For each element (k, v) in x and (k, w) in y, the resulting RDD +# will contain all pairs (k, (v, w)) for both (k, v) in x and +# (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +# in x/y have key k. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +# fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +# # list(1, list(3, 1)), +# # list(2, list(NULL, 4))) +# # list(3, list(3, NULL)), +#} +# @rdname join-methods +# @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +# For each key k in several RDDs, return a resulting RDD that +# whose values are a list of values for the key in all RDDs. +# +# @param ... Several RDDs. +# @param numPartitions Number of partitions to create. +# @return a new RDD containing all pairs of elements with values in a list +# in all RDDs. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +# cogroup(rdd1, rdd2, numPartitions = 2L) +# # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#} +# @rdname cogroup +# @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +# Sort a (k, v) pair RDD by k. +# +# @param x A (k, v) pair RDD to be sorted. +# @param ascending A flag to indicate whether the sorting is ascending or descending. +# @param numPartitions Number of partitions to create. +# @return An RDD where all (k, v) pair elements are sorted. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +# collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#} +# @rdname sortByKey +# @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + +# Subtract a pair RDD with another pair RDD. +# +# Return an RDD with the pairs from x whose keys are not in other. +# +# @param x An RDD. +# @param other An RDD. +# @param numPartitions Number of the partitions in the result RDD. +# @return An RDD with the pairs from x whose keys are not in other. +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +# list("b", 5), list("a", 2))) +# rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +# collect(subtractByKey(rdd1, rdd2)) +# # list(list("b", 4), list("b", 5)) +#} +# @rdname subtractByKey +# @aliases subtractByKey,RDD +setMethod("subtractByKey", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR:::numPartitions(x)) { + filterFunction <- function(elem) { + iters <- elem[[2]] + (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) + } + + flatMapValues(filterRDD(cogroup(x, + other, + numPartitions = numPartitions), + filterFunction), + function (v) { v[[1]] }) + }) + +# Return a subset of this RDD sampled by key. +# +# @description +# \code{sampleByKey} Create a sample of this RDD using variable sampling rates +# for different keys as specified by fractions, a key to sampling rate map. +# +# @param x The RDD to sample elements by key, where each element is +# list(K, V) or c(K, V). +# @param withReplacement Sampling with replacement or not +# @param fraction The (rough) sample target fraction +# @param seed Randomness seed value +# @examples +#\dontrun{ +# sc <- sparkR.init() +# rdd <- parallelize(sc, 1:3000) +# pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +# else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +# fractions <- list(a = 0.2, b = 0.1, c = 0.3) +# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +# 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +# 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +# 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +# lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +# lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +# lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +# lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +# lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +# lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +# fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +# fractions <- list(a = 0.2, b = 0.1) +# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#} +# @rdname sampleByKey +# @aliases sampleByKey,RDD-method +setMethod("sampleByKey", + signature(x = "RDD", withReplacement = "logical", + fractions = "vector", seed = "integer"), + function(x, withReplacement, fractions, seed) { + + for (elem in fractions) { + if (elem < 0.0) { + stop(paste("Negative fraction value ", fractions[which(fractions == elem)])) + } + } + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(partIndex, part) { + set.seed(bitwXor(seed, partIndex)) + res <- vector("list", length(part)) + len <- 0 + + # mixing because the initial seeds are close to each other + runif(10) + + for (elem in part) { + if (elem[[1]] %in% names(fractions)) { + frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) + if (withReplacement) { + count <- rpois(1, frac) + if (count > 0) { + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < frac) { + len <- len + 1 + res[[len]] <- elem + } + } + } else { + stop("KeyError: \"", elem[[1]], "\"") + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. (duplicated from sampleRDD) + if (len > 0) { + res[1:len] + } else { + list() + } + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R new file mode 100644 index 000000000000..79c744ef29c2 --- /dev/null +++ b/R/pkg/R/schema.R @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField +# datatypes. These are used to create and interact with DataFrame schemas. + +#' structType +#' +#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' use with createDataFrame and toDF. +#' +#' @param x a structField object (created with the field() function) +#' @param ... additional structField objects +#' @return a structType object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' schema <- structType(structField("a", "integer"), structField("b", "string")) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } +structType <- function(x, ...) { + UseMethod("structType", x) +} + +structType.jobj <- function(x) { + obj <- structure(list(), class = "structType") + obj$jobj <- x + obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) } + obj +} + +structType.structField <- function(x, ...) { + fields <- list(x, ...) + if (!all(sapply(fields, inherits, "structField"))) { + stop("All arguments must be structField objects.") + } + sfObjList <- lapply(fields, function(field) { + field$jobj + }) + stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructType", + listToSeq(sfObjList)) + structType(stObj) +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + cat("StructType\n", + sapply(x$fields(), + function(field) { + paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") + }), + sep = "") +} + +#' structField +#' +#' Create a structField object that contains the metadata for a single field in a schema. +#' +#' @param x The name of the field +#' @param type The data type of the field +#' @param nullable A logical vector indicating whether or not the field is nullable +#' @return a structField object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' field1 <- structField("a", "integer", TRUE) +#' field2 <- structField("b", "string", TRUE) +#' schema <- structType(field1, field2) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } + +structField <- function(x, ...) { + UseMethod("structField", x) +} + +structField.jobj <- function(x) { + obj <- structure(list(), class = "structField") + obj$jobj <- x + obj$name <- function() { callJMethod(x, "name") } + obj$dataType <- function() { callJMethod(x, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(x, "nullable") } + obj +} + +structField.character <- function(x, type, nullable = TRUE) { + if (class(x) != "character") { + stop("Field name must be a string.") + } + if (class(type) != "character") { + stop("Field type must be a string.") + } + if (class(nullable) != "logical") { + stop("nullable must be either TRUE or FALSE") + } + options <- c("byte", + "integer", + "float", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + dataType <- if (type %in% options) { + type + } else { + stop(paste("Unsupported type for Dataframe:", type)) + } + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructField", + x, + dataType, + nullable) + structField(sfObj) +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat("StructField(name = \"", x$name(), + "\", type = \"", x$dataType.toString(), + "\", nullable = ", x$nullable(), + ")", + sep = "") +} diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 000000000000..e3676f57f907 --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,200 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # lists and pairlists + if (type %in% c("integer", "character", "logical", "double", "numeric")) { + if (is.na(object)) { + object <- NULL + type <- "NULL" + } + } + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + utfVal <- enc2utf8(value) + writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) + writeBin(utfVal, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeGenericList(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = "t", + POSIXct = "t", + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeGenericList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 000000000000..e83104f11642 --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,320 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, + error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkPackages Character string vector of packages from spark-packages.org +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkPackages = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + uriSep <- "//" + } else { + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + packages = sparkPackages) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open="rb") + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, + error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc = NULL) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + sc) + assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) + sqlContext +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc = NULL) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, + error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} + +#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a +#' different value or cleared. +#' +#' @param sc existing spark context +#' @param groupid the ID to be assigned to job groups +#' @param description description for the the job group ID +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE) +#'} + +setJobGroup <- function(sc, groupId, description, interruptOnCancel) { + callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) +} + +#' Clear current job group ID and its description +#' +#' @param sc existing spark context +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' clearJobGroup(sc) +#'} + +clearJobGroup <- function(sc) { + callJMethod(sc, "clearJobGroup") +} + +#' Cancel active jobs for the specified group +#' +#' @param sc existing spark context +#' @param groupId the ID of job group to be cancelled +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' cancelJobGroup(sc, "myJobGroup") +#'} + +cancelJobGroup <- function(sc, groupId) { + callJMethod(sc, "cancelJobGroup", groupId) +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 000000000000..4f9f4d9cad2a --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,599 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0 : (arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + # TODO: SPARK-7839 means we might not have the native library available + if (is.loaded("stringHashCode")) { + .Call("stringHashCode", key) + } else { + n <- nchar(key) + if (n == 0) { + 0L + } else { + asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) + hashC <- 0 + for (k in 1:length(asciiVals)) { + hashC <- mult31AndAdd(hashC, asciiVals[k]) + } + as.integer(hashC) + } + } + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Helper function used to wrap a 'numeric' value to integer bounds. +# Useful for implementing C-like integer arithmetic +wrapInt <- function(value) { + if (value > .Machine$integer.max) { + value <- value - 2 * .Machine$integer.max - 2 + } else if (value < -1 * .Machine$integer.max) { + value <- 2 * .Machine$integer.max + value + 2 + } + value +} + +# Multiply `val` by 31 and add `addVal` to the result. Ensures that +# integer-overflows are handled at every step. +mult31AndAdd <- function(val, addVal) { + vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + Reduce(function(a, b) { + wrapInt(as.numeric(a) + as.numeric(b)) + }, + vec) +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { + # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { + # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { + # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { + # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { + # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { + # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { + # Only include SparkR internals. + + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { + # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { + # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { + # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} + +# Append partition lengths to each partition in two input RDDs if needed. +# param +# x An RDD. +# Other An RDD. +# return value +# A list of two result RDDs. +appendPartitionLengths <- function(x, other) { + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # know the boundary of elements from x and other. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + len <- length(part) + part[[len + 1]] <- len + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + list (x, other) +} + +# Perform zip or cartesian between elements from two RDDs in each partition +# param +# rdd An RDD. +# zip A boolean flag indicating this call is for zip operation or not. +# return value +# A result RDD. +mergePartitions <- function(rdd, zip) { + serializerMode <- getSerializedMode(rdd) + partitionFunc <- function(partIndex, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. + if (zip && lengthOfKeys != lengthOfValues) { + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + } else { + keys <- list() + } + if (lengthOfValues > 1) { + values <- part[ (lengthOfKeys + 1) : (len - 1) ] + } else { + values <- list() + } + + if (!zip) { + return(mergeCompactLists(keys, values)) + } + } else { + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { list(k, v) }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(rdd, partitionFunc) +} diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 000000000000..2a8a8213d084 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") + .libPaths(c(packageDir, .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 000000000000..7189f1a26093 --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + # Make sure SparkR package is the last loaded one + old <- getOption("defaultPackages") + options(defaultPackages = c(old, "SparkR")) + + sc <- SparkR::sparkR.init() + assign("sc", sc, envir=.GlobalEnv) + sqlContext <- SparkR::sparkRSQL.init(sc) + sparkVer <- SparkR:::callJMethod(sc, "version") + assign("sqlContext", sqlContext, envir=.GlobalEnv) + cat("\n Welcome to") + cat("\n") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + if (nchar(sparkVer) == 0) { + cat("\n") + } else { + cat(" version ", sparkVer, "\n") + } + cat(" /_/", "\n") + cat("\n") + + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") +} diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 000000000000..1d5c2af631aa Binary files /dev/null and b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar differ diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/jarTest.R new file mode 100644 index 000000000000..d68bb20950b0 --- /dev/null +++ b/R/pkg/inst/tests/jarTest.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) + +sc <- sparkR.init() + +helloTest <- SparkR:::callJStatic("sparkR.test.hello", + "helloWorld", + "Dave") + +basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", + "addStuff", + 2L, + 2L) + +sparkR.stop() +output <- c(helloTest, basicFunction) +writeLines(output) diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/packageInAJarTest.R new file mode 100644 index 000000000000..207a37a0cb47 --- /dev/null +++ b/R/pkg/inst/tests/packageInAJarTest.R @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) +library(sparkPackageTest) + +sc <- sparkR.init() + +run1 <- myfunc(5L) + +run2 <- myfunc(-4L) + +sparkR.stop() + +if(run1 != 6) quit(save = "no", status = 1) + +if(run2 != -3) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R new file mode 100644 index 000000000000..009db85da2be --- /dev/null +++ b/R/pkg/inst/tests/test_Serde.R @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("SerDe functionality") + +sc <- sparkR.init() + +test_that("SerDe of primitive types", { + x <- callJStatic("SparkRHandler", "echo", 1L) + expect_equal(x, 1L) + expect_equal(class(x), "integer") + + x <- callJStatic("SparkRHandler", "echo", 1) + expect_equal(x, 1) + expect_equal(class(x), "numeric") + + x <- callJStatic("SparkRHandler", "echo", TRUE) + expect_true(x) + expect_equal(class(x), "logical") + + x <- callJStatic("SparkRHandler", "echo", "abc") + expect_equal(x, "abc") + expect_equal(class(x), "character") +}) + +test_that("SerDe of list of primitive types", { + x <- list(1L, 2L, 3L) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "integer") + + x <- list(1, 2, 3) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + + x <- list(TRUE, FALSE) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "logical") + + x <- list("a", "b", "c") + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "character") + + # Empty list + x <- list() + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) + +test_that("SerDe of list of lists", { + x <- list(list(1L, 2L, 3L), list(1, 2, 3), + list(TRUE, FALSE), list("a", "b", "c")) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + + # List of empty lists + x <- list(list(), list()) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 000000000000..f2452ed97d2e --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_equal(count(rdd), 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 000000000000..f054ac9a87d6 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_equal(getSerializedMode(union.rdd), "byte") + + rdd <- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_equal(getSerializedMode(union.rdd), "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("zipPartitions() on RDDs", { + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 + rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 + rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + func = function(x, y, z) { list(list(x, y, z))} )) + expect_equal(actual, + list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipPartitions(rdd, rdd, + func = function(x, y) { list(paste(x, y, sep = "\n")) })) + expected <- list(paste(mockFile, mockFile, sep = "\n")) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipPartitions(rdd1, rdd, + func = function(x, y) { list(x + nchar(y)) })) + expected <- list(0:1 + nchar(mockFile)) + expect_equal(actual, expected) + + rdd <- map(rdd, function(x) { x }) + actual <- collect(zipPartitions(rdd, rdd1, + func = function(x, y) { list(y + nchar(x)) })) + expect_equal(actual, expected) + + unlink(fileName) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 000000000000..bb86a5c922bd --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(SparkR:::value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R new file mode 100644 index 000000000000..8a20991f89af --- /dev/null +++ b/R/pkg/inst/tests/test_client.R @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in client.R") + +test_that("adding spark-testing-base as a package works", { + args <- generateSparkSubmitArgs("", "", "", "", + "holdenk:spark-testing-base:1.3.0_0.0.5") + expect_equal(gsub("[[:space:]]", "", args), + gsub("[[:space:]]", "", + "--packages holdenk:spark-testing-base:1.3.0_0.0.5")) +}) + +test_that("no package specified doesn't add packages flag", { + args <- generateSparkSubmitArgs("", "", "", "", "") + expect_equal(gsub("[[:space:]]", "", args), + "") +}) + +test_that("multiple packages don't produce a warning", { + expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 000000000000..513bbc8e6205 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R new file mode 100644 index 000000000000..cc1faeabffe3 --- /dev/null +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +context("include an external JAR in SparkContext") + +runScript <- function() { + sparkHome <- Sys.getenv("SPARK_HOME") + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + submitPath <- file.path(sparkHome, "bin/spark-submit") + res <- system2(command = submitPath, + args = c(jarPath, scriptPath), + stdout = TRUE) + tail(res, 2) +} + +test_that("sparkJars tag in SparkContext", { + testOutput <- runScript() + helloTest <- testOutput[1] + expect_equal(helloTest, "Hello, Dave") + basicFunction <- testOutput[2] + expect_equal(basicFunction, "4") +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 000000000000..8152b448d087 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 000000000000..f272de78ad4a --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 000000000000..2552127cc547 --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_is(rdd, "RDD") + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 000000000000..71aed2bb9d6a --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,793 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_equal(first(rdd), 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_equal(first(newrdd), 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(partIndex, part) { + part <- as.list(unlist(part) * partIndex + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(partIndex, part) { + part <- as.list(unlist(part) * partIndex) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + tempDir <- tempfile(pattern = "checkpoint") + setCheckpointDir(sc, tempDir) + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink(tempDir) +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(partIndex, part) { list(partIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x * x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >= 0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("cartesian() on RDDs", { + rdd <- parallelize(sc, 1:3) + actual <- collect(cartesian(rdd, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(1, 1), list(1, 2), list(1, 3), + list(2, 1), list(2, 2), list(2, 3), + list(3, 1), list(3, 2), list(3, 3))) + + # test case where one RDD is empty + emptyRdd <- parallelize(sc, list()) + actual <- collect(cartesian(rdd, emptyRdd)) + expect_equal(actual, list()) + + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + actual <- collect(cartesian(rdd, rdd)) + expected <- list( + list("Spark is awesome.", "Spark is pretty."), + list("Spark is awesome.", "Spark is awesome."), + list("Spark is pretty.", "Spark is pretty."), + list("Spark is pretty.", "Spark is awesome.")) + expect_equal(sortKeyValueList(actual), expected) + + rdd1 <- parallelize(sc, 0:1) + actual <- collect(cartesian(rdd1, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(0, "Spark is pretty."), + list(0, "Spark is awesome."), + list(1, "Spark is pretty."), + list(1, "Spark is awesome."))) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(cartesian(rdd, rdd1)) + expect_equal(sortKeyValueList(actual), expected) + + unlink(fileName) +}) + +test_that("subtract() on RDDs", { + l <- list(1, 1, 2, 2, 3, 4) + rdd1 <- parallelize(sc, l) + + # subtract by itself + actual <- collect(subtract(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtract by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + l) + + rdd2 <- parallelize(sc, list(2, 4)) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + list(1, 1, 3)) + + l <- list("a", "a", "b", "b", "c", "d") + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list("b", "d")) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="character"))), + list("a", "a", "c")) +}) + +test_that("subtractByKey() on pairwise RDDs", { + l <- list(list("a", 1), list("b", 4), + list("b", 5), list("a", 2)) + rdd1 <- parallelize(sc, l) + + # subtractByKey by itself + actual <- collect(subtractByKey(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtractByKey by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(l)) + + rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list("b", 4), list("b", 5))) + + l <- list(list(1, 1), list(2, 4), + list(2, 5), list(1, 2)) + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list(2, 4), list(2, 5))) +}) + +test_that("intersection() on RDDs", { + # intersection with self + actual <- collect(intersection(rdd, rdd)) + expect_equal(sort(as.integer(actual)), nums) + + # intersection with an empty RDD + emptyRdd <- parallelize(sc, list()) + actual <- collect(intersection(rdd, emptyRdd)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) + rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) + actual <- collect(intersection(rdd1, rdd2)) + expect_equal(sort(as.integer(actual)), 1:3) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) + +test_that("show()", { + rdd <- parallelize(sc, list(1:10)) + expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") +}) + +test_that("sampleByKey() on pairwise RDDs", { + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) + fractions <- list(a = 0.2, b = 0.1) + sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE) + expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE) + expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE) + expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE) + expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE) + expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE) + + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) }) + fractions <- list(`2` = 0.2, `3` = 0.1) + sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE) + expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE) + expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE) + expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE) + expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) + expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 000000000000..adf0b91d25fe --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,221 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for characters", { + stringKeyRDD <- parallelize(sc, + list(list("max", 1L), list("min", 2L), + list("other", 3L), list("max", 4L)), 2L) + reduced <- combineByKey(stringKeyRDD, + function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 000000000000..ee48a3dc0cc0 --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,1199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = "array", elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = "array", elementType = "integer", containsNull = TRUE)) + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_equal(class(testStruct), "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("structType and structField", { + testField <- structField("a", "string") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(nrow(df), 10) + expect_equal(ncol(df), 2) + expect_equal(dim(df), c(10, 2)) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- createDataFrame(sqlContext, rdd, schema) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), + age=c(19, 23, 18), + height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) +}) + +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- toDF(rdd, schema) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlContext, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlContext, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlContext, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlContext, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_equal(count(rdd), 3) + df <- jsonRDD(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlContext, rdd2) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlContext, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlContext, jsonPath) + registerTempTable(df, "table1") + expect_equal(length(tableNames(sqlContext)), 1) + df <- tables(sqlContext) + expect_equal(count(df), 1) + dropTempTable(sqlContext, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) + dropTempTable(sqlContext, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(sqlContext, jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") + dropTempTable(sqlContext, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") + dropTempTable(sqlContext, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlContext, "table1") + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) + dropTempTable(sqlContext, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlContext, jsonPath) + testRDD <- toRDD(df) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlContext, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlContext, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlContext, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_is(objectIn, "RDD") + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlContext, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_is(testRDD, "RDD") + collected <- collect(testRDD) + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlContext, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) + + # collect() returns data correctly from a DataFrame with 0 row + df0 <- limit(df, 0) + rdf <- collect(df0) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 0) + expect_equal(ncol(rdf), 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlContext, jsonPath) + dfLimited <- limit(df, 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlContext, jsonPath) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations result in an RDD with the correct values", { + df <- jsonFile(sqlContext, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlContext, jsonPath) + testSchema <- schema(df) + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") + + testTypes <- dtypes(df) + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") + + testCols <- columns(df) + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") + + testNames <- names(df) + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlContext, jsonPath) + testHead <- head(df) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) + + testHead2 <- head(df, 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) + + testFirst <- first(df) + expect_equal(nrow(testFirst), 1) + + # head() and first() return the correct data on + # a DataFrame with 0 row + df0 <- limit(df, 0) + + testHead <- head(df0) + expect_equal(nrow(testHead), 0) + expect_equal(ncol(testHead), 2) + + testFirst <- first(df0) + expect_equal(nrow(testFirst), 0) + expect_equal(ncol(testFirst), 2) +}) + +test_that("distinct() and unique on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlContext, jsonPathWithDup) + uniques <- distinct(df) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) + + uniques2 <- unique(df) + expect_is(uniques2, "DataFrame") + expect_equal(count(uniques2), 3) +}) + +test_that("sample on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + sampled <- sample(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_is(sampled, "DataFrame") + sampled2 <- sample(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) + + # Also test sample_frac + sampled3 <- sample_frac(df, FALSE, 0.1) + expect_true(count(sampled3) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") + + expect_is(df[,1], "DataFrame") + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_is(df2, "DataFrame") + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) +}) + +test_that("select with column", { + df <- jsonFile(sqlContext, jsonPath) + df1 <- select(df, "name") + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) + + df2 <- select(df, df$age) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) + + df3 <- select(df, lit("x")) + expect_equal(columns(df3), c("x")) + expect_equal(count(df3), 3) + expect_equal(collect(select(df3, "x"))[[1, 1]], "x") +}) + +test_that("subsetting", { + # jsonFile returns columns in random order + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + filtered <- df[df$age > 20,] + expect_equal(count(filtered), 1) + expect_equal(columns(filtered), c("name", "age")) + expect_equal(collect(filtered)$name, "Andy") + + df2 <- df[df$age == 19, 1] + expect_is(df2, "DataFrame") + expect_equal(count(df2), 1) + expect_equal(columns(df2), c("name")) + expect_equal(collect(df2)$name, "Justin") + + df3 <- df[df$age > 20, 2] + expect_equal(count(df3), 1) + expect_equal(columns(df3), c("age")) + + df4 <- df[df$age %in% c(19, 30), 1:2] + expect_equal(count(df4), 2) + expect_equal(columns(df4), c("name", "age")) + + df5 <- df[df$age %in% c(19), c(1,2)] + expect_equal(count(df5), 1) + expect_equal(columns(df5), c("name", "age")) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_equal(names(selected), "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_equal(count(selected2), 3) +}) + +test_that("expr() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) +}) + +test_that("column calculation", { + df <- jsonFile(sqlContext, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_equal(names(d), c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) +}) + +test_that("read.df() from json file", { + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) +}) + +test_that("write.df() as parquet file", { + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", mode="overwrite") + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + df2 <- sql(hiveCtx, "select * from json") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) + + df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") +}) +# +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end + expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) + expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) + expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) + expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") + expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) + expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") + expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) +}) + +test_that("string operators", { + df <- jsonFile(sqlContext, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") + expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") + expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) + expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") + expect_equal(collect(select(df, sha1(df$name)))[2, 1], + "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") + expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], + "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") + expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") + expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") + expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") + expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") + + l2 <- list(list(a = "aaads")) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) + expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + + l3 <- list(list(a = "a.b.c.d")) + df3 <- createDataFrame(sqlContext, l3) + expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") + expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") + expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") +}) + +test_that("date functions on a DataFrame", { + .originalTimeZone <- Sys.getenv("TZ") + Sys.setenv(TZ = "UTC") + l <- list(list(a = 1L, b = as.Date("2012-12-13")), + list(a = 2L, b = as.Date("2013-12-14")), + list(a = 3L, b = as.Date("2014-12-15"))) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) + expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) + expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) + expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) + expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) + expect_equal(collect(select(df, last_day(df$b)))[, 1], + c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) + expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], + c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) + expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) + expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], + c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) + expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], + c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) + expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], + c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) + + l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), + list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) + expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) + expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + + l3 <- list(list(a = 1000), list(a = -1000)) + df3 <- createDataFrame(sqlContext, l3) + result31 <- collect(select(df3, from_unixtime(df3$a))) + expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), + c(1, 2)) + result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) + expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) + Sys.setenv(TZ = .originalTimeZone) +}) + +test_that("greatest() and least() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) + expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) +}) + +test_that("when(), otherwise() and ifelse() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) +}) + +test_that("group by", { + df <- jsonFile(sqlContext, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_equal(1, count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_equal(1, count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_is(gd, "GroupedData") + df2 <- count(gd) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) + + # Also test group_by, summarize, mean + gd1 <- group_by(df, "name") + expect_is(gd1, "GroupedData") + df_summarized <- summarize(gd, mean_age = mean(df$age)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) + + df3 <- agg(gd, age = "sum") + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) +}) + +test_that("arrange() and orderBy() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + sorted <- arrange(df, df$age) + expect_equal(collect(sorted)[1,2], "Michael") + + sorted2 <- arrange(df, "name") + expect_equal(collect(sorted2)[2,"age"], 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_equal(collect(sorted3)[2, "age"], 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + filtered <- filter(df, "age > 20") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) +}) + +test_that("join() and merge() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlContext, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_equal(count(joined), 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_equal(count(joined2), 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_equal(count(joined3), 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_equal(count(joined4), 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + + merged <- select(merge(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(merged), c("newAge", "name", "test")) + expect_equal(count(merged), 4) + expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlContext, jsonPath) + testRDD <- toJSON(df) + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlContext, jsonPath) + s <- capture.output(showDF(df)) + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) +}) + +test_that("isLocal()", { + df <- jsonFile(sqlContext, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- read.df(sqlContext, jsonPath2, "json") + + unioned <- arrange(unionAll(df, df2), df$age) + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") + + unioned2 <- arrange(rbind(unioned, df, df2), df$age) + expect_is(unioned2, "DataFrame") + expect_equal(count(unioned2), 12) + expect_equal(first(unioned2)$name, "Michael") + + excepted <- arrange(except(df, df2), desc(df$age)) + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") + + intersected <- arrange(intersect(df, df2), df$age) + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlContext, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") +}) + +test_that("mutate(), rename() and names()", { + df <- jsonFile(sqlContext, jsonPath) + newDF <- mutate(df, newAge = df$age + 2) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + + newDF2 <- rename(df, newerAge = df$age) + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") + + names(newDF2) <- c("newerName", "evenNewerAge") + expect_equal(length(names(newDF2)), 2) + expect_equal(names(newDF2)[1], "newerName") +}) + +test_that("write.df() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlContext, jsonPath) + write.df(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlContext, parquetPath) + expect_is(parquetDF, "DataFrame") + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlContext, jsonPath) + write.df(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.df(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df) * 2) +}) + +test_that("describe() and summarize() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + stats <- describe(df, "age") + expect_equal(collect(stats)[1, "summary"], "count") + expect_equal(collect(stats)[2, "age"], "24.5") + expect_equal(collect(stats)[3, "age"], "5.5") + stats <- describe(df) + expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[5, "age"], "30") + + stats2 <- summary(df) + expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[5, "age"], "30") +}) + +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_identical(expected, actual) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_identical(expected, actual) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_identical(expected, actual) +}) + +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b),] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + +unlink(parquetPath) +unlink(jsonPath) +unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 000000000000..c2c724cdc762 --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) +}) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 000000000000..a9cf83dbdbdb --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_is(rdd, "RDD") + expect_true(count(rdd) > 0) + expect_equal(count(rdd), 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1L) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1L) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_equal(count(rdd), 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 000000000000..12df4cf4f65b --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_equal(getSerializedMode(text.rdd), "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_equal(getSerializedMode(ser.rdd), "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + # TODO(shivaram): length(ls(env)) is 4 here for some reason and `lapply` is included in `env`. + # Disabling this test till we debug this. + # + # expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t <- 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 000000000000..3584b418a71a --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 000000000000..0c3b0d1f4be2 --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +# Get current system time +currentTimeSecs <- function() { + as.numeric(Sys.time()) +} + +# Get elapsed time +elapsedSecs <- function() { + proc.time()[3] +} + +# Constants +specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) + +# Timing R process boot +bootTime <- currentTimeSecs() +bootElap <- elapsedSecs() + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Timing init envs for computing +initElap <- elapsedSecs() + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + SparkR:::setBroadcastValue(bcastId, value) + } +} + +# Timing broadcast +broadcastElap <- elapsedSecs() + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readMultipleObjects(inputCon) + } + # Timing reading input data for execution + inputElap <- elapsedSecs() + + output <- computeFunc(partition, data) + # Timing computing + computeElap <- elapsedSecs() + + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) + } + # Timing output + outputElap <- elapsedSecs() + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readMultipleObjects(inputCon) + } + # Timing reading input data for execution + inputElap <- elapsedSecs() + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + # Timing computing + computeElap <- elapsedSecs() + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + # Timing output + outputElap <- elapsedSecs() + } +} else { + inputElap <- broadcastElap + computeElap <- broadcastElap + outputElap <- broadcastElap +} + +# Report timing +SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA) +SparkR:::writeDouble(outputCon, bootTime) +SparkR:::writeDouble(outputCon, initElap - bootElap) # init +SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast +SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input +SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute +SparkR:::writeDouble(outputCon, outputElap - computeElap) # output + +# End of output +SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src-native/Makefile b/R/pkg/src-native/Makefile new file mode 100644 index 000000000000..a55a56fe80e1 --- /dev/null +++ b/R/pkg/src-native/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src-native/Makefile.win b/R/pkg/src-native/Makefile.win new file mode 100644 index 000000000000..aa486d822837 --- /dev/null +++ b/R/pkg/src-native/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src-native/string_hash_code.c b/R/pkg/src-native/string_hash_code.c new file mode 100644 index 000000000000..e3274b9a0c54 --- /dev/null +++ b/R/pkg/src-native/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 000000000000..4f8a1ed2d83e --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 000000000000..e82ad0ba2cd0 --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/README.md b/README.md index af0233957819..380422ca00db 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,8 +75,8 @@ can be run using: ./dev/run-tests -Please see the guidance on how to -[run all automated tests](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-AutomatedTesting). +Please see the guidance on how to +[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions @@ -85,7 +85,7 @@ storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version) +["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) for detailed guidance on building for a particular distribution of Hadoop, including building for particular Hive and Hive Thriftserver distributions. See also ["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) diff --git a/assembly/pom.xml b/assembly/pom.xml index 1bb5a671f539..e9c6d26ccddc 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -36,10 +36,6 @@ scala-${scala.binary.version} spark-assembly-${project.version}-hadoop${hadoop.version}.jar ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} - spark - /usr/share/spark - root - 744 @@ -96,6 +92,27 @@ true + + + org.apache.maven.plugins + maven-antrun-plugin + + + package + + run + + + + + + + + + + + + org.apache.maven.plugins @@ -198,7 +215,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist @@ -217,123 +233,6 @@ - - deb - - - - org.codehaus.mojo - buildnumber-maven-plugin - 1.2 - - - validate - - create - - - 8 - - - - - - org.vafer - jdeb - 0.11 - - - package - - jdeb - - - ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb - false - gzip - - - ${spark.jar} - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/jars - - - - ${basedir}/src/deb/RELEASE - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path} - - - - ${basedir}/../conf - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/conf - 744 - - - - ${basedir}/../bin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/bin - ${deb.bin.filemode} - - - - ${basedir}/../sbin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/sbin - 744 - - - - ${basedir}/../python - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/python - 744 - - - - - - - - - - - - kinesis-asl - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - diff --git a/assembly/src/deb/RELEASE b/assembly/src/deb/RELEASE deleted file mode 100644 index aad50ee73aa4..000000000000 --- a/assembly/src/deb/RELEASE +++ /dev/null @@ -1,2 +0,0 @@ -compute-classpath.sh uses the existence of this file to decide whether to put the assembly jar on the -classpath or instead to use classfiles in the source tree. \ No newline at end of file diff --git a/assembly/src/deb/control/control b/assembly/src/deb/control/control deleted file mode 100644 index a6b4471d485f..000000000000 --- a/assembly/src/deb/control/control +++ /dev/null @@ -1,8 +0,0 @@ -Package: [[deb.pkg.name]] -Version: [[version]]-[[buildNumber]] -Section: misc -Priority: extra -Architecture: all -Maintainer: Matei Zaharia -Description: [[name]] -Distribution: development diff --git a/bagel/pom.xml b/bagel/pom.xml index 510e92640eff..ed5c37e595a9 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 853ef0ed2986..edbecdae9209 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02..fb10d734ac74 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd deleted file mode 100644 index 088f993954d9..000000000000 --- a/bin/compute-classpath.cmd +++ /dev/null @@ -1,124 +0,0 @@ -@echo off - -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -rem script and the ExecutorRunner in standalone cluster mode. - -rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting -rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we -rem need to set it here because we use !datanucleus_jars! below. -if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion -setlocal enabledelayedexpansion -:skip_delayed_expansion - -set SCALA_VERSION=2.10 - -rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" - -rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% - -if not "x%SPARK_CONF_DIR%"=="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% -) else ( - set CLASSPATH=%CLASSPATH%;%FWDIR%conf -) - -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) else ( - for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) - -set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% - -rem When Hive support is needed, Datanucleus jars must be included on the classpath. -rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -rem built with Hive, so look for them there. -if exist "%FWDIR%RELEASE" ( - set datanucleus_dir=%FWDIR%lib -) else ( - set datanucleus_dir=%FWDIR%lib_managed\jars -) -set "datanucleus_jars=" -for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( - set datanucleus_jars=!datanucleus_jars!;%%d -) -set CLASSPATH=%CLASSPATH%;%datanucleus_jars% - -set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes - -set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes - -if "x%SPARK_TESTING%"=="x1" ( - rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH - rem so that local compilation takes precedence over assembled jar - set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% -) - -rem Add hadoop conf dir - else FileSystem.*, etc fail -rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -rem the configurtion files. -if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir - set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% -:no_hadoop_conf_dir - -if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir - set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% -:no_yarn_conf_dir - -rem To allow for distributions to append needed libraries to the classpath (e.g. when -rem using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and -rem append it to tbe final classpath. -if not "x%$SPARK_DIST_CLASSPATH%"=="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_DIST_CLASSPATH% -) - -rem A bit of a hack to allow calling this script within run2.cmd without seeing output -if "%DONT_PRINT_CLASSPATH%"=="1" goto exit - -echo %CLASSPATH% - -:exit diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh deleted file mode 100755 index a8c344b1ca59..000000000000 --- a/bin/compute-classpath.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -# script and the ExecutorRunner in standalone cluster mode. - -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -. "$FWDIR"/bin/load-spark-env.sh - -if [ -n "$SPARK_CLASSPATH" ]; then - CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH" -else - CLASSPATH="$SPARK_SUBMIT_CLASSPATH" -fi - -# Build up classpath -if [ -n "$SPARK_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR" -else - CLASSPATH="$CLASSPATH:$FWDIR/conf" -fi - -ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" - -if [ -n "$JAVA_HOME" ]; then - JAR_CMD="$JAVA_HOME/bin/jar" -else - JAR_CMD="jar" -fi - -# A developer option to prepend more recently compiled Spark classes -if [ -n "$SPARK_PREPEND_CLASSES" ]; then - echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ - "classes ahead of assembly." >&2 - # Spark classes - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" - # Jars for shaded deps in their original form (copied here during build) - CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" -fi - -# Use spark-assembly jar from either RELEASE or assembly directory -if [ -f "$FWDIR/RELEASE" ]; then - assembly_folder="$FWDIR"/lib -else - assembly_folder="$ASSEMBLY_DIR" -fi - -num_jars=0 - -for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark assembly in $assembly_folder" 1>&2 - echo "You need to build Spark before running this program." 1>&2 - exit 1 - fi - ASSEMBLY_JAR="$f" - num_jars=$((num_jars+1)) -done - -if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2 - ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 -fi - -# Verify that versions of java used to build the jars and run Spark are compatible -jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) -if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 -fi - -CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" - -# When Hive support is needed, Datanucleus jars must be included on the classpath. -# Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -# Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -# built with Hive, so first check if the datanucleus jars exist, and then ensure the current Spark -# assembly is built for Hive, before actually populating the CLASSPATH with the jars. -# Note that this check order is faster (by up to half a second) in the case where Hive is not used. -if [ -f "$FWDIR/RELEASE" ]; then - datanucleus_dir="$FWDIR"/lib -else - datanucleus_dir="$FWDIR"/lib_managed/jars -fi - -datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar$")" -datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)" - -if [ -n "$datanucleus_jars" ]; then - hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) - if [ -n "$hive_files" ]; then - echo "Spark assembly has been built with Hive, including Datanucleus jars on classpath" 1>&2 - CLASSPATH="$CLASSPATH:$datanucleus_jars" - fi -fi - -# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 -if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" -fi - -# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! -# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -# the configurtion files. -if [ -n "$HADOOP_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR" -fi -if [ -n "$YARN_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" -fi - -# To allow for distributions to append needed libraries to the classpath (e.g. when -# using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and -# append it to tbe final classpath. -if [ -n "$SPARK_DIST_CLASSPATH" ]; then - CLASSPATH="$CLASSPATH:$SPARK_DIST_CLASSPATH" -fi - -echo "$CLASSPATH" diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd new file mode 100644 index 000000000000..36d932c453b6 --- /dev/null +++ b/bin/load-spark-env.cmd @@ -0,0 +1,59 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. +rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's +rem conf/ subdirectory. + +if [%SPARK_ENV_LOADED%] == [] ( + set SPARK_ENV_LOADED=1 + + if not [%SPARK_CONF_DIR%] == [] ( + set user_conf_dir=%SPARK_CONF_DIR% + ) else ( + set user_conf_dir=%~dp0..\..\conf + ) + + call :LoadSparkEnv +) + +rem Setting SPARK_SCALA_VERSION if not already set. + +set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 +set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 + +if [%SPARK_SCALA_VERSION%] == [] ( + + if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + exit 1 + ) + if exist %ASSEMBLY_DIR2% ( + set SPARK_SCALA_VERSION=2.11 + ) else ( + set SPARK_SCALA_VERSION=2.10 + ) +) +exit /b 0 + +:LoadSparkEnv +if exist "%user_conf_dir%\spark-env.cmd" ( + call "%user_conf_dir%\spark-env.cmd" +) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 356b3d49b2ff..95779e9ddbb1 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,6 +20,7 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 @@ -43,7 +44,7 @@ if [ -z "$SPARK_SCALA_VERSION" ]; then ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" - + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 @@ -54,5 +55,5 @@ if [ -z "$SPARK_SCALA_VERSION" ]; then export SPARK_SCALA_VERSION="2.11" else export SPARK_SCALA_VERSION="2.10" - fi + fi fi diff --git a/bin/pyspark b/bin/pyspark index 0b4f695dd06d..8f2a3b5a7717 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,36 +17,10 @@ # limitations under the License. # -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" - -source "$FWDIR/bin/utils.sh" - -source "$FWDIR"/bin/load-spark-env.sh - -function usage() { - echo "Usage: ./bin/pyspark [options]" 1>&2 - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 -} - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - -# Exit if the user hasn't compiled Spark -if [ ! -f "$FWDIR/RELEASE" ]; then - # Exit if the user hasn't compiled Spark - ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null - if [[ $? != 0 ]]; then - echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2 - echo "You need to build Spark before running this program" 1>&2 - exit 1 - fi -fi +source "$SPARK_HOME"/bin/load-spark-env.sh +export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. @@ -95,42 +69,17 @@ export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" -export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py" - -# Build up arguments list manually to preserve quotes and backslashes. -# We export Spark submit arguments as an environment variable because shell.py must run as a -# PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" -PYSPARK_SUBMIT_ARGS="" -whitespace="[[:space:]]" -for i in "${SUBMISSION_OPTS[@]}"; do - if [[ $i =~ \" ]]; then i=$(echo $i | sed 's/\"/\\\"/g'); fi - if [[ $i =~ $whitespace ]]; then i=\"$i\"; fi - PYSPARK_SUBMIT_ARGS="$PYSPARK_SUBMIT_ARGS $i" -done -export PYSPARK_SUBMIT_ARGS +export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py" # For pyspark tests if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + export PYTHONHASHSEED=0 + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi -# If a python file is provided, directly run spark-submit. -if [[ "$1" =~ \.py$ ]]; then - echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 - echo -e "Use ./bin/spark-submit \n" 1>&2 - primary="$1" - shift - gatherSparkSubmitOpts "$@" - exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" -else - exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS -fi +export PYSPARK_DRIVER_PYTHON +export PYSPARK_DRIVER_PYTHON_OPTS +exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a542ec80b49d..3c6169983e76 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -17,59 +17,22 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SCALA_VERSION=2.10 - rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Export this as SPARK_HOME -set SPARK_HOME=%FWDIR% - -rem Test whether the user has built Spark -if exist "%FWDIR%RELEASE" goto skip_build_test -set FOUND_JAR=0 -for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set FOUND_JAR=1 -) -if [%FOUND_JAR%] == [0] ( - echo Failed to find Spark assembly JAR. - echo You need to build Spark before running this program. - goto exit -) -:skip_build_test +set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd +set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. -if [%PYSPARK_PYTHON%] == [] set PYSPARK_PYTHON=python +if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( + set PYSPARK_DRIVER_PYTHON=python + if not [%PYSPARK_PYTHON%] == [] set PYSPARK_DRIVER_PYTHON=%PYSPARK_PYTHON% +) -set PYTHONPATH=%FWDIR%python;%PYTHONPATH% -set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% -set PYTHONSTARTUP=%FWDIR%python\pyspark\shell.py -set PYSPARK_SUBMIT_ARGS=%* - -echo Running %PYSPARK_PYTHON% with PYTHONPATH=%PYTHONPATH% - -rem Check whether the argument is a file -for /f %%i in ('echo %1^| findstr /R "\.py"') do ( - set PYTHON_FILE=%%i -) - -if [%PYTHON_FILE%] == [] ( - if [%IPYTHON%] == [1] ( - ipython %IPYTHON_OPTS% - ) else ( - %PYSPARK_PYTHON% - ) -) else ( - echo. - echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0. - echo Use ./bin/spark-submit ^ - echo. - "%FWDIR%\bin\spark-submit.cmd" %PYSPARK_SUBMIT_ARGS% -) +set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -:exit +call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* diff --git a/bin/run-example b/bin/run-example index c567acf9a6b5..798e2caeb88c 100755 --- a/bin/run-example +++ b/bin/run-example @@ -42,7 +42,7 @@ fi JAR_COUNT=0 -for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do +for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do if [[ ! -e "$f" ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -54,7 +54,7 @@ done if [ "$JAR_COUNT" -gt "1" ]; then echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2 - ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2 + ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2 echo "Please remove all but one jar." 1>&2 exit 1 fi @@ -67,7 +67,7 @@ if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -"$FWDIR"/bin/spark-submit \ +exec "$FWDIR"/bin/spark-submit \ --master $EXAMPLE_MASTER \ --class $EXAMPLE_CLASS \ "$SPARK_EXAMPLES_JAR" \ diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b49d0dcb4ff2..c3e0221fb62e 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\ rem Export this as SPARK_HOME set SPARK_HOME=%FWDIR% -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if not "x%1"=="x" goto arg_given diff --git a/bin/spark-class b/bin/spark-class index 2f0441bb3c1c..2b59e5df5736 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -17,87 +17,10 @@ # limitations under the License. # -# NOTE: Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! - -cygwin=false -case "`uname`" in - CYGWIN*) cygwin=true;; -esac - # Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" -export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" - -. "$FWDIR"/bin/load-spark-env.sh +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -if [ -z "$1" ]; then - echo "Usage: spark-class []" 1>&2 - exit 1 -fi - -if [ -n "$SPARK_MEM" ]; then - echo -e "Warning: SPARK_MEM is deprecated, please use a more specific config option" 1>&2 - echo -e "(e.g., spark.executor.memory or spark.driver.memory)." 1>&2 -fi - -# Use SPARK_MEM or 512m as the default memory, to be overridden by specific options -DEFAULT_MEM=${SPARK_MEM:-512m} - -SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" - -# Add java opts and memory settings for master, worker, history server, executors, and repl. -case "$1" in - # Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. - 'org.apache.spark.deploy.master.Master') - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS" - OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} - ;; - 'org.apache.spark.deploy.worker.Worker') - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS" - OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} - ;; - 'org.apache.spark.deploy.history.HistoryServer') - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_HISTORY_OPTS" - OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} - ;; - - # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. - 'org.apache.spark.executor.CoarseGrainedExecutorBackend') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" - OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} - ;; - 'org.apache.spark.executor.MesosExecutorBackend') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" - OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} - export PYTHONPATH="$FWDIR/python:$PYTHONPATH" - export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" - ;; - - # Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + - # SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. - 'org.apache.spark.deploy.SparkSubmit') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS" - OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} - if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then - if [[ $OSTYPE == darwin* ]]; then - export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH" - else - export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH" - fi - fi - if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then - OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY" - fi - ;; - - *) - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" - OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} - ;; -esac +. "$SPARK_HOME"/bin/load-spark-env.sh # Find the java binary if [ -n "${JAVA_HOME}" ]; then @@ -110,83 +33,45 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') -# Set JAVA_OPTS to be able to load native libraries and to set heap size -if [ "$JAVA_VERSION" -ge 18 ]; then - JAVA_OPTS="$OUR_JAVA_OPTS" +# Find assembly jar +SPARK_ASSEMBLY_JAR= +if [ -f "$SPARK_HOME/RELEASE" ]; then + ASSEMBLY_DIR="$SPARK_HOME/lib" else - JAVA_OPTS="-XX:MaxPermSize=128m $OUR_JAVA_OPTS" -fi -JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" - -# Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`" -fi - -# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! - -TOOLS_DIR="$FWDIR"/tools -SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then - # Use the JAR from the SBT build - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`" -fi -if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then - # Use the JAR from the Maven build - # TODO: this also needs to become an assembly! - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar`" + ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" fi -# Compute classpath using external script -classpath_output=$("$FWDIR"/bin/compute-classpath.sh) -if [[ "$?" != "0" ]]; then - echo "$classpath_output" +num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then + echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 + echo "You need to build Spark before running this program." 1>&2 exit 1 -else - CLASSPATH="$classpath_output" fi - -if [[ "$1" =~ org.apache.spark.tools.* ]]; then - if test -z "$SPARK_TOOLS_JAR"; then - echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 - echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2 - exit 1 - fi - CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" +ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" +if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 fi -if $cygwin; then - CLASSPATH="`cygpath -wp "$CLASSPATH"`" - if [ "$1" == "org.apache.spark.tools.JavaAPICompletenessChecker" ]; then - export SPARK_TOOLS_JAR="`cygpath -w "$SPARK_TOOLS_JAR"`" - fi -fi -export CLASSPATH +SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" -# In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. -# Here we must parse the properties file for relevant "spark.driver.*" configs before launching -# the driver JVM itself. Instead of handling this complexity in Bash, we launch a separate JVM -# to prepare the launch environment of this driver JVM. +LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" -if [ -n "$SPARK_SUBMIT_BOOTSTRAP_DRIVER" ]; then - # This is used only if the properties file actually contains these special configs - # Export the environment variables needed by SparkSubmitDriverBootstrapper - export RUNNER - export CLASSPATH - export JAVA_OPTS - export OUR_JAVA_MEM - export SPARK_CLASS=1 - shift # Ignore main class (org.apache.spark.deploy.SparkSubmit) and use our own - exec "$RUNNER" org.apache.spark.deploy.SparkSubmitDriverBootstrapper "$@" -else - # Note: The format of this command is closely echoed in SparkSubmitDriverBootstrapper.scala - if [ -n "$SPARK_PRINT_LAUNCH_COMMAND" ]; then - echo -n "Spark Command: " 1>&2 - echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" 1>&2 - echo -e "========================================\n" 1>&2 - fi - exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" +# Add the launcher build dir to the classpath if requested. +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" fi +export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" + +# The launcher library will print arguments separated by a NULL character, to allow arguments with +# characters that would be otherwise interpreted by the shell. Read that in a while loop, populating +# an array that will be used to exec the final command. +CMD=() +while IFS= read -d '' -r ARG; do + CMD+=("$ARG") +done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") +exec "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index da46543647ef..db09fa27e51a 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -17,135 +17,54 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! - -setlocal enabledelayedexpansion - -set SCALA_VERSION=2.10 - rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Export this as SPARK_HOME -set SPARK_HOME=%FWDIR% +set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given -if not "x%1"=="x" goto arg_given +if "x%1"=="x" ( echo Usage: spark-class ^ [^] - goto exit -:arg_given - -if not "x%SPARK_MEM%"=="x" ( - echo Warning: SPARK_MEM is deprecated, please use a more specific config option - echo e.g., spark.executor.memory or spark.driver.memory. + exit /b 1 ) -rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options -set OUR_JAVA_MEM=%SPARK_MEM% -if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m - -set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true - -rem Add java opts and memory settings for master, worker, history server, executors, and repl. -rem Master, Worker and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. -if "%1"=="org.apache.spark.deploy.master.Master" ( - set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS% - if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% -) else if "%1"=="org.apache.spark.deploy.worker.Worker" ( - set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS% - if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% -) else if "%1"=="org.apache.spark.deploy.history.HistoryServer" ( - set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_HISTORY_OPTS% - if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% - -rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. -) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% - if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% -) else if "%1"=="org.apache.spark.executor.MesosExecutorBackend" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% - if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% +rem Find assembly jar +set SPARK_ASSEMBLY_JAR=0 -rem Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + -rem SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. -rem The repl also uses SPARK_REPL_OPTS. -) else if "%1"=="org.apache.spark.deploy.SparkSubmit" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_SUBMIT_OPTS% %SPARK_REPL_OPTS% - if not "x%SPARK_SUBMIT_LIBRARY_PATH%"=="x" ( - set OUR_JAVA_OPTS=!OUR_JAVA_OPTS! -Djava.library.path=%SPARK_SUBMIT_LIBRARY_PATH% - ) else if not "x%SPARK_LIBRARY_PATH%"=="x" ( - set OUR_JAVA_OPTS=!OUR_JAVA_OPTS! -Djava.library.path=%SPARK_LIBRARY_PATH% - ) - if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% - if not "x%SPARK_SUBMIT_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_SUBMIT_DRIVER_MEMORY% +if exist "%SPARK_HOME%\RELEASE" ( + set ASSEMBLY_DIR=%SPARK_HOME%\lib ) else ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% - if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% + set ASSEMBLY_DIR=%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION% ) -rem Set JAVA_OPTS to be able to load native libraries and to set heap size -for /f "tokens=3" %%i in ('java -version 2^>^&1 ^| find "version"') do set jversion=%%i -for /f "tokens=1 delims=_" %%i in ("%jversion:~1,-1%") do set jversion=%%i -if "%jversion%" geq "1.8.0" ( - set JAVA_OPTS=%OUR_JAVA_OPTS% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% -) else ( - set JAVA_OPTS=-XX:MaxPermSize=128m %OUR_JAVA_OPTS% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% -) -rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! - -rem Test whether the user has built Spark -if exist "%FWDIR%RELEASE" goto skip_build_test -set FOUND_JAR=0 -for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set FOUND_JAR=1 +for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( + set SPARK_ASSEMBLY_JAR=%%d ) -if "%FOUND_JAR%"=="0" ( +if "%SPARK_ASSEMBLY_JAR%"=="0" ( echo Failed to find Spark assembly JAR. echo You need to build Spark before running this program. - goto exit + exit /b 1 ) -:skip_build_test -set TOOLS_DIR=%FWDIR%tools -set SPARK_TOOLS_JAR= -for %%d in ("%TOOLS_DIR%\target\scala-%SCALA_VERSION%\spark-tools*assembly*.jar") do ( - set SPARK_TOOLS_JAR=%%d +set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% + +rem Add the launcher build dir to the classpath if requested. +if not "x%SPARK_PREPEND_CLASSES%"=="x" ( + set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% ) -rem Compute classpath using external script -set DONT_PRINT_CLASSPATH=1 -call "%FWDIR%bin\compute-classpath.cmd" -set DONT_PRINT_CLASSPATH=0 -set CLASSPATH=%CLASSPATH%;%SPARK_TOOLS_JAR% +set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java -rem In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. -rem Here we must parse the properties file for relevant "spark.driver.*" configs before launching -rem the driver JVM itself. Instead of handling this complexity here, we launch a separate JVM -rem to prepare the launch environment of this driver JVM. - -rem In this case, leave out the main class (org.apache.spark.deploy.SparkSubmit) and use our own. -rem Leaving out the first argument is surprisingly difficult to do in Windows. Note that this must -rem be done here because the Windows "shift" command does not work in a conditional block. -set BOOTSTRAP_ARGS= -shift -:start_parse -if "%~1" == "" goto end_parse -set BOOTSTRAP_ARGS=%BOOTSTRAP_ARGS% %~1 -shift -goto start_parse -:end_parse - -if not [%SPARK_SUBMIT_BOOTSTRAP_DRIVER%] == [] ( - set SPARK_CLASS=1 - "%RUNNER%" org.apache.spark.deploy.SparkSubmitDriverBootstrapper %BOOTSTRAP_ARGS% -) else ( - "%RUNNER%" -cp "%CLASSPATH%" %JAVA_OPTS% %* +rem The launcher library prints the command to be executed in a single line suitable for being +rem executed by the batch interpreter. So read all the output of the launcher into a variable. +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( + set SPARK_CMD=%%i ) -:exit +del %LAUNCHER_OUTPUT% +%SPARK_CMD% diff --git a/bin/spark-shell b/bin/spark-shell index cca5aa067612..00ab7afd118b 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -28,25 +28,11 @@ esac # Enter posix mode for bash set -o posix -## Global script variables -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage() { - echo "Usage: ./bin/spark-shell [options]" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 -} - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - -source "$FWDIR"/bin/utils.sh -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" +export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, -# so we need to add the "-Dscala.usejavacp=true" flag mnually. We +# so we need to add the "-Dscala.usejavacp=true" flag manually. We # do this specifically for the Spark shell because the scala REPL # has its own class loader, and any additional classpath specified # through spark.driver.extraClassPath is not automatically propagated. @@ -61,11 +47,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd old mode 100755 new mode 100644 diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 1d1a40da315e..b9b0f510d7f5 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,24 +18,18 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] -echo "%*" | findstr " --help -h" >nul -if %ERRORLEVEL% equ 0 ( - call :usage - exit /b 0 +rem SPARK-4161: scala does not assume use of the java classpath, +rem so we need to add the "-Dscala.usejavacp=true" flag manually. We +rem do this specifically for the Spark shell because the scala REPL +rem has its own class loader, and any additional classpath specified +rem through spark.driver.extraClassPath is not automatically propagated. +if "x%SPARK_SUBMIT_OPTS%"=="x" ( + set SPARK_SUBMIT_OPTS=-Dscala.usejavacp=true + goto run_shell ) +set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" -call %SPARK_HOME%\bin\windows-utils.cmd %* -if %ERRORLEVEL% equ 1 ( - call :usage - exit /b 1 -) - -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %SUBMISSION_OPTS% spark-shell %APPLICATION_OPTS% - -exit /b 0 - -:usage -echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -%SPARK_HOME%\bin\spark-submit --help 2>&1 | findstr /V "Usage" 1>&2 -exit /b 0 +:run_shell +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/bin/spark-sql b/bin/spark-sql index 3b6cc420fea8..4ea7bc6e39c0 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,41 +17,6 @@ # limitations under the License. # -# -# Shell script for starting the Spark SQL CLI - -# Enter posix mode for bash -set -o posix - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" - -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage { - echo "Usage: ./bin/spark-sql [options] [cli option]" - pattern="usage" - pattern+="\|Spark assembly has been built with Hive" - pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" - pattern+="\|Spark Command: " - pattern+="\|--help" - pattern+="\|=======" - - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - echo - echo "CLI options:" - "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 -} - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage - exit 0 -fi - -source "$FWDIR"/bin/utils.sh -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" - -exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" +export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" +exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 3e5cbdbb2439..255378b0f077 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,58 +17,9 @@ # limitations under the License. # -# NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! +SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -ORIG_ARGS=("$@") - -# Set COLUMNS for progress bar -export COLUMNS=`tput cols` - -while (($#)); do - if [ "$1" = "--deploy-mode" ]; then - SPARK_SUBMIT_DEPLOY_MODE=$2 - elif [ "$1" = "--properties-file" ]; then - SPARK_SUBMIT_PROPERTIES_FILE=$2 - elif [ "$1" = "--driver-memory" ]; then - export SPARK_SUBMIT_DRIVER_MEMORY=$2 - elif [ "$1" = "--driver-library-path" ]; then - export SPARK_SUBMIT_LIBRARY_PATH=$2 - elif [ "$1" = "--driver-class-path" ]; then - export SPARK_SUBMIT_CLASSPATH=$2 - elif [ "$1" = "--driver-java-options" ]; then - export SPARK_SUBMIT_OPTS=$2 - elif [ "$1" = "--master" ]; then - export MASTER=$2 - fi - shift -done - -if [ -z "$SPARK_CONF_DIR" ]; then - export SPARK_CONF_DIR="$SPARK_HOME/conf" -fi -DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf" -if [ "$MASTER" == "yarn-cluster" ]; then - SPARK_SUBMIT_DEPLOY_MODE=cluster -fi -export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"} -export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"} - -# For client mode, the driver will be launched in the same JVM that launches -# SparkSubmit, so we may need to read the properties file for any extra class -# paths, library paths, java options and memory early on. Otherwise, it will -# be too late by the time the driver JVM has started. - -if [[ "$SPARK_SUBMIT_DEPLOY_MODE" == "client" && -f "$SPARK_SUBMIT_PROPERTIES_FILE" ]]; then - # Parse the properties file only if the special configs exist - contains_special_configs=$( - grep -e "spark.driver.extra*\|spark.driver.memory" "$SPARK_SUBMIT_PROPERTIES_FILE" | \ - grep -v "^[[:space:]]*#" - ) - if [ -n "$contains_special_configs" ]; then - export SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - fi -fi - -exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" +# disable randomized hash for string in Python 3.3+ +export PYTHONHASHSEED=0 +exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 12244a9cb04f..651376e52692 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -17,62 +17,11 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! +rem This is the entry point for running Spark submit. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -set SPARK_HOME=%~dp0.. -set ORIG_ARGS=%* +rem disable randomized hash for string in Python 3.3+ +set PYTHONHASHSEED=0 -rem Reset the values of all variables used -set SPARK_SUBMIT_DEPLOY_MODE=client - -if not defined %SPARK_CONF_DIR% ( - set SPARK_CONF_DIR=%SPARK_HOME%\conf -) -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf -set SPARK_SUBMIT_DRIVER_MEMORY= -set SPARK_SUBMIT_LIBRARY_PATH= -set SPARK_SUBMIT_CLASSPATH= -set SPARK_SUBMIT_OPTS= -set SPARK_SUBMIT_BOOTSTRAP_DRIVER= - -:loop -if [%1] == [] goto continue - if [%1] == [--deploy-mode] ( - set SPARK_SUBMIT_DEPLOY_MODE=%2 - ) else if [%1] == [--properties-file] ( - set SPARK_SUBMIT_PROPERTIES_FILE=%2 - ) else if [%1] == [--driver-memory] ( - set SPARK_SUBMIT_DRIVER_MEMORY=%2 - ) else if [%1] == [--driver-library-path] ( - set SPARK_SUBMIT_LIBRARY_PATH=%2 - ) else if [%1] == [--driver-class-path] ( - set SPARK_SUBMIT_CLASSPATH=%2 - ) else if [%1] == [--driver-java-options] ( - set SPARK_SUBMIT_OPTS=%2 - ) else if [%1] == [--master] ( - set MASTER=%2 - ) - shift -goto loop -:continue - -if [%MASTER%] == [yarn-cluster] ( - set SPARK_SUBMIT_DEPLOY_MODE=cluster -) - -rem For client mode, the driver will be launched in the same JVM that launches -rem SparkSubmit, so we may need to read the properties file for any extra class -rem paths, library paths, java options and memory early on. Otherwise, it will -rem be too late by the time the driver JVM has started. - -if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( - if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( - rem Parse the properties file only if the special configs exist - for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ - %SPARK_SUBMIT_PROPERTIES_FILE%') do ( - set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - ) - ) -) - -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% +set CLASS=org.apache.spark.deploy.SparkSubmit +%~dp0spark-class2.cmd %CLASS% %* diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 000000000000..464c29f36942 --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,23 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +source "$SPARK_HOME"/bin/load-spark-env.sh +export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 000000000000..d7b60183ca8e --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 000000000000..e47f22c7300b --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/bin/utils.sh b/bin/utils.sh deleted file mode 100755 index 22ea2b9a6d58..000000000000 --- a/bin/utils.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Gather all spark-submit options into SUBMISSION_OPTS -function gatherSparkSubmitOpts() { - - if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then - echo "Function for printing usage of $0 is not set." 1>&2 - echo "Please set usage function to shell variable 'SUBMIT_USAGE_FUNCTION' in $0" 1>&2 - exit 1 - fi - - # NOTE: If you add or remove spark-sumbmit options, - # modify NOT ONLY this script but also SparkSubmitArgument.scala - SUBMISSION_OPTS=() - APPLICATION_OPTS=() - while (($#)); do - case "$1" in - --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ - --conf | --properties-file | --driver-memory | --driver-java-options | \ - --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ - --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) - if [[ $# -lt 2 ]]; then - "$SUBMIT_USAGE_FUNCTION" - exit 1; - fi - SUBMISSION_OPTS+=("$1"); shift - SUBMISSION_OPTS+=("$1"); shift - ;; - - --verbose | -v | --supervise) - SUBMISSION_OPTS+=("$1"); shift - ;; - - *) - APPLICATION_OPTS+=("$1"); shift - ;; - esac - done - - export SUBMISSION_OPTS - export APPLICATION_OPTS -} diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd deleted file mode 100644 index 1082a952dac9..000000000000 --- a/bin/windows-utils.cmd +++ /dev/null @@ -1,59 +0,0 @@ -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem Gather all spark-submit options into SUBMISSION_OPTS - -set SUBMISSION_OPTS= -set APPLICATION_OPTS= - -rem NOTE: If you add or remove spark-sumbmit options, -rem modify NOT ONLY this script but also SparkSubmitArgument.scala - -:OptsLoop -if "x%1"=="x" ( - goto :OptsLoopEnd -) - -SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--py-files\> \<--files\>" -SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>" -SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>" -SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>" -SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\>" - -echo %1 | findstr %opts% >nul -if %ERRORLEVEL% equ 0 ( - if "x%2"=="x" ( - echo "%1" requires an argument. >&2 - exit /b 1 - ) - set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1 %2 - shift - shift - goto :OptsLoop -) -echo %1 | findstr "\<--verbose\> \<-v\> \<--supervise\>" >nul -if %ERRORLEVEL% equ 0 ( - set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1 - shift - goto :OptsLoop -) -set APPLICATION_OPTS=%APPLICATION_OPTS% %1 -shift -goto :OptsLoop - -:OptsLoopEnd -exit /b 0 diff --git a/build/mvn b/build/mvn index a87c5a26230c..ec0380afad31 100755 --- a/build/mvn +++ b/build/mvn @@ -21,6 +21,8 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # Preserve the calling directory _CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" # Installs any application tarball given a URL, the expected tarball name, # and, optionally, a checkable binary path to determine if the binary has @@ -34,14 +36,14 @@ install_app() { local binary="${_DIR}/$3" # setup `curl` and `wget` silent options if we're running on Jenkins - local curl_opts="" + local curl_opts="-L" local wget_opts="" if [ -n "$AMPLAB_JENKINS" ]; then - curl_opts="-s" - wget_opts="--quiet" + curl_opts="-s ${curl_opts}" + wget_opts="--quiet ${wget_opts}" else - curl_opts="--progress-bar" - wget_opts="--progress=bar:force" + curl_opts="--progress-bar ${curl_opts}" + wget_opts="--progress=bar:force ${wget_opts}" fi if [ -z "$3" -o ! -f "$binary" ]; then @@ -49,11 +51,11 @@ install_app() { # check if we have curl installed # download application [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ - echo "exec: curl ${curl_opts} ${remote_tarball}" && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ - echo "exec: wget ${wget_opts} ${remote_tarball}" && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit [ ! -f "${local_tarball}" ] && \ @@ -67,17 +69,20 @@ install_app() { # Install maven under the build/ folder install_mvn() { + local MVN_VERSION="3.3.3" + install_app \ - "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ - "apache-maven-3.2.5-bin.tar.gz" \ - "apache-maven-3.2.5/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "apache-maven-${MVN_VERSION}-bin.tar.gz" \ + "apache-maven-${MVN_VERSION}/bin/mvn" + + MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn" } # Install zinc under the build/ folder install_zinc() { local zinc_path="zinc-0.3.5.3/bin/zinc" - [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ "http://downloads.typesafe.com/zinc/0.3.5.3" \ "zinc-0.3.5.3.tgz" \ @@ -103,28 +108,23 @@ install_scala() { SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" } -# Determines if a given application is already installed. If not, will attempt -# to install -## Arg1 - application name -## Arg2 - Alternate path to local install under build/ dir -check_and_install_app() { - # create the local environment variable in uppercase - local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN" - # some black magic to set the generated app variable (i.e. MVN_BIN) into the - # environment - eval "${app_bin}=`which $1 2>/dev/null`" - - if [ -z "`which $1 2>/dev/null`" ]; then - install_$1 - fi -} - # Setup healthy defaults for the Zinc port if none were provided from # the environment ZINC_PORT=${ZINC_PORT:-"3030"} -# Check and install all applications necessary to build Spark -check_and_install_app "mvn" +# Check for the `--force` flag dictating that `mvn` should be downloaded +# regardless of whether the system already has a `mvn` install +if [ "$1" == "--force" ]; then + FORCE_MVN=1 + shift +fi + +# Install Maven if necessary +MVN_BIN="$(command -v mvn)" + +if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then + install_mvn +fi # Install the proper version of Scala and Zinc for the build install_zinc @@ -135,15 +135,18 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it -if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then - ${ZINC_BIN} -shutdown +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status -port ${ZINC_PORT}`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} + ${ZINC_BIN} -shutdown -port ${ZINC_PORT} ${ZINC_BIN} -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi # Set any `mvn` options if not already present -export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"} +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} + +echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} "$@" +${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" diff --git a/build/sbt b/build/sbt index 28ebb64f7197..cc3203d79bcc 100755 --- a/build/sbt +++ b/build/sbt @@ -125,4 +125,32 @@ loadConfigFile() { [[ -f "$etc_sbt_opts_file" ]] && set -- $(loadConfigFile "$etc_sbt_opts_file") "$@" [[ -f "$sbt_opts_file" ]] && set -- $(loadConfigFile "$sbt_opts_file") "$@" +exit_status=127 +saved_stty="" + +restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +saveSttySettings() { + saved_stty=$(stty -g 2>/dev/null) + if [[ ! $? ]]; then + saved_stty="" + fi +} + +saveSttySettings +trap onExit INT + run "$@" + +exit_status=$? +onExit diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 5e0c640fa591..615f84839465 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -38,8 +38,7 @@ dlog () { acquire_sbt_jar () { SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` - URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar - URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar JAR=build/sbt-launch-${SBT_VERSION}.jar sbt_jar=$JAR @@ -51,9 +50,11 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ + mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + wget --quiet ${URL1} -O "${JAR_DL}" &&\ + mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 @@ -81,7 +82,7 @@ execRunner () { echo "" } - exec "$@" + "$@" } addJava () { diff --git a/conf/docker.properties.template b/conf/docker.properties.template new file mode 100644 index 000000000000..26e3bfd9c5b9 --- /dev/null +++ b/conf/docker.properties.template @@ -0,0 +1,3 @@ +spark.mesos.executor.docker.image: +spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro +spark.mesos.executor.home: /opt/spark diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 89eec7d4b7f6..74c5cea94403 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -6,7 +6,13 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 464c14457e53..7f17bc7eea4f 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wild card "*" can be used as an instance name, in +# and "applications". A wildcard "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wild card properties can be overridden by more specific properties. +# 3. Wildcard properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,6 +47,13 @@ # instance master and applications. MetricsServlet may not be configured by self. # +## List of available common sources and their properties. + +# org.apache.spark.metrics.source.JvmSource +# Note: Currently, JvmSource is the only available common source +# to add additionaly to an instance, to enable this, +# set the "class" option to its fully qulified class name (see examples below) + ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink @@ -122,6 +129,15 @@ #worker.sink.csv.unit=minutes +# Enable Slf4jSink for all instances by class name +#*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink + +# Polling period for Slf4JSink +#*.sink.slf4j.period=1 + +#*.sink.slf4j.unit=minutes + + # Enable jvm source for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 0886b0276fb9..c05fe381a36a 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -15,14 +15,14 @@ # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program # - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data -# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos +# - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. @@ -38,7 +38,9 @@ # - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") +# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/pom.xml b/core/pom.xml index 2c115683fce6..4f79d71bf85f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -34,6 +34,11 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + com.google.guava guava @@ -41,40 +46,19 @@ com.twitter chill_${scala.binary.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - + + + org.apache.spark + spark-launcher_${scala.binary.version} + ${project.version} org.apache.spark @@ -86,6 +70,11 @@ spark-network-shuffle_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + net.java.dev.jets3t jets3t @@ -132,6 +121,13 @@ jetty-servlet compile + + + org.eclipse.jetty.orbit + javax.servlet + ${orbit.version} + org.apache.commons @@ -207,6 +203,14 @@ json4s-jackson_${scala.binary.version} 3.2.10 + + com.sun.jersey + jersey-server + + + com.sun.jersey + jersey-core + org.apache.mesos mesos @@ -236,15 +240,33 @@ io.dropwizard.metrics metrics-graphite + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.module + jackson-module-scala_${scala.binary.version} + org.apache.derby derby test + + org.apache.ivy + ivy + + + oro + + oro + ${oro.version} + org.tachyonproject tachyon-client - 0.5.0 + 0.7.1 org.apache.hadoop @@ -255,62 +277,50 @@ curator-recipes - org.eclipse.jetty - jetty-jsp - - - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server + org.tachyonproject + tachyon-underfs-glusterfs - org.eclipse.jetty - jetty-servlet - - - junit - junit - - - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test + - org.mockito - mockito-all + xml-apis + xml-apis test - org.scalacheck - scalacheck_${scala.binary.version} + org.hamcrest + hamcrest-core test - org.easymock - easymockclassextension + org.hamcrest + hamcrest-library test - asm - asm + org.mockito + mockito-core + test + + + org.scalacheck + scalacheck_${scala.binary.version} test @@ -324,9 +334,20 @@ test - org.spark-project + org.apache.curator + curator-test + test + + + net.razorvine pyrolite - 2.0.1 + 4.4 + + + net.razorvine + serpent + + net.sf.py4j @@ -338,35 +359,6 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-antrun-plugin - - - generate-resources - - run - - - - - - - - - - - maven-clean-plugin - - - - ${basedir}/../python/build - - - true - - org.apache.maven.plugins maven-dependency-plugin @@ -387,7 +379,7 @@ true true - guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server + guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server,jetty-security true @@ -395,24 +387,56 @@ - - - - src/main/resources - - - ../python - - pyspark/*.py - - - - ../python/build - - py4j/*.py - - - + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 646496f31350..fa9acf0a15b8 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -17,23 +17,7 @@ package org.apache.spark; -import org.apache.spark.scheduler.SparkListener; -import org.apache.spark.scheduler.SparkListenerApplicationEnd; -import org.apache.spark.scheduler.SparkListenerApplicationStart; -import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; -import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; -import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorAdded; -import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorRemoved; -import org.apache.spark.scheduler.SparkListenerJobEnd; -import org.apache.spark.scheduler.SparkListenerJobStart; -import org.apache.spark.scheduler.SparkListenerStageCompleted; -import org.apache.spark.scheduler.SparkListenerStageSubmitted; -import org.apache.spark.scheduler.SparkListenerTaskEnd; -import org.apache.spark.scheduler.SparkListenerTaskGettingResult; -import org.apache.spark.scheduler.SparkListenerTaskStart; -import org.apache.spark.scheduler.SparkListenerUnpersistRDD; +import org.apache.spark.scheduler.*; /** * Java clients should extend this class instead of implementing @@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } @Override public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + } diff --git a/core/src/main/java/org/apache/spark/JobExecutionStatus.java b/core/src/main/java/org/apache/spark/JobExecutionStatus.java index 6e161313702b..0287fb79f8dd 100644 --- a/core/src/main/java/org/apache/spark/JobExecutionStatus.java +++ b/core/src/main/java/org/apache/spark/JobExecutionStatus.java @@ -17,9 +17,15 @@ package org.apache.spark; +import org.apache.spark.util.EnumUtil; + public enum JobExecutionStatus { RUNNING, SUCCEEDED, FAILED, - UNKNOWN + UNKNOWN; + + public static JobExecutionStatus fromString(String str) { + return EnumUtil.parseIgnoreCase(JobExecutionStatus.class, str); + } } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index fbc566695905..1214d05ba606 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { onEvent(executorRemoved); } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + } diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java index 2090efd3b999..d4c42b38ac22 100644 --- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -23,11 +23,13 @@ // See // http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html abstract class JavaSparkContextVarargsWorkaround { - public JavaRDD union(JavaRDD... rdds) { + + @SafeVarargs + public final JavaRDD union(JavaRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList> rest = new ArrayList>(rdds.length - 1); + List> rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } @@ -38,18 +40,19 @@ public JavaDoubleRDD union(JavaDoubleRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList rest = new ArrayList(rdds.length - 1); + List rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } return union(rdds[0], rest); } - public JavaPairRDD union(JavaPairRDD... rdds) { + @SafeVarargs + public final JavaPairRDD union(JavaPairRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList> rest = new ArrayList>(rdds.length - 1); + List> rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } @@ -57,7 +60,7 @@ public JavaPairRDD union(JavaPairRDD... rdds) { } // These methods take separate "first" and "rest" elements to avoid having the same type erasure - abstract public JavaRDD union(JavaRDD first, List> rest); - abstract public JavaDoubleRDD union(JavaDoubleRDD first, List rest); - abstract public JavaPairRDD union(JavaPairRDD first, List> rest); + public abstract JavaRDD union(JavaRDD first, List> rest); + public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest); + public abstract JavaPairRDD union(JavaPairRDD first, List> rest); } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java new file mode 100644 index 000000000000..38e410c5debe --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A zero-argument function that returns an R. + */ +public interface Function0 extends Serializable { + public R call() throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java new file mode 100644 index 000000000000..0e58bb4f7101 --- /dev/null +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.reflect.ClassTag; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.Platform; + +/** + * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + * Our shuffle write path doesn't actually use this serializer (since we end up calling the + * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + * around this, we pass a dummy no-op serializer. + */ +@Private +public final class DummySerializerInstance extends SerializerInstance { + + public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); + + private DummySerializerInstance() { } + + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + Platform.throwException(e); + } + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + Platform.throwException(e); + } + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + throw new UnsupportedOperationException(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 000000000000..0b8b604e1849 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specific, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private DiskBlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (DiskBlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 000000000000..656ea0401a14 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java new file mode 100644 index 000000000000..4ee6a82c0423 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +/** + * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). + */ +final class PackedRecordPointer { + + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + + /** + * The maximum partition identifier that can be encoded. Note that partition ids start from 0. + */ + static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + + /** Bit mask for the lower 40 bits of a long. */ + private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; + + /** Bit mask for the upper 24 bits of a long */ + private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; + + /** Bit mask for the lower 27 bits of a long. */ + private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Pack a record address and partition id into a single word. + * + * @param recordPointer a record pointer encoded by TaskMemoryManager. + * @param partitionId a shuffle partition id (maximum value of 2^24). + * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. + */ + public static long packPointer(long recordPointer, int partitionId) { + assert (partitionId <= MAXIMUM_PARTITION_ID); + // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. + // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); + return (((long) partitionId) << 40) | compressedAddress; + } + + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } + + public int getPartitionId() { + return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); + } + + public long getRecordPointer() { + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; + return pageNumber | offsetInPage; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java new file mode 100644 index 000000000000..7bac0dc0bbeb --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** + * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + */ +final class SpillInfo { + final long[] partitionLengths; + final File file; + final TempShuffleBlockId blockId; + + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java new file mode 100644 index 000000000000..3d1ef0c48adc --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import scala.Tuple2; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + */ +final class UnsafeShuffleExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + private final int initialSize; + private final int numPartitions; + private final int pageSizeBytes; + @VisibleForTesting + final int maxRecordSizeBytes; + private final TaskMemoryManager taskMemoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + private final LinkedList spills = new LinkedList(); + + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + + // These variables are reset after spilling: + @Nullable private UnsafeShuffleInMemorySorter inMemSorter; + @Nullable private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + public UnsafeShuffleExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetrics writeMetrics) throws IOException { + this.taskMemoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.initialSize = initialSize; + this.peakMemoryUsedBytes = initialSize; + this.numPartitions = numPartitions; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.pageSizeBytes = (int) Math.min( + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); + this.maxRecordSizeBytes = pageSizeBytes - 4; + this.writeMetrics = writeMetrics; + initializeForWriting(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); + } + + /** + * Sorts the in-memory records and writes the sorted records to an on-disk file. + * This method does not free the sort data structures. + * + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. + */ + private void writeSortedFile(boolean isLastFile) throws IOException { + + final ShuffleWriteMetrics writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + // This call performs the actual sort. + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + inMemSorter.getSortedIterator(); + + // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this + // after SPARK-5581 is fixed. + DiskBlockObjectWriter writer; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + final SerializerInstance ser = DummySerializerInstance.INSTANCE; + + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + + int currentPartition = -1; + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + writeSortedFile(false); + final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + shuffleMemoryManager.release(inMemSorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + private long freeMemory() { + updatePeakMemoryUsed(); + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + taskMemoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupResources() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (inMemSorter != null) { + shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); + inMemSorter = null; + } + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + inMemSorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + growPointerArrayIfNecessary(); + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > pageSizeBytes) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + pageSizeBytes + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); + } + } + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = pageSizeBytes; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the shuffle sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + int partitionId) throws IOException { + + growPointerArrayIfNecessary(); + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; + } + final Object dataPageBaseObject = dataPage.getBaseObject(); + + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ + public SpillInfo[] closeAndGetSpills() throws IOException { + try { + if (inMemSorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java new file mode 100644 index 000000000000..5bab501da936 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; + +final class UnsafeShuffleInMemorySorter { + + private final Sorter sorter; + private static final class SortComparator implements Comparator { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + } + private static final SortComparator SORT_COMPARATOR = new SortComparator(); + + /** + * An array of record pointers and partition ids that have been encoded by + * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating + * records. + */ + private long[] pointerArray; + + /** + * The position in the pointer array where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeShuffleInMemorySorter(int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize]; + this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 1 < pointerArray.length; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + /** + * Inserts a record to be sorted. + * + * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to + * certain pointer compression techniques used by the sorter, the sort can + * only operate on pointers that point to locations in the first + * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page. + * @param partitionId the partition id, which must be less than or equal to + * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. + */ + public void insertRecord(long recordPointer, int partitionId) { + if (!hasSpaceForAnotherRecord()) { + if (pointerArray.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort pointer array has reached maximum size"); + } else { + expandPointerArray(); + } + } + pointerArray[pointerArrayInsertPosition] = + PackedRecordPointer.packPointer(recordPointer, partitionId); + pointerArrayInsertPosition++; + } + + /** + * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. + */ + public static final class UnsafeShuffleSorterIterator { + + private final long[] pointerArray; + private final int numRecords; + final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + private int position = 0; + + public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + this.numRecords = numRecords; + this.pointerArray = pointerArray; + } + + public boolean hasNext() { + return position < numRecords; + } + + public void loadNext() { + packedRecordPointer.set(pointerArray[position]); + position++; + } + } + + /** + * Return an iterator over record pointers in sorted order. + */ + public UnsafeShuffleSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); + return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java new file mode 100644 index 000000000000..a66d74ee4478 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.util.collection.SortDataFormat; + +final class UnsafeShuffleSortDataFormat extends SortDataFormat { + + public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + + private UnsafeShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { + reuse.set(data[pos]); + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + final long temp = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = temp; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos] = src[srcPos]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos, dst, dstPos, length); + } + + @Override + public long[] allocate(int length) { + return new long[length]; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 000000000000..fdb309e365f6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import javax.annotation.Nullable; +import java.io.*; +import java.nio.channels.FileChannel; +import java.util.Iterator; + +import scala.Option; +import scala.Product2; +import scala.collection.JavaConverters; +import scala.collection.immutable.Map; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.*; +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +@Private +public class UnsafeShuffleWriter extends ShuffleWriter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + @VisibleForTesting + static final int INITIAL_SORT_BUFFER_SIZE = 4096; + + private final BlockManager blockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final boolean transferToEnabled; + + @Nullable private MapStatus mapStatus; + @Nullable private UnsafeShuffleExternalSorter sorter; + private long peakMemoryUsedBytes = 0; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf sparkConf) throws IOException { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + throw new IllegalArgumentException( + "UnsafeShuffleWriter can only be used for shuffles with at most " + + UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + } + this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); + } + + @VisibleForTesting + public int maxRecordSizeBytes() { + assert(sorter != null); + return sorter.maxRecordSizeBytes; + } + + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConverters.asScalaIteratorConverter(records).asScala()); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we encountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (sorter != null) { + try { + sorter.cleanupResources(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } + } + } + } + + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + INITIAL_SORT_BUFFER_SIZE, + partitioner.numPartitions(), + sparkConf, + writeMetrics); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); + serBuffer = null; + serOutputStream = null; + final SpillInfo[] spills = sorter.closeAndGetSpills(); + sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } + + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } + } else { + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incShuffleBytesWritten(outputFile.length()); + return partitionLengths; + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = + new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } + } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + Closeables.close(mergedFileOutputStream, threwException); + } + return partitionLengths; + } + + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], + bytesToTransfer, + mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; + } + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + Closeables.close(spillInputChannels[i], threwException); + } + Closeables.close(mergedFileOutputChannel, threwException); + } + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) + Map> internalAccumulators = + taskContext.internalMetricsToAccumulators(); + if (internalAccumulators != null) { + internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) + .add(getPeakMemoryUsedBytes()); + } + + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupResources(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/status/api/v1/ApplicationStatus.java b/core/src/main/java/org/apache/spark/status/api/v1/ApplicationStatus.java new file mode 100644 index 000000000000..8c7dcf776fda --- /dev/null +++ b/core/src/main/java/org/apache/spark/status/api/v1/ApplicationStatus.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1; + +import org.apache.spark.util.EnumUtil; + +public enum ApplicationStatus { + COMPLETED, + RUNNING; + + public static ApplicationStatus fromString(String str) { + return EnumUtil.parseIgnoreCase(ApplicationStatus.class, str); + } + +} diff --git a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java new file mode 100644 index 000000000000..9dbb565aab70 --- /dev/null +++ b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1; + +import org.apache.spark.util.EnumUtil; + +public enum StageStatus { + ACTIVE, + COMPLETE, + FAILED, + PENDING; + + public static StageStatus fromString(String str) { + return EnumUtil.parseIgnoreCase(StageStatus.class, str); + } +} diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java new file mode 100644 index 000000000000..f19ed01d5aeb --- /dev/null +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1; + +import org.apache.spark.util.EnumUtil; + +import java.util.HashSet; +import java.util.Set; + +public enum TaskSorting { + ID, + INCREASING_RUNTIME("runtime"), + DECREASING_RUNTIME("-runtime"); + + private final Set alternateNames; + private TaskSorting(String... names) { + alternateNames = new HashSet(); + for (String n: names) { + alternateNames.add(n); + } + } + + public static TaskSorting fromString(String str) { + String lower = str.toLowerCase(); + for (TaskSorting t: values()) { + if (t.alternateNames.contains(lower)) { + return t; + } + } + return EnumUtil.parseIgnoreCase(TaskSorting.class, str); + } + +} diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java new file mode 100644 index 000000000000..dc2aa30466cc --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; + +/** + * Intercepts write calls and tracks total time spent writing in order to update shuffle write + * metrics. Not thread safe. + */ +@Private +public final class TimeTrackingOutputStream extends OutputStream { + + private final ShuffleWriteMetrics writeMetrics; + private final OutputStream outputStream; + + public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + this.writeMetrics = writeMetrics; + this.outputStream = outputStream; + } + + @Override + public void write(int b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b, off, len); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void flush() throws IOException { + final long startTime = System.nanoTime(); + outputStream.flush(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void close() throws IOException { + final long startTime = System.nanoTime(); + outputStream.close(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java new file mode 100644 index 000000000000..b24eed3952fd --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -0,0 +1,832 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import javax.annotation.Nullable; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.bitset.BitSet; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * An append-only hash map where keys and values are contiguous regions of bytes. + * + * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, + * which is guaranteed to exhaust the space. + * + * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should + * probably be using sorting instead of hashing for better cache locality. + * + * The key and values under the hood are stored together, in the following format: + * Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4 + * Bytes 4 to 8: len(k) + * Bytes 8 to 8 + len(k): key data + * Bytes 8 + len(k) to 8 + len(k) + len(v): value data + * + * This means that the first four bytes store the entire record (key + value) length. This format + * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, + * so we can pass records from this map directly into the sorter to sort records in place. + */ +public final class BytesToBytesMap { + + private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); + + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); + + private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + + /** + * Special record length that is placed after the last record in a data page. + */ + private static final int END_OF_PAGE_MARKER = -1; + + private final TaskMemoryManager taskMemoryManager; + + private final ShuffleMemoryManager shuffleMemoryManager; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List dataPages = new LinkedList(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. This does not incorporate the page's base offset. + */ + private long pageCursor = 0; + + /** + * The maximum number of keys that BytesToBytesMap supports. The hash table has to be + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, + * since that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). + */ + @VisibleForTesting + static final int MAX_CAPACITY = (1 << 29); + + // This choice of page table size and page size means that we can address up to 500 gigabytes + // of memory. + + /** + * A single array to store the key and value. + * + * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, + * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. + */ + @Nullable private LongArray longArray; + // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode + // and exploit word-alignment to use fewer bits to hold the address. This might let us store + // only one long per map entry, increasing the chance that this array will fit in cache at the + // expense of maybe performing more lookups if we have hash collisions. Say that we stored only + // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte + // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a + // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store + // full base addresses in the page table for off-heap mode so that we can reconstruct the full + // absolute memory addresses. + + /** + * A {@link BitSet} used to track location of the map where the key is set. + * Size of the bitset should be half of the size of the long array. + */ + @Nullable private BitSet bitset; + + private final double loadFactor; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private final long pageSizeBytes; + + /** + * Number of keys defined in the map. + */ + private int numElements; + + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ + private int growthThreshold; + + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + * This is a strength reduction optimization; we're essentially performing a modulus operation, + * but doing so with a bitmask because this is a power-of-2-sized hash map. + */ + private int mask; + + /** + * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + */ + private final Location loc; + + private final boolean enablePerfMetrics; + + private long timeSpentResizingNs = 0; + + private long numProbes = 0; + + private long numKeyLookups = 0; + + private long numHashCollisions = 0; + + private long peakMemoryUsedBytes = 0L; + + public BytesToBytesMap( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + int initialCapacity, + double loadFactor, + long pageSizeBytes, + boolean enablePerfMetrics) { + this.taskMemoryManager = taskMemoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.loadFactor = loadFactor; + this.loc = new Location(); + this.pageSizeBytes = pageSizeBytes; + this.enablePerfMetrics = enablePerfMetrics; + if (initialCapacity <= 0) { + throw new IllegalArgumentException("Initial capacity must be greater than 0"); + } + if (initialCapacity > MAX_CAPACITY) { + throw new IllegalArgumentException( + "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); + } + if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); + } + allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); + } + + public BytesToBytesMap( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + int initialCapacity, + long pageSizeBytes) { + this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + } + + public BytesToBytesMap( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + int initialCapacity, + long pageSizeBytes, + boolean enablePerfMetrics) { + this( + taskMemoryManager, + shuffleMemoryManager, + initialCapacity, + 0.70, + pageSizeBytes, + enablePerfMetrics); + } + + /** + * Returns the number of keys defined in the map. + */ + public int numElements() { return numElements; } + + public static final class BytesToBytesMapIterator implements Iterator { + + private final int numRecords; + private final Iterator dataPagesIterator; + private final Location loc; + + private MemoryBlock currentPage = null; + private int currentRecordNumber = 0; + private Object pageBaseObject; + private long offsetInPage; + + // If this iterator destructive or not. When it is true, it frees each page as it moves onto + // next one. + private boolean destructive = false; + private BytesToBytesMap bmap; + + private BytesToBytesMapIterator( + int numRecords, Iterator dataPagesIterator, Location loc, + boolean destructive, BytesToBytesMap bmap) { + this.numRecords = numRecords; + this.dataPagesIterator = dataPagesIterator; + this.loc = loc; + this.destructive = destructive; + this.bmap = bmap; + if (dataPagesIterator.hasNext()) { + advanceToNextPage(); + } + } + + private void advanceToNextPage() { + if (destructive && currentPage != null) { + dataPagesIterator.remove(); + this.bmap.taskMemoryManager.freePage(currentPage); + this.bmap.shuffleMemoryManager.release(currentPage.size()); + } + currentPage = dataPagesIterator.next(); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public Location next() { + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + if (totalLength == END_OF_PAGE_MARKER) { + advanceToNextPage(); + totalLength = Platform.getInt(pageBaseObject, offsetInPage); + } + loc.with(currentPage, offsetInPage); + offsetInPage += 4 + totalLength; + currentRecordNumber++; + return loc; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + + /** + * Returns an iterator for iterating over the entries of this map. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public BytesToBytesMapIterator iterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + } + + /** + * Returns a destructive iterator for iterating over the entries of this map. It frees each page + * as it moves onto next one. Notice: it is illegal to call any method on the map after + * `destructiveIterator()` has been called. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public BytesToBytesMapIterator destructiveIterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes) { + safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes, + Location loc) { + assert(bitset != null); + assert(longArray != null); + + if (enablePerfMetrics) { + numKeyLookups++; + } + final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + int pos = hashcode & mask; + int step = 1; + while (true) { + if (enablePerfMetrics) { + numProbes++; + } + if (!bitset.isSet(pos)) { + // This is a new key. + loc.with(pos, hashcode, false); + return; + } else { + long stored = longArray.get(pos * 2 + 1); + if ((int) (stored) == hashcode) { + // Full hash code matches. Let's compare the keys for equality. + loc.with(pos, hashcode, true); + if (loc.getKeyLength() == keyRowLengthBytes) { + final MemoryLocation keyAddress = loc.getKeyAddress(); + final Object storedKeyBaseObject = keyAddress.getBaseObject(); + final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final boolean areEqual = ByteArrayMethods.arrayEquals( + keyBaseObject, + keyBaseOffset, + storedKeyBaseObject, + storedKeyBaseOffset, + keyRowLengthBytes + ); + if (areEqual) { + return; + } else { + if (enablePerfMetrics) { + numHashCollisions++; + } + } + } + } + } + pos = (pos + step) & mask; + step++; + } + } + + /** + * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. + */ + public final class Location { + /** An index into the hash map's Long array */ + private int pos; + /** True if this location points to a position where a key is defined, false otherwise */ + private boolean isDefined; + /** + * The hashcode of the most recent key passed to + * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to + * avoid re-hashing the key when storing a value for that key. + */ + private int keyHashcode; + private final MemoryLocation keyMemoryLocation = new MemoryLocation(); + private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private int keyLength; + private int valueLength; + + /** + * Memory page containing the record. Only set if created by {@link BytesToBytesMap#iterator()}. + */ + @Nullable private MemoryBlock memoryPage; + + private void updateAddressesAndSizes(long fullKeyAddress) { + updateAddressesAndSizes( + taskMemoryManager.getPage(fullKeyAddress), + taskMemoryManager.getOffsetInPage(fullKeyAddress)); + } + + private void updateAddressesAndSizes(final Object page, final long offsetInPage) { + long position = offsetInPage; + final int totalLength = Platform.getInt(page, position); + position += 4; + keyLength = Platform.getInt(page, position); + position += 4; + valueLength = totalLength - keyLength - 4; + + keyMemoryLocation.setObjAndOffset(page, position); + + position += keyLength; + valueMemoryLocation.setObjAndOffset(page, position); + } + + private Location with(int pos, int keyHashcode, boolean isDefined) { + assert(longArray != null); + this.pos = pos; + this.isDefined = isDefined; + this.keyHashcode = keyHashcode; + if (isDefined) { + final long fullKeyAddress = longArray.get(pos * 2); + updateAddressesAndSizes(fullKeyAddress); + } + return this; + } + + private Location with(MemoryBlock page, long offsetInPage) { + this.isDefined = true; + this.memoryPage = page; + updateAddressesAndSizes(page.getBaseObject(), offsetInPage); + return this; + } + + /** + * Returns the memory page that contains the current record. + * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. + */ + public MemoryBlock getMemoryPage() { + return this.memoryPage; + } + + /** + * Returns true if the key is defined at this position, and false otherwise. + */ + public boolean isDefined() { + return isDefined; + } + + /** + * Returns the address of the key defined at this position. + * This points to the first byte of the key data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getKeyAddress() { + assert (isDefined); + return keyMemoryLocation; + } + + /** + * Returns the length of the key defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getKeyLength() { + assert (isDefined); + return keyLength; + } + + /** + * Returns the address of the value defined at this position. + * This points to the first byte of the value data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getValueAddress() { + assert (isDefined); + return valueMemoryLocation; + } + + /** + * Returns the length of the value defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getValueLength() { + assert (isDefined); + return valueLength; + } + + /** + * Store a new key and value. This method may only be called once for a given key; if you want + * to update the value associated with a key, then you can directly manipulate the bytes stored + * at the value address. The return value indicates whether the put succeeded or whether it + * failed because additional memory could not be acquired. + *

+ * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

+ *

+ * The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

+ *

+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` + * will return information on the data stored by this `putNewKey` call. + *

+ *

+ * As an example usage, here's the proper way to store a new key: + *

+ *
+     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+     *   if (!loc.isDefined()) {
+     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *       // handle failure to grow map (by spilling, for example)
+     *     }
+     *   }
+     * 
+ *

+ * Unspecified behavior if the key is not defined. + *

+ * + * @return true if the put() was successful and false if the put() failed because memory could + * not be acquired. + */ + public boolean putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes) { + assert (!isDefined) : "Can only set value once for a key"; + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + assert(bitset != null); + assert(longArray != null); + + if (numElements == MAX_CAPACITY) { + throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + } + + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (value) + final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; + + // --- Figure out where to insert the new record --------------------------------------------- + + final MemoryBlock dataPage; + final Object dataPageBaseObject; + final long dataPageInsertOffset; + boolean useOverflowPage = requiredSize > pageSizeBytes - 8; + if (useOverflowPage) { + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryRequested = requiredSize + 8; + final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryGranted != memoryRequested) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", memoryRequested); + return false; + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested); + dataPages.add(overflowPage); + dataPage = overflowPage; + dataPageBaseObject = overflowPage.getBaseObject(); + dataPageInsertOffset = overflowPage.getBaseOffset(); + } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { + // The record can fit in a data page, but either we have not allocated any pages yet or + // the current page does not have enough space. + if (currentDataPage != null) { + // There wasn't enough space in the current page, so write an end-of-page marker: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; + Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } + if (!acquireNewPage()) { + return false; + } + dataPage = currentDataPage; + dataPageBaseObject = currentDataPage.getBaseObject(); + dataPageInsertOffset = currentDataPage.getBaseOffset(); + } else { + // There is enough space in the current data page. + dataPage = currentDataPage; + dataPageBaseObject = currentDataPage.getBaseObject(); + dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor; + } + + // --- Append the key and value data to the current data page -------------------------------- + + long insertCursor = dataPageInsertOffset; + + // Compute all of our offsets up-front: + final long recordOffset = insertCursor; + insertCursor += 4; + final long keyLengthOffset = insertCursor; + insertCursor += 4; + final long keyDataOffsetInPage = insertCursor; + insertCursor += keyLengthBytes; + final long valueDataOffsetInPage = insertCursor; + insertCursor += valueLengthBytes; // word used to store the value size + + Platform.putInt(dataPageBaseObject, recordOffset, + keyLengthBytes + valueLengthBytes + 4); + Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); + // Copy the key + Platform.copyMemory( + keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + valueDataOffsetInPage, valueLengthBytes); + + // --- Update bookeeping data structures ----------------------------------------------------- + + if (useOverflowPage) { + // Store the end-of-page marker at the end of the data page + Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + } else { + pageCursor += requiredSize; + } + + numElements++; + bitset.set(pos); + final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( + dataPage, recordOffset); + longArray.set(pos * 2, storedKeyAddress); + longArray.set(pos * 2 + 1, keyHashcode); + updateAddressesAndSizes(storedKeyAddress); + isDefined = true; + if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) { + growAndRehash(); + } + return true; + } + } + + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ + private void allocate(int capacity) { + assert (capacity >= 0); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); + assert (capacity <= MAX_CAPACITY); + longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); + bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); + + this.growthThreshold = (int) (capacity * loadFactor); + this.mask = capacity - 1; + } + + /** + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent and can be called multiple times. + */ + public void free() { + updatePeakMemoryUsed(); + longArray = null; + bitset = null; + Iterator dataPagesIterator = dataPages.iterator(); + while (dataPagesIterator.hasNext()) { + MemoryBlock dataPage = dataPagesIterator.next(); + dataPagesIterator.remove(); + taskMemoryManager.freePage(dataPage); + shuffleMemoryManager.release(dataPage.size()); + } + assert(dataPages.isEmpty()); + } + + public TaskMemoryManager getTaskMemoryManager() { + return taskMemoryManager; + } + + public ShuffleMemoryManager getShuffleMemoryManager() { + return shuffleMemoryManager; + } + + public long getPageSizeBytes() { + return pageSizeBytes; + } + + /** + * Returns the total amount of memory, in bytes, consumed by this map's managed structures. + */ + public long getTotalMemoryConsumption() { + long totalDataPagesSize = 0L; + for (MemoryBlock dataPage : dataPages) { + totalDataPagesSize += dataPage.size(); + } + return totalDataPagesSize + + ((bitset != null) ? bitset.memoryBlock().size() : 0L) + + ((longArray != null) ? longArray.memoryBlock().size() : 0L); + } + + private void updatePeakMemoryUsed() { + long mem = getTotalMemoryConsumption(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + /** + * Returns the total amount of time spent resizing this map (in nanoseconds). + */ + public long getTimeSpentResizingNs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingNs; + } + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + + public long getNumHashCollisions() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return numHashCollisions; + } + + @VisibleForTesting + public int getNumDataPages() { + return dataPages.size(); + } + + /** + * Grows the size of the hash table and re-hash everything. + */ + @VisibleForTesting + void growAndRehash() { + assert(bitset != null); + assert(longArray != null); + + long resizeStartTime = -1; + if (enablePerfMetrics) { + resizeStartTime = System.nanoTime(); + } + // Store references to the old data structures to be used when we re-hash + final LongArray oldLongArray = longArray; + final BitSet oldBitSet = bitset; + final int oldCapacity = (int) oldBitSet.capacity(); + + // Allocate the new data structures + allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); + + // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) + for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { + final long keyPointer = oldLongArray.get(pos * 2); + final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + int newPos = hashcode & mask; + int step = 1; + boolean keepGoing = true; + + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!bitset.isSet(newPos)) { + bitset.set(newPos); + longArray.set(newPos * 2, keyPointer); + longArray.set(newPos * 2 + 1, hashcode); + keepGoing = false; + } else { + newPos = (newPos + step) & mask; + step++; + } + } + } + + if (enablePerfMetrics) { + timeSpentResizingNs += System.nanoTime() - resizeStartTime; + } + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java new file mode 100644 index 000000000000..20654e4eeaa0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +/** + * Interface that defines how we can grow the size of a hash map when it is over a threshold. + */ +public interface HashMapGrowthStrategy { + + int nextCapacity(int currentCapacity); + + /** + * Double the size of the hash map every time. + */ + HashMapGrowthStrategy DOUBLING = new Doubling(); + + class Doubling implements HashMapGrowthStrategy { + @Override + public int nextCapacity(int currentCapacity) { + assert (currentCapacity > 0); + // Guard against overflow + return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE; + } + } + +} diff --git a/core/src/main/java/org/apache/spark/util/EnumUtil.java b/core/src/main/java/org/apache/spark/util/EnumUtil.java new file mode 100644 index 000000000000..c40c7e727613 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/EnumUtil.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util; + +import com.google.common.base.Joiner; +import org.apache.spark.annotation.Private; + +@Private +public class EnumUtil { + public static > E parseIgnoreCase(Class clz, String str) { + E[] constants = clz.getEnumConstants(); + if (str == null) { + return null; + } + for (E e : constants) { + if (e.name().equalsIgnoreCase(str)) { + return e; + } + } + throw new IllegalArgumentException( + String.format("Illegal type='%s'. Supported type values: %s", + str, Joiner.on(", ").join(constants))); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 409e1a41c5d4..a90cc0e761f6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) { private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { + if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) + || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { if (runLen[n - 1] < runLen[n + 1]) n--; - mergeAt(n); - } else if (runLen[n] <= runLen[n + 1]) { - mergeAt(n); - } else { + } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } + mergeAt(n); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 000000000000..45b78829e4cf --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +@Private +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 000000000000..71b76d5ddfaa --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.primitives.UnsignedLongs; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final StringPrefixComparator STRING = new StringPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator(); + public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); + + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + + public static long computePrefix(UTF8String value) { + return value == null ? 0L : value.getPrefix(); + } + } + + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class BinaryPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + + public static long computePrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + /** + * TODO: If a wrapper for BinaryType is created (SPARK-8786), + * these codes below will be in the wrapper class. + */ + final int minLen = Math.min(bytes.length, 8); + long p = 0; + for (int i = 0; i < minLen; ++i) { + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) + << (56 - 8 * i); + } + return p; + } + } + } + + public static final class BinaryPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); + } + + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } + + public static final class DoublePrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); + } + + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java new file mode 100644 index 000000000000..09e425879220 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 000000000000..0c4ebde407cf --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 000000000000..fc364e0a895b --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import javax.annotation.Nullable; + +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private final long pageSizeBytes; + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private final TaskMemoryManager taskMemoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList<>(); + + private final LinkedList spillWriters = new LinkedList<>(); + + // These variables are reset after spilling: + @Nullable private UnsafeInMemorySorter inMemSorter; + // Whether the in-mem sorter is created internally, or passed in from outside. + // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. + private boolean isInMemSorterExternal = false; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + private long peakMemoryUsedBytes = 0; + + public static UnsafeExternalSorter createWithExistingInMemorySorter( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + UnsafeInMemorySorter inMemorySorter) throws IOException { + return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + } + + public static UnsafeExternalSorter create( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes) throws IOException { + return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); + } + + private UnsafeExternalSorter( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException { + this.taskMemoryManager = taskMemoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.fileBufferSizeBytes = 32 * 1024; + this.pageSizeBytes = pageSizeBytes; + this.writeMetrics = new ShuffleWriteMetrics(); + + if (existingInMemorySorter == null) { + initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); + } else { + this.isInMemSorterExternal = true; + this.inMemSorter = existingInMemorySorter; + } + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + cleanupResources(); + return null; + } + }); + } + + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + final long pointerArrayMemory = + UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory); + if (memoryAcquired != pointerArrayMemory) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory"); + } + + this.inMemSorter = + new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); + this.isInMemSorterExternal = false; + } + + /** + * Marks the current page as no-more-space-available, and as a result, either allocate a + * new page or spill when we see the next record. + */ + @VisibleForTesting + public void closeCurrentPage() { + freeSpaceInCurrentPage = 0; + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + public void spill() throws IOException { + assert(inMemSorter != null); + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + // We only write out contents of the inMemSorter if it is not empty. + if (inMemSorter.numRecords() > 0) { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + inMemSorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + } + + final long spillSize = freeMemory(); + // Note that this is more-or-less going to be a multiple of the page size, so wasted space in + // pages will currently be counted as memory spilled even though that space isn't actually + // written to disk. This also counts the space needed to store the sorter's pointer array. + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + /** + * Return the total memory usage of this sorter, including the data pages and the sorter's pointer + * array. + */ + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + + /** + * Free this sorter's in-memory data structures, including its data pages and pointer array. + * + * @return the number of bytes freed. + */ + private long freeMemory() { + updatePeakMemoryUsed(); + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + taskMemoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + if (inMemSorter != null) { + if (!isInMemSorterExternal) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + memoryFreed += sorterMemoryUsage; + shuffleMemoryManager.release(sorterMemoryUsage); + } + inMemSorter = null; + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Deletes any spill files created by this sorter. + */ + private void deleteSpillFiles() { + for (UnsafeSorterSpillWriter spill : spillWriters) { + File file = spill.getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + + /** + * Frees this sorter's in-memory data structures and cleans up its spill files. + */ + public void cleanupResources() { + deleteSpillFiles(); + freeMemory(); + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + inMemSorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + assert (requiredSpace <= pageSizeBytes); + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > pageSizeBytes) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + pageSizeBytes + ")"); + } else { + acquireNewPage(); + } + } + } + + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * + * If there is not enough space to allocate the new page, spill all existing ones + * and try again. If there is still not enough space, report error to the caller. + */ + private void acquireNewPage() throws IOException { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); + } + } + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = pageSizeBytes; + allocatedPages.add(currentPage); + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws IOException { + + growPointerArrayIfNecessary(); + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; + } + final Object dataPageBaseObject = dataPage.getBaseObject(); + + // --- Insert the record ---------------------------------------------------------------------- + + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, prefix); + } + + /** + * Write a key-value record to the sorter. The key and value will be put together in-memory, + * using the following format: + * + * record length (4 bytes), key length (4 bytes), key data, value data + * + * record length = key length + value length + 4 + */ + public void insertKVRecord( + Object keyBaseObj, long keyOffset, int keyLen, + Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException { + + growPointerArrayIfNecessary(); + final int totalSpaceRequired = keyLen + valueLen + 4 + 4; + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; + } + final Object dataPageBaseObject = dataPage.getBaseObject(); + + // --- Insert the record ---------------------------------------------------------------------- + + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + dataPagePosition += 4; + + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); + dataPagePosition += 4; + + Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + dataPagePosition += keyLen; + + Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, prefix); + } + + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(inMemSorter != null); + final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + spillMerger.addSpillIfNotEmpty(inMemoryIterator); + + return spillMerger.getSortedIterator(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 000000000000..f7787e1019c2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] pointerArray; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + static long getMemoryRequirementsForPointerArray(long numEntries) { + return numEntries * 2L * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 2 < pointerArray.length; + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandPointerArray(); + } + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; + } + + public static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + private SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length + recordLength = Platform.getInt(baseObject, baseOffset - 4); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public SortedIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 000000000000..d09c728a7a63 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 000000000000..16ac2e8d821b --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 000000000000..3874a9f9cbdb --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + final int numSpills) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(numSpills, comparator); + } + + /** + * Add an UnsafeSorterIterator to this merger + */ + public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + // We only add the spillReader to the priorityQueue if it is not empty. We do this to + // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator + // does not return wrong result because hasNext will returns true + // at least priorityQueue.size() times. If we allow n spillReaders in the + // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator. + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 000000000000..4989b05d63e2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.Platform; + +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private final File file; + private InputStream in; + private DataInputStream din; + + // Variables that change with every record read: + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + this.file = file; + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public void loadNext() throws IOException { + recordLength = din.readInt(); + keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + file.delete(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 000000000000..e59a84ff8d11 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.Platform; + +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ +final class UnsafeSorterSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private DiskBlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + Platform.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + writer.recordWritten(); + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public File getFile() { + return file; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties new file mode 100644 index 000000000000..689afea64f8d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -0,0 +1,16 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f6..27006e45e932 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -6,7 +6,11 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 14ba37d7c9bd..3c8ddddf07b1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -19,6 +19,9 @@ * to be registered after the page loads. */ $(function() { $("span.expand-additional-metrics").click(function(){ + var status = window.localStorage.getItem("expand-additional-metrics") == "true"; + status = !status; + // Expand the list of additional metrics. var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); $(additionalMetricsDiv).toggleClass('collapsed'); @@ -26,28 +29,62 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-additional-metrics", "" + status); }); + if (window.localStorage.getItem("expand-additional-metrics") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-additional-metrics", "false"); + $("span.expand-additional-metrics").trigger("click"); + } + stripeSummaryTable(); - $("input:checkbox").click(function() { - var column = "table ." + $(this).attr("name"); + $('input[type="checkbox"]').click(function() { + var name = $(this).attr("name") + var column = "table ." + name; + var status = window.localStorage.getItem(name) == "true"; + status = !status; $(column).toggle(); stripeSummaryTable(); + window.localStorage.setItem(name, "" + status); }); $("#select-all-metrics").click(function() { + var status = window.localStorage.getItem("select-all-metrics") == "true"; + status = !status; if (this.checked) { // Toggle all un-checked options. - $('input:checkbox:not(:checked)').trigger('click'); + $('input[type="checkbox"]:not(:checked)').trigger('click'); } else { // Toggle all checked options. - $('input:checkbox:checked').trigger('click'); + $('input[type="checkbox"]:checked').trigger('click'); } + window.localStorage.setItem("select-all-metrics", "" + status); + }); + + if (window.localStorage.getItem("select-all-metrics") == "true") { + $("#select-all-metrics").attr('checked', status); + } + + $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() { + var name = $(this).attr("name") + // If name is undefined, then skip it because it's the "select-all-metrics" checkbox + if (name && window.localStorage.getItem(name) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(name, "false"); + $(this).trigger("click") + } }); // Trigger a click on the checkbox if a user clicks the label next to it. $("span.additional-metric-title").click(function() { - $(this).parent().find('input:checkbox').trigger('click'); + $(this).parent().find('input[type="checkbox"]').trigger('click'); + }); + + // Trigger a double click on the span to show full job description. + $(".description-input").dblclick(function() { + $(this).removeClass("description-input").addClass("description-input-full"); }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js b/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js index 2934181c1006..acd6096e6743 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js +++ b/core/src/main/resources/org/apache/spark/ui/static/bootstrap-tooltip.js @@ -1,9 +1,9 @@ /* =========================================================== - * bootstrap-tooltip.js v2.2.2 - * http://twitter.github.com/bootstrap/javascript.html#tooltips + * bootstrap-tooltip.js v2.3.2 + * http://getbootstrap.com/2.3.2/javascript.html#tooltips * Inspired by the original jQuery.tipsy by Jason Frame * =========================================================== - * Copyright 2012 Twitter, Inc. + * Copyright 2013 Twitter, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,19 +38,27 @@ , init: function (type, element, options) { var eventIn , eventOut + , triggers + , trigger + , i this.type = type this.$element = $(element) this.options = this.getOptions(options) this.enabled = true - if (this.options.trigger == 'click') { - this.$element.on('click.' + this.type, this.options.selector, $.proxy(this.toggle, this)) - } else if (this.options.trigger != 'manual') { - eventIn = this.options.trigger == 'hover' ? 'mouseenter' : 'focus' - eventOut = this.options.trigger == 'hover' ? 'mouseleave' : 'blur' - this.$element.on(eventIn + '.' + this.type, this.options.selector, $.proxy(this.enter, this)) - this.$element.on(eventOut + '.' + this.type, this.options.selector, $.proxy(this.leave, this)) + triggers = this.options.trigger.split(' ') + + for (i = triggers.length; i--;) { + trigger = triggers[i] + if (trigger == 'click') { + this.$element.on('click.' + this.type, this.options.selector, $.proxy(this.toggle, this)) + } else if (trigger != 'manual') { + eventIn = trigger == 'hover' ? 'mouseenter' : 'focus' + eventOut = trigger == 'hover' ? 'mouseleave' : 'blur' + this.$element.on(eventIn + '.' + this.type, this.options.selector, $.proxy(this.enter, this)) + this.$element.on(eventOut + '.' + this.type, this.options.selector, $.proxy(this.leave, this)) + } } this.options.selector ? @@ -59,7 +67,7 @@ } , getOptions: function (options) { - options = $.extend({}, $.fn[this.type].defaults, options, this.$element.data()) + options = $.extend({}, $.fn[this.type].defaults, this.$element.data(), options) if (options.delay && typeof options.delay == 'number') { options.delay = { @@ -72,7 +80,15 @@ } , enter: function (e) { - var self = $(e.currentTarget)[this.type](this._options).data(this.type) + var defaults = $.fn[this.type].defaults + , options = {} + , self + + this._options && $.each(this._options, function (key, value) { + if (defaults[key] != value) options[key] = value + }, this) + + self = $(e.currentTarget)[this.type](options).data(this.type) if (!self.options.delay || !self.options.delay.show) return self.show() @@ -97,14 +113,16 @@ , show: function () { var $tip - , inside , pos , actualWidth , actualHeight , placement , tp + , e = $.Event('show') if (this.hasContent() && this.enabled) { + this.$element.trigger(e) + if (e.isDefaultPrevented()) return $tip = this.tip() this.setContent() @@ -116,19 +134,18 @@ this.options.placement.call(this, $tip[0], this.$element[0]) : this.options.placement - inside = /in/.test(placement) - $tip .detach() .css({ top: 0, left: 0, display: 'block' }) - .insertAfter(this.$element) - pos = this.getPosition(inside) + this.options.container ? $tip.appendTo(this.options.container) : $tip.insertAfter(this.$element) + + pos = this.getPosition() actualWidth = $tip[0].offsetWidth actualHeight = $tip[0].offsetHeight - switch (inside ? placement.split(' ')[1] : placement) { + switch (placement) { case 'bottom': tp = {top: pos.top + pos.height, left: pos.left + pos.width / 2 - actualWidth / 2} break @@ -143,11 +160,56 @@ break } - $tip - .offset(tp) - .addClass(placement) - .addClass('in') + this.applyPlacement(tp, placement) + this.$element.trigger('shown') + } + } + + , applyPlacement: function(offset, placement){ + var $tip = this.tip() + , width = $tip[0].offsetWidth + , height = $tip[0].offsetHeight + , actualWidth + , actualHeight + , delta + , replace + + $tip + .offset(offset) + .addClass(placement) + .addClass('in') + + actualWidth = $tip[0].offsetWidth + actualHeight = $tip[0].offsetHeight + + if (placement == 'top' && actualHeight != height) { + offset.top = offset.top + height - actualHeight + replace = true + } + + if (placement == 'bottom' || placement == 'top') { + delta = 0 + + if (offset.left < 0){ + delta = offset.left * -2 + offset.left = 0 + $tip.offset(offset) + actualWidth = $tip[0].offsetWidth + actualHeight = $tip[0].offsetHeight + } + + this.replaceArrow(delta - width + actualWidth, actualWidth, 'left') + } else { + this.replaceArrow(actualHeight - height, actualHeight, 'top') } + + if (replace) $tip.offset(offset) + } + + , replaceArrow: function(delta, dimension, position){ + this + .arrow() + .css(position, delta ? (50 * (1 - delta / dimension) + "%") : '') } , setContent: function () { @@ -161,6 +223,10 @@ , hide: function () { var that = this , $tip = this.tip() + , e = $.Event('hide') + + this.$element.trigger(e) + if (e.isDefaultPrevented()) return $tip.removeClass('in') @@ -179,6 +245,8 @@ removeWithAnimation() : $tip.detach() + this.$element.trigger('hidden') + return this } @@ -193,11 +261,12 @@ return this.getTitle() } - , getPosition: function (inside) { - return $.extend({}, (inside ? {top: 0, left: 0} : this.$element.offset()), { - width: this.$element[0].offsetWidth - , height: this.$element[0].offsetHeight - }) + , getPosition: function () { + var el = this.$element[0] + return $.extend({}, (typeof el.getBoundingClientRect == 'function') ? el.getBoundingClientRect() : { + width: el.offsetWidth + , height: el.offsetHeight + }, this.$element.offset()) } , getTitle: function () { @@ -215,6 +284,10 @@ return this.$tip = this.$tip || $(this.options.template) } + , arrow: function(){ + return this.$arrow = this.$arrow || this.tip().find(".tooltip-arrow") + } + , validate: function () { if (!this.$element[0].parentNode) { this.hide() @@ -236,8 +309,8 @@ } , toggle: function (e) { - var self = $(e.currentTarget)[this.type](this._options).data(this.type) - self[self.tip().hasClass('in') ? 'hide' : 'show']() + var self = e ? $(e.currentTarget)[this.type](this._options).data(this.type) : this + self.tip().hasClass('in') ? self.hide() : self.show() } , destroy: function () { @@ -269,10 +342,11 @@ , placement: 'top' , selector: false , template: '

' - , trigger: 'hover' + , trigger: 'hover focus' , title: '' , delay: 0 , html: false + , container: false } @@ -285,4 +359,3 @@ } }(window.jQuery); - diff --git a/core/src/main/resources/org/apache/spark/ui/static/d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/d3.min.js new file mode 100644 index 000000000000..30cd292198b9 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/d3.min.js @@ -0,0 +1,5 @@ +/*v3.5.5*/!function(){function n(n){return n&&(n.ownerDocument||n.document||n).documentElement}function t(n){return n&&(n.ownerDocument&&n.ownerDocument.defaultView||n.document&&n||n.defaultView)}function e(n,t){return t>n?-1:n>t?1:n>=t?0:0/0}function r(n){return null===n?0/0:+n}function u(n){return!isNaN(n)}function i(n){return{left:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n(t[i],e)<0?r=i+1:u=i}return r},right:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n(t[i],e)>0?u=i:r=i+1}return r}}}function o(n){return n.length}function a(n){for(var t=1;n*t%1;)t*=10;return t}function c(n,t){for(var e in t)Object.defineProperty(n.prototype,e,{value:t[e],enumerable:!1})}function l(){this._=Object.create(null)}function s(n){return(n+="")===pa||n[0]===va?va+n:n}function f(n){return(n+="")[0]===va?n.slice(1):n}function h(n){return s(n)in this._}function g(n){return(n=s(n))in this._&&delete this._[n]}function p(){var n=[];for(var t in this._)n.push(f(t));return n}function v(){var n=0;for(var t in this._)++n;return n}function d(){for(var n in this._)return!1;return!0}function m(){this._=Object.create(null)}function y(n){return n}function M(n,t,e){return function(){var r=e.apply(t,arguments);return r===t?n:r}}function x(n,t){if(t in n)return t;t=t.charAt(0).toUpperCase()+t.slice(1);for(var e=0,r=da.length;r>e;++e){var u=da[e]+t;if(u in n)return u}}function b(){}function _(){}function w(n){function t(){for(var t,r=e,u=-1,i=r.length;++ue;e++)for(var u,i=n[e],o=0,a=i.length;a>o;o++)(u=i[o])&&t(u,o,e);return n}function Z(n){return ya(n,Sa),n}function V(n){var t,e;return function(r,u,i){var o,a=n[i].update,c=a.length;for(i!=e&&(e=i,t=0),u>=t&&(t=u+1);!(o=a[t])&&++t0&&(n=n.slice(0,a));var l=ka.get(n);return l&&(n=l,c=B),a?t?u:r:t?b:i}function $(n,t){return function(e){var r=ta.event;ta.event=e,t[0]=this.__data__;try{n.apply(this,t)}finally{ta.event=r}}}function B(n,t){var e=$(n,t);return function(n){var t=this,r=n.relatedTarget;r&&(r===t||8&r.compareDocumentPosition(t))||e.call(t,n)}}function W(e){var r=".dragsuppress-"+ ++Aa,u="click"+r,i=ta.select(t(e)).on("touchmove"+r,S).on("dragstart"+r,S).on("selectstart"+r,S);if(null==Ea&&(Ea="onselectstart"in e?!1:x(e.style,"userSelect")),Ea){var o=n(e).style,a=o[Ea];o[Ea]="none"}return function(n){if(i.on(r,null),Ea&&(o[Ea]=a),n){var t=function(){i.on(u,null)};i.on(u,function(){S(),t()},!0),setTimeout(t,0)}}}function J(n,e){e.changedTouches&&(e=e.changedTouches[0]);var r=n.ownerSVGElement||n;if(r.createSVGPoint){var u=r.createSVGPoint();if(0>Na){var i=t(n);if(i.scrollX||i.scrollY){r=ta.select("body").append("svg").style({position:"absolute",top:0,left:0,margin:0,padding:0,border:"none"},"important");var o=r[0][0].getScreenCTM();Na=!(o.f||o.e),r.remove()}}return Na?(u.x=e.pageX,u.y=e.pageY):(u.x=e.clientX,u.y=e.clientY),u=u.matrixTransform(n.getScreenCTM().inverse()),[u.x,u.y]}var a=n.getBoundingClientRect();return[e.clientX-a.left-n.clientLeft,e.clientY-a.top-n.clientTop]}function G(){return ta.event.changedTouches[0].identifier}function K(n){return n>0?1:0>n?-1:0}function Q(n,t,e){return(t[0]-n[0])*(e[1]-n[1])-(t[1]-n[1])*(e[0]-n[0])}function nt(n){return n>1?0:-1>n?qa:Math.acos(n)}function tt(n){return n>1?Ra:-1>n?-Ra:Math.asin(n)}function et(n){return((n=Math.exp(n))-1/n)/2}function rt(n){return((n=Math.exp(n))+1/n)/2}function ut(n){return((n=Math.exp(2*n))-1)/(n+1)}function it(n){return(n=Math.sin(n/2))*n}function ot(){}function at(n,t,e){return this instanceof at?(this.h=+n,this.s=+t,void(this.l=+e)):arguments.length<2?n instanceof at?new at(n.h,n.s,n.l):bt(""+n,_t,at):new at(n,t,e)}function ct(n,t,e){function r(n){return n>360?n-=360:0>n&&(n+=360),60>n?i+(o-i)*n/60:180>n?o:240>n?i+(o-i)*(240-n)/60:i}function u(n){return Math.round(255*r(n))}var i,o;return n=isNaN(n)?0:(n%=360)<0?n+360:n,t=isNaN(t)?0:0>t?0:t>1?1:t,e=0>e?0:e>1?1:e,o=.5>=e?e*(1+t):e+t-e*t,i=2*e-o,new mt(u(n+120),u(n),u(n-120))}function lt(n,t,e){return this instanceof lt?(this.h=+n,this.c=+t,void(this.l=+e)):arguments.length<2?n instanceof lt?new lt(n.h,n.c,n.l):n instanceof ft?gt(n.l,n.a,n.b):gt((n=wt((n=ta.rgb(n)).r,n.g,n.b)).l,n.a,n.b):new lt(n,t,e)}function st(n,t,e){return isNaN(n)&&(n=0),isNaN(t)&&(t=0),new ft(e,Math.cos(n*=Da)*t,Math.sin(n)*t)}function ft(n,t,e){return this instanceof ft?(this.l=+n,this.a=+t,void(this.b=+e)):arguments.length<2?n instanceof ft?new ft(n.l,n.a,n.b):n instanceof lt?st(n.h,n.c,n.l):wt((n=mt(n)).r,n.g,n.b):new ft(n,t,e)}function ht(n,t,e){var r=(n+16)/116,u=r+t/500,i=r-e/200;return u=pt(u)*Xa,r=pt(r)*$a,i=pt(i)*Ba,new mt(dt(3.2404542*u-1.5371385*r-.4985314*i),dt(-.969266*u+1.8760108*r+.041556*i),dt(.0556434*u-.2040259*r+1.0572252*i))}function gt(n,t,e){return n>0?new lt(Math.atan2(e,t)*Pa,Math.sqrt(t*t+e*e),n):new lt(0/0,0/0,n)}function pt(n){return n>.206893034?n*n*n:(n-4/29)/7.787037}function vt(n){return n>.008856?Math.pow(n,1/3):7.787037*n+4/29}function dt(n){return Math.round(255*(.00304>=n?12.92*n:1.055*Math.pow(n,1/2.4)-.055))}function mt(n,t,e){return this instanceof mt?(this.r=~~n,this.g=~~t,void(this.b=~~e)):arguments.length<2?n instanceof mt?new mt(n.r,n.g,n.b):bt(""+n,mt,ct):new mt(n,t,e)}function yt(n){return new mt(n>>16,n>>8&255,255&n)}function Mt(n){return yt(n)+""}function xt(n){return 16>n?"0"+Math.max(0,n).toString(16):Math.min(255,n).toString(16)}function bt(n,t,e){var r,u,i,o=0,a=0,c=0;if(r=/([a-z]+)\((.*)\)/i.exec(n))switch(u=r[2].split(","),r[1]){case"hsl":return e(parseFloat(u[0]),parseFloat(u[1])/100,parseFloat(u[2])/100);case"rgb":return t(kt(u[0]),kt(u[1]),kt(u[2]))}return(i=Ga.get(n.toLowerCase()))?t(i.r,i.g,i.b):(null==n||"#"!==n.charAt(0)||isNaN(i=parseInt(n.slice(1),16))||(4===n.length?(o=(3840&i)>>4,o=o>>4|o,a=240&i,a=a>>4|a,c=15&i,c=c<<4|c):7===n.length&&(o=(16711680&i)>>16,a=(65280&i)>>8,c=255&i)),t(o,a,c))}function _t(n,t,e){var r,u,i=Math.min(n/=255,t/=255,e/=255),o=Math.max(n,t,e),a=o-i,c=(o+i)/2;return a?(u=.5>c?a/(o+i):a/(2-o-i),r=n==o?(t-e)/a+(e>t?6:0):t==o?(e-n)/a+2:(n-t)/a+4,r*=60):(r=0/0,u=c>0&&1>c?0:r),new at(r,u,c)}function wt(n,t,e){n=St(n),t=St(t),e=St(e);var r=vt((.4124564*n+.3575761*t+.1804375*e)/Xa),u=vt((.2126729*n+.7151522*t+.072175*e)/$a),i=vt((.0193339*n+.119192*t+.9503041*e)/Ba);return ft(116*u-16,500*(r-u),200*(u-i))}function St(n){return(n/=255)<=.04045?n/12.92:Math.pow((n+.055)/1.055,2.4)}function kt(n){var t=parseFloat(n);return"%"===n.charAt(n.length-1)?Math.round(2.55*t):t}function Et(n){return"function"==typeof n?n:function(){return n}}function At(n){return function(t,e,r){return 2===arguments.length&&"function"==typeof e&&(r=e,e=null),Nt(t,e,n,r)}}function Nt(n,t,e,r){function u(){var n,t=c.status;if(!t&&zt(c)||t>=200&&300>t||304===t){try{n=e.call(i,c)}catch(r){return void o.error.call(i,r)}o.load.call(i,n)}else o.error.call(i,c)}var i={},o=ta.dispatch("beforesend","progress","load","error"),a={},c=new XMLHttpRequest,l=null;return!this.XDomainRequest||"withCredentials"in c||!/^(http(s)?:)?\/\//.test(n)||(c=new XDomainRequest),"onload"in c?c.onload=c.onerror=u:c.onreadystatechange=function(){c.readyState>3&&u()},c.onprogress=function(n){var t=ta.event;ta.event=n;try{o.progress.call(i,c)}finally{ta.event=t}},i.header=function(n,t){return n=(n+"").toLowerCase(),arguments.length<2?a[n]:(null==t?delete a[n]:a[n]=t+"",i)},i.mimeType=function(n){return arguments.length?(t=null==n?null:n+"",i):t},i.responseType=function(n){return arguments.length?(l=n,i):l},i.response=function(n){return e=n,i},["get","post"].forEach(function(n){i[n]=function(){return i.send.apply(i,[n].concat(ra(arguments)))}}),i.send=function(e,r,u){if(2===arguments.length&&"function"==typeof r&&(u=r,r=null),c.open(e,n,!0),null==t||"accept"in a||(a.accept=t+",*/*"),c.setRequestHeader)for(var s in a)c.setRequestHeader(s,a[s]);return null!=t&&c.overrideMimeType&&c.overrideMimeType(t),null!=l&&(c.responseType=l),null!=u&&i.on("error",u).on("load",function(n){u(null,n)}),o.beforesend.call(i,c),c.send(null==r?null:r),i},i.abort=function(){return c.abort(),i},ta.rebind(i,o,"on"),null==r?i:i.get(Ct(r))}function Ct(n){return 1===n.length?function(t,e){n(null==t?e:null)}:n}function zt(n){var t=n.responseType;return t&&"text"!==t?n.response:n.responseText}function qt(){var n=Lt(),t=Tt()-n;t>24?(isFinite(t)&&(clearTimeout(tc),tc=setTimeout(qt,t)),nc=0):(nc=1,rc(qt))}function Lt(){var n=Date.now();for(ec=Ka;ec;)n>=ec.t&&(ec.f=ec.c(n-ec.t)),ec=ec.n;return n}function Tt(){for(var n,t=Ka,e=1/0;t;)t.f?t=n?n.n=t.n:Ka=t.n:(t.t8?function(n){return n/e}:function(n){return n*e},symbol:n}}function Pt(n){var t=n.decimal,e=n.thousands,r=n.grouping,u=n.currency,i=r&&e?function(n,t){for(var u=n.length,i=[],o=0,a=r[0],c=0;u>0&&a>0&&(c+a+1>t&&(a=Math.max(1,t-c)),i.push(n.substring(u-=a,u+a)),!((c+=a+1)>t));)a=r[o=(o+1)%r.length];return i.reverse().join(e)}:y;return function(n){var e=ic.exec(n),r=e[1]||" ",o=e[2]||">",a=e[3]||"-",c=e[4]||"",l=e[5],s=+e[6],f=e[7],h=e[8],g=e[9],p=1,v="",d="",m=!1,y=!0;switch(h&&(h=+h.substring(1)),(l||"0"===r&&"="===o)&&(l=r="0",o="="),g){case"n":f=!0,g="g";break;case"%":p=100,d="%",g="f";break;case"p":p=100,d="%",g="r";break;case"b":case"o":case"x":case"X":"#"===c&&(v="0"+g.toLowerCase());case"c":y=!1;case"d":m=!0,h=0;break;case"s":p=-1,g="r"}"$"===c&&(v=u[0],d=u[1]),"r"!=g||h||(g="g"),null!=h&&("g"==g?h=Math.max(1,Math.min(21,h)):("e"==g||"f"==g)&&(h=Math.max(0,Math.min(20,h)))),g=oc.get(g)||Ut;var M=l&&f;return function(n){var e=d;if(m&&n%1)return"";var u=0>n||0===n&&0>1/n?(n=-n,"-"):"-"===a?"":a;if(0>p){var c=ta.formatPrefix(n,h);n=c.scale(n),e=c.symbol+d}else n*=p;n=g(n,h);var x,b,_=n.lastIndexOf(".");if(0>_){var w=y?n.lastIndexOf("e"):-1;0>w?(x=n,b=""):(x=n.substring(0,w),b=n.substring(w))}else x=n.substring(0,_),b=t+n.substring(_+1);!l&&f&&(x=i(x,1/0));var S=v.length+x.length+b.length+(M?0:u.length),k=s>S?new Array(S=s-S+1).join(r):"";return M&&(x=i(k+x,k.length?s-b.length:1/0)),u+=v,n=x+b,("<"===o?u+n+k:">"===o?k+u+n:"^"===o?k.substring(0,S>>=1)+u+n+k.substring(S):u+(M?n:k+n))+e}}}function Ut(n){return n+""}function jt(){this._=new Date(arguments.length>1?Date.UTC.apply(this,arguments):arguments[0])}function Ft(n,t,e){function r(t){var e=n(t),r=i(e,1);return r-t>t-e?e:r}function u(e){return t(e=n(new cc(e-1)),1),e}function i(n,e){return t(n=new cc(+n),e),n}function o(n,r,i){var o=u(n),a=[];if(i>1)for(;r>o;)e(o)%i||a.push(new Date(+o)),t(o,1);else for(;r>o;)a.push(new Date(+o)),t(o,1);return a}function a(n,t,e){try{cc=jt;var r=new jt;return r._=n,o(r,t,e)}finally{cc=Date}}n.floor=n,n.round=r,n.ceil=u,n.offset=i,n.range=o;var c=n.utc=Ht(n);return c.floor=c,c.round=Ht(r),c.ceil=Ht(u),c.offset=Ht(i),c.range=a,n}function Ht(n){return function(t,e){try{cc=jt;var r=new jt;return r._=t,n(r,e)._}finally{cc=Date}}}function Ot(n){function t(n){function t(t){for(var e,u,i,o=[],a=-1,c=0;++aa;){if(r>=l)return-1;if(u=t.charCodeAt(a++),37===u){if(o=t.charAt(a++),i=C[o in sc?t.charAt(a++):o],!i||(r=i(n,e,r))<0)return-1}else if(u!=e.charCodeAt(r++))return-1}return r}function r(n,t,e){_.lastIndex=0;var r=_.exec(t.slice(e));return r?(n.w=w.get(r[0].toLowerCase()),e+r[0].length):-1}function u(n,t,e){x.lastIndex=0;var r=x.exec(t.slice(e));return r?(n.w=b.get(r[0].toLowerCase()),e+r[0].length):-1}function i(n,t,e){E.lastIndex=0;var r=E.exec(t.slice(e));return r?(n.m=A.get(r[0].toLowerCase()),e+r[0].length):-1}function o(n,t,e){S.lastIndex=0;var r=S.exec(t.slice(e));return r?(n.m=k.get(r[0].toLowerCase()),e+r[0].length):-1}function a(n,t,r){return e(n,N.c.toString(),t,r)}function c(n,t,r){return e(n,N.x.toString(),t,r)}function l(n,t,r){return e(n,N.X.toString(),t,r)}function s(n,t,e){var r=M.get(t.slice(e,e+=2).toLowerCase());return null==r?-1:(n.p=r,e)}var f=n.dateTime,h=n.date,g=n.time,p=n.periods,v=n.days,d=n.shortDays,m=n.months,y=n.shortMonths;t.utc=function(n){function e(n){try{cc=jt;var t=new cc;return t._=n,r(t)}finally{cc=Date}}var r=t(n);return e.parse=function(n){try{cc=jt;var t=r.parse(n);return t&&t._}finally{cc=Date}},e.toString=r.toString,e},t.multi=t.utc.multi=ae;var M=ta.map(),x=Yt(v),b=Zt(v),_=Yt(d),w=Zt(d),S=Yt(m),k=Zt(m),E=Yt(y),A=Zt(y);p.forEach(function(n,t){M.set(n.toLowerCase(),t)});var N={a:function(n){return d[n.getDay()]},A:function(n){return v[n.getDay()]},b:function(n){return y[n.getMonth()]},B:function(n){return m[n.getMonth()]},c:t(f),d:function(n,t){return It(n.getDate(),t,2)},e:function(n,t){return It(n.getDate(),t,2)},H:function(n,t){return It(n.getHours(),t,2)},I:function(n,t){return It(n.getHours()%12||12,t,2)},j:function(n,t){return It(1+ac.dayOfYear(n),t,3)},L:function(n,t){return It(n.getMilliseconds(),t,3)},m:function(n,t){return It(n.getMonth()+1,t,2)},M:function(n,t){return It(n.getMinutes(),t,2)},p:function(n){return p[+(n.getHours()>=12)]},S:function(n,t){return It(n.getSeconds(),t,2)},U:function(n,t){return It(ac.sundayOfYear(n),t,2)},w:function(n){return n.getDay()},W:function(n,t){return It(ac.mondayOfYear(n),t,2)},x:t(h),X:t(g),y:function(n,t){return It(n.getFullYear()%100,t,2)},Y:function(n,t){return It(n.getFullYear()%1e4,t,4)},Z:ie,"%":function(){return"%"}},C={a:r,A:u,b:i,B:o,c:a,d:Qt,e:Qt,H:te,I:te,j:ne,L:ue,m:Kt,M:ee,p:s,S:re,U:Xt,w:Vt,W:$t,x:c,X:l,y:Wt,Y:Bt,Z:Jt,"%":oe};return t}function It(n,t,e){var r=0>n?"-":"",u=(r?-n:n)+"",i=u.length;return r+(e>i?new Array(e-i+1).join(t)+u:u)}function Yt(n){return new RegExp("^(?:"+n.map(ta.requote).join("|")+")","i")}function Zt(n){for(var t=new l,e=-1,r=n.length;++e68?1900:2e3)}function Kt(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.m=r[0]-1,e+r[0].length):-1}function Qt(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.d=+r[0],e+r[0].length):-1}function ne(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+3));return r?(n.j=+r[0],e+r[0].length):-1}function te(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.H=+r[0],e+r[0].length):-1}function ee(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.M=+r[0],e+r[0].length):-1}function re(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+2));return r?(n.S=+r[0],e+r[0].length):-1}function ue(n,t,e){fc.lastIndex=0;var r=fc.exec(t.slice(e,e+3));return r?(n.L=+r[0],e+r[0].length):-1}function ie(n){var t=n.getTimezoneOffset(),e=t>0?"-":"+",r=ga(t)/60|0,u=ga(t)%60;return e+It(r,"0",2)+It(u,"0",2)}function oe(n,t,e){hc.lastIndex=0;var r=hc.exec(t.slice(e,e+1));return r?e+r[0].length:-1}function ae(n){for(var t=n.length,e=-1;++e=0?1:-1,a=o*e,c=Math.cos(t),l=Math.sin(t),s=i*l,f=u*c+s*Math.cos(a),h=s*o*Math.sin(a);yc.add(Math.atan2(h,f)),r=n,u=c,i=l}var t,e,r,u,i;Mc.point=function(o,a){Mc.point=n,r=(t=o)*Da,u=Math.cos(a=(e=a)*Da/2+qa/4),i=Math.sin(a)},Mc.lineEnd=function(){n(t,e)}}function pe(n){var t=n[0],e=n[1],r=Math.cos(e);return[r*Math.cos(t),r*Math.sin(t),Math.sin(e)]}function ve(n,t){return n[0]*t[0]+n[1]*t[1]+n[2]*t[2]}function de(n,t){return[n[1]*t[2]-n[2]*t[1],n[2]*t[0]-n[0]*t[2],n[0]*t[1]-n[1]*t[0]]}function me(n,t){n[0]+=t[0],n[1]+=t[1],n[2]+=t[2]}function ye(n,t){return[n[0]*t,n[1]*t,n[2]*t]}function Me(n){var t=Math.sqrt(n[0]*n[0]+n[1]*n[1]+n[2]*n[2]);n[0]/=t,n[1]/=t,n[2]/=t}function xe(n){return[Math.atan2(n[1],n[0]),tt(n[2])]}function be(n,t){return ga(n[0]-t[0])a;++a)u.point((e=n[a])[0],e[1]);return void u.lineEnd()}var c=new qe(e,n,null,!0),l=new qe(e,null,c,!1);c.o=l,i.push(c),o.push(l),c=new qe(r,n,null,!1),l=new qe(r,null,c,!0),c.o=l,i.push(c),o.push(l)}}),o.sort(t),ze(i),ze(o),i.length){for(var a=0,c=e,l=o.length;l>a;++a)o[a].e=c=!c;for(var s,f,h=i[0];;){for(var g=h,p=!0;g.v;)if((g=g.n)===h)return;s=g.z,u.lineStart();do{if(g.v=g.o.v=!0,g.e){if(p)for(var a=0,l=s.length;l>a;++a)u.point((f=s[a])[0],f[1]);else r(g.x,g.n.x,1,u);g=g.n}else{if(p){s=g.p.z;for(var a=s.length-1;a>=0;--a)u.point((f=s[a])[0],f[1])}else r(g.x,g.p.x,-1,u);g=g.p}g=g.o,s=g.z,p=!p}while(!g.v);u.lineEnd()}}}function ze(n){if(t=n.length){for(var t,e,r=0,u=n[0];++r0){for(b||(i.polygonStart(),b=!0),i.lineStart();++o1&&2&t&&e.push(e.pop().concat(e.shift())),g.push(e.filter(Te))}var g,p,v,d=t(i),m=u.invert(r[0],r[1]),y={point:o,lineStart:c,lineEnd:l,polygonStart:function(){y.point=s,y.lineStart=f,y.lineEnd=h,g=[],p=[]},polygonEnd:function(){y.point=o,y.lineStart=c,y.lineEnd=l,g=ta.merge(g);var n=Fe(m,p);g.length?(b||(i.polygonStart(),b=!0),Ce(g,De,n,e,i)):n&&(b||(i.polygonStart(),b=!0),i.lineStart(),e(null,null,1,i),i.lineEnd()),b&&(i.polygonEnd(),b=!1),g=p=null},sphere:function(){i.polygonStart(),i.lineStart(),e(null,null,1,i),i.lineEnd(),i.polygonEnd()}},M=Re(),x=t(M),b=!1;return y}}function Te(n){return n.length>1}function Re(){var n,t=[];return{lineStart:function(){t.push(n=[])},point:function(t,e){n.push([t,e])},lineEnd:b,buffer:function(){var e=t;return t=[],n=null,e},rejoin:function(){t.length>1&&t.push(t.pop().concat(t.shift()))}}}function De(n,t){return((n=n.x)[0]<0?n[1]-Ra-Ca:Ra-n[1])-((t=t.x)[0]<0?t[1]-Ra-Ca:Ra-t[1])}function Pe(n){var t,e=0/0,r=0/0,u=0/0;return{lineStart:function(){n.lineStart(),t=1},point:function(i,o){var a=i>0?qa:-qa,c=ga(i-e);ga(c-qa)0?Ra:-Ra),n.point(u,r),n.lineEnd(),n.lineStart(),n.point(a,r),n.point(i,r),t=0):u!==a&&c>=qa&&(ga(e-u)Ca?Math.atan((Math.sin(t)*(i=Math.cos(r))*Math.sin(e)-Math.sin(r)*(u=Math.cos(t))*Math.sin(n))/(u*i*o)):(t+r)/2}function je(n,t,e,r){var u;if(null==n)u=e*Ra,r.point(-qa,u),r.point(0,u),r.point(qa,u),r.point(qa,0),r.point(qa,-u),r.point(0,-u),r.point(-qa,-u),r.point(-qa,0),r.point(-qa,u);else if(ga(n[0]-t[0])>Ca){var i=n[0]a;++a){var l=t[a],s=l.length;if(s)for(var f=l[0],h=f[0],g=f[1]/2+qa/4,p=Math.sin(g),v=Math.cos(g),d=1;;){d===s&&(d=0),n=l[d];var m=n[0],y=n[1]/2+qa/4,M=Math.sin(y),x=Math.cos(y),b=m-h,_=b>=0?1:-1,w=_*b,S=w>qa,k=p*M;if(yc.add(Math.atan2(k*_*Math.sin(w),v*x+k*Math.cos(w))),i+=S?b+_*La:b,S^h>=e^m>=e){var E=de(pe(f),pe(n));Me(E);var A=de(u,E);Me(A);var N=(S^b>=0?-1:1)*tt(A[2]);(r>N||r===N&&(E[0]||E[1]))&&(o+=S^b>=0?1:-1)}if(!d++)break;h=m,p=M,v=x,f=n}}return(-Ca>i||Ca>i&&0>yc)^1&o}function He(n){function t(n,t){return Math.cos(n)*Math.cos(t)>i}function e(n){var e,i,c,l,s;return{lineStart:function(){l=c=!1,s=1},point:function(f,h){var g,p=[f,h],v=t(f,h),d=o?v?0:u(f,h):v?u(f+(0>f?qa:-qa),h):0;if(!e&&(l=c=v)&&n.lineStart(),v!==c&&(g=r(e,p),(be(e,g)||be(p,g))&&(p[0]+=Ca,p[1]+=Ca,v=t(p[0],p[1]))),v!==c)s=0,v?(n.lineStart(),g=r(p,e),n.point(g[0],g[1])):(g=r(e,p),n.point(g[0],g[1]),n.lineEnd()),e=g;else if(a&&e&&o^v){var m;d&i||!(m=r(p,e,!0))||(s=0,o?(n.lineStart(),n.point(m[0][0],m[0][1]),n.point(m[1][0],m[1][1]),n.lineEnd()):(n.point(m[1][0],m[1][1]),n.lineEnd(),n.lineStart(),n.point(m[0][0],m[0][1])))}!v||e&&be(e,p)||n.point(p[0],p[1]),e=p,c=v,i=d},lineEnd:function(){c&&n.lineEnd(),e=null},clean:function(){return s|(l&&c)<<1}}}function r(n,t,e){var r=pe(n),u=pe(t),o=[1,0,0],a=de(r,u),c=ve(a,a),l=a[0],s=c-l*l;if(!s)return!e&&n;var f=i*c/s,h=-i*l/s,g=de(o,a),p=ye(o,f),v=ye(a,h);me(p,v);var d=g,m=ve(p,d),y=ve(d,d),M=m*m-y*(ve(p,p)-1);if(!(0>M)){var x=Math.sqrt(M),b=ye(d,(-m-x)/y);if(me(b,p),b=xe(b),!e)return b;var _,w=n[0],S=t[0],k=n[1],E=t[1];w>S&&(_=w,w=S,S=_);var A=S-w,N=ga(A-qa)A;if(!N&&k>E&&(_=k,k=E,E=_),C?N?k+E>0^b[1]<(ga(b[0]-w)qa^(w<=b[0]&&b[0]<=S)){var z=ye(d,(-m+x)/y);return me(z,p),[b,xe(z)]}}}function u(t,e){var r=o?n:qa-n,u=0;return-r>t?u|=1:t>r&&(u|=2),-r>e?u|=4:e>r&&(u|=8),u}var i=Math.cos(n),o=i>0,a=ga(i)>Ca,c=gr(n,6*Da);return Le(t,e,c,o?[0,-n]:[-qa,n-qa])}function Oe(n,t,e,r){return function(u){var i,o=u.a,a=u.b,c=o.x,l=o.y,s=a.x,f=a.y,h=0,g=1,p=s-c,v=f-l;if(i=n-c,p||!(i>0)){if(i/=p,0>p){if(h>i)return;g>i&&(g=i)}else if(p>0){if(i>g)return;i>h&&(h=i)}if(i=e-c,p||!(0>i)){if(i/=p,0>p){if(i>g)return;i>h&&(h=i)}else if(p>0){if(h>i)return;g>i&&(g=i)}if(i=t-l,v||!(i>0)){if(i/=v,0>v){if(h>i)return;g>i&&(g=i)}else if(v>0){if(i>g)return;i>h&&(h=i)}if(i=r-l,v||!(0>i)){if(i/=v,0>v){if(i>g)return;i>h&&(h=i)}else if(v>0){if(h>i)return;g>i&&(g=i)}return h>0&&(u.a={x:c+h*p,y:l+h*v}),1>g&&(u.b={x:c+g*p,y:l+g*v}),u}}}}}}function Ie(n,t,e,r){function u(r,u){return ga(r[0]-n)0?0:3:ga(r[0]-e)0?2:1:ga(r[1]-t)0?1:0:u>0?3:2}function i(n,t){return o(n.x,t.x)}function o(n,t){var e=u(n,1),r=u(t,1);return e!==r?e-r:0===e?t[1]-n[1]:1===e?n[0]-t[0]:2===e?n[1]-t[1]:t[0]-n[0]}return function(a){function c(n){for(var t=0,e=d.length,r=n[1],u=0;e>u;++u)for(var i,o=1,a=d[u],c=a.length,l=a[0];c>o;++o)i=a[o],l[1]<=r?i[1]>r&&Q(l,i,n)>0&&++t:i[1]<=r&&Q(l,i,n)<0&&--t,l=i;return 0!==t}function l(i,a,c,l){var s=0,f=0;if(null==i||(s=u(i,c))!==(f=u(a,c))||o(i,a)<0^c>0){do l.point(0===s||3===s?n:e,s>1?r:t);while((s=(s+c+4)%4)!==f)}else l.point(a[0],a[1])}function s(u,i){return u>=n&&e>=u&&i>=t&&r>=i}function f(n,t){s(n,t)&&a.point(n,t)}function h(){C.point=p,d&&d.push(m=[]),S=!0,w=!1,b=_=0/0}function g(){v&&(p(y,M),x&&w&&A.rejoin(),v.push(A.buffer())),C.point=f,w&&a.lineEnd()}function p(n,t){n=Math.max(-Tc,Math.min(Tc,n)),t=Math.max(-Tc,Math.min(Tc,t));var e=s(n,t);if(d&&m.push([n,t]),S)y=n,M=t,x=e,S=!1,e&&(a.lineStart(),a.point(n,t));else if(e&&w)a.point(n,t);else{var r={a:{x:b,y:_},b:{x:n,y:t}};N(r)?(w||(a.lineStart(),a.point(r.a.x,r.a.y)),a.point(r.b.x,r.b.y),e||a.lineEnd(),k=!1):e&&(a.lineStart(),a.point(n,t),k=!1)}b=n,_=t,w=e}var v,d,m,y,M,x,b,_,w,S,k,E=a,A=Re(),N=Oe(n,t,e,r),C={point:f,lineStart:h,lineEnd:g,polygonStart:function(){a=A,v=[],d=[],k=!0},polygonEnd:function(){a=E,v=ta.merge(v);var t=c([n,r]),e=k&&t,u=v.length;(e||u)&&(a.polygonStart(),e&&(a.lineStart(),l(null,null,1,a),a.lineEnd()),u&&Ce(v,i,t,l,a),a.polygonEnd()),v=d=m=null}};return C}}function Ye(n){var t=0,e=qa/3,r=ir(n),u=r(t,e);return u.parallels=function(n){return arguments.length?r(t=n[0]*qa/180,e=n[1]*qa/180):[t/qa*180,e/qa*180]},u}function Ze(n,t){function e(n,t){var e=Math.sqrt(i-2*u*Math.sin(t))/u;return[e*Math.sin(n*=u),o-e*Math.cos(n)]}var r=Math.sin(n),u=(r+Math.sin(t))/2,i=1+r*(2*u-r),o=Math.sqrt(i)/u;return e.invert=function(n,t){var e=o-t;return[Math.atan2(n,e)/u,tt((i-(n*n+e*e)*u*u)/(2*u))]},e}function Ve(){function n(n,t){Dc+=u*n-r*t,r=n,u=t}var t,e,r,u;Hc.point=function(i,o){Hc.point=n,t=r=i,e=u=o},Hc.lineEnd=function(){n(t,e)}}function Xe(n,t){Pc>n&&(Pc=n),n>jc&&(jc=n),Uc>t&&(Uc=t),t>Fc&&(Fc=t)}function $e(){function n(n,t){o.push("M",n,",",t,i)}function t(n,t){o.push("M",n,",",t),a.point=e}function e(n,t){o.push("L",n,",",t)}function r(){a.point=n}function u(){o.push("Z")}var i=Be(4.5),o=[],a={point:n,lineStart:function(){a.point=t},lineEnd:r,polygonStart:function(){a.lineEnd=u},polygonEnd:function(){a.lineEnd=r,a.point=n},pointRadius:function(n){return i=Be(n),a},result:function(){if(o.length){var n=o.join("");return o=[],n}}};return a}function Be(n){return"m0,"+n+"a"+n+","+n+" 0 1,1 0,"+-2*n+"a"+n+","+n+" 0 1,1 0,"+2*n+"z"}function We(n,t){_c+=n,wc+=t,++Sc}function Je(){function n(n,r){var u=n-t,i=r-e,o=Math.sqrt(u*u+i*i);kc+=o*(t+n)/2,Ec+=o*(e+r)/2,Ac+=o,We(t=n,e=r)}var t,e;Ic.point=function(r,u){Ic.point=n,We(t=r,e=u)}}function Ge(){Ic.point=We}function Ke(){function n(n,t){var e=n-r,i=t-u,o=Math.sqrt(e*e+i*i);kc+=o*(r+n)/2,Ec+=o*(u+t)/2,Ac+=o,o=u*n-r*t,Nc+=o*(r+n),Cc+=o*(u+t),zc+=3*o,We(r=n,u=t)}var t,e,r,u;Ic.point=function(i,o){Ic.point=n,We(t=r=i,e=u=o)},Ic.lineEnd=function(){n(t,e)}}function Qe(n){function t(t,e){n.moveTo(t+o,e),n.arc(t,e,o,0,La)}function e(t,e){n.moveTo(t,e),a.point=r}function r(t,e){n.lineTo(t,e)}function u(){a.point=t}function i(){n.closePath()}var o=4.5,a={point:t,lineStart:function(){a.point=e},lineEnd:u,polygonStart:function(){a.lineEnd=i},polygonEnd:function(){a.lineEnd=u,a.point=t},pointRadius:function(n){return o=n,a},result:b};return a}function nr(n){function t(n){return(a?r:e)(n)}function e(t){return rr(t,function(e,r){e=n(e,r),t.point(e[0],e[1])})}function r(t){function e(e,r){e=n(e,r),t.point(e[0],e[1])}function r(){M=0/0,S.point=i,t.lineStart()}function i(e,r){var i=pe([e,r]),o=n(e,r);u(M,x,y,b,_,w,M=o[0],x=o[1],y=e,b=i[0],_=i[1],w=i[2],a,t),t.point(M,x)}function o(){S.point=e,t.lineEnd()}function c(){r(),S.point=l,S.lineEnd=s}function l(n,t){i(f=n,h=t),g=M,p=x,v=b,d=_,m=w,S.point=i}function s(){u(M,x,y,b,_,w,g,p,f,v,d,m,a,t),S.lineEnd=o,o()}var f,h,g,p,v,d,m,y,M,x,b,_,w,S={point:e,lineStart:r,lineEnd:o,polygonStart:function(){t.polygonStart(),S.lineStart=c +},polygonEnd:function(){t.polygonEnd(),S.lineStart=r}};return S}function u(t,e,r,a,c,l,s,f,h,g,p,v,d,m){var y=s-t,M=f-e,x=y*y+M*M;if(x>4*i&&d--){var b=a+g,_=c+p,w=l+v,S=Math.sqrt(b*b+_*_+w*w),k=Math.asin(w/=S),E=ga(ga(w)-1)i||ga((y*z+M*q)/x-.5)>.3||o>a*g+c*p+l*v)&&(u(t,e,r,a,c,l,N,C,E,b/=S,_/=S,w,d,m),m.point(N,C),u(N,C,E,b,_,w,s,f,h,g,p,v,d,m))}}var i=.5,o=Math.cos(30*Da),a=16;return t.precision=function(n){return arguments.length?(a=(i=n*n)>0&&16,t):Math.sqrt(i)},t}function tr(n){var t=nr(function(t,e){return n([t*Pa,e*Pa])});return function(n){return or(t(n))}}function er(n){this.stream=n}function rr(n,t){return{point:t,sphere:function(){n.sphere()},lineStart:function(){n.lineStart()},lineEnd:function(){n.lineEnd()},polygonStart:function(){n.polygonStart()},polygonEnd:function(){n.polygonEnd()}}}function ur(n){return ir(function(){return n})()}function ir(n){function t(n){return n=a(n[0]*Da,n[1]*Da),[n[0]*h+c,l-n[1]*h]}function e(n){return n=a.invert((n[0]-c)/h,(l-n[1])/h),n&&[n[0]*Pa,n[1]*Pa]}function r(){a=Ae(o=lr(m,M,x),i);var n=i(v,d);return c=g-n[0]*h,l=p+n[1]*h,u()}function u(){return s&&(s.valid=!1,s=null),t}var i,o,a,c,l,s,f=nr(function(n,t){return n=i(n,t),[n[0]*h+c,l-n[1]*h]}),h=150,g=480,p=250,v=0,d=0,m=0,M=0,x=0,b=Lc,_=y,w=null,S=null;return t.stream=function(n){return s&&(s.valid=!1),s=or(b(o,f(_(n)))),s.valid=!0,s},t.clipAngle=function(n){return arguments.length?(b=null==n?(w=n,Lc):He((w=+n)*Da),u()):w},t.clipExtent=function(n){return arguments.length?(S=n,_=n?Ie(n[0][0],n[0][1],n[1][0],n[1][1]):y,u()):S},t.scale=function(n){return arguments.length?(h=+n,r()):h},t.translate=function(n){return arguments.length?(g=+n[0],p=+n[1],r()):[g,p]},t.center=function(n){return arguments.length?(v=n[0]%360*Da,d=n[1]%360*Da,r()):[v*Pa,d*Pa]},t.rotate=function(n){return arguments.length?(m=n[0]%360*Da,M=n[1]%360*Da,x=n.length>2?n[2]%360*Da:0,r()):[m*Pa,M*Pa,x*Pa]},ta.rebind(t,f,"precision"),function(){return i=n.apply(this,arguments),t.invert=i.invert&&e,r()}}function or(n){return rr(n,function(t,e){n.point(t*Da,e*Da)})}function ar(n,t){return[n,t]}function cr(n,t){return[n>qa?n-La:-qa>n?n+La:n,t]}function lr(n,t,e){return n?t||e?Ae(fr(n),hr(t,e)):fr(n):t||e?hr(t,e):cr}function sr(n){return function(t,e){return t+=n,[t>qa?t-La:-qa>t?t+La:t,e]}}function fr(n){var t=sr(n);return t.invert=sr(-n),t}function hr(n,t){function e(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,l=Math.sin(t),s=l*r+a*u;return[Math.atan2(c*i-s*o,a*r-l*u),tt(s*i+c*o)]}var r=Math.cos(n),u=Math.sin(n),i=Math.cos(t),o=Math.sin(t);return e.invert=function(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,l=Math.sin(t),s=l*i-c*o;return[Math.atan2(c*i+l*o,a*r+s*u),tt(s*r-a*u)]},e}function gr(n,t){var e=Math.cos(n),r=Math.sin(n);return function(u,i,o,a){var c=o*t;null!=u?(u=pr(e,u),i=pr(e,i),(o>0?i>u:u>i)&&(u+=o*La)):(u=n+o*La,i=n-.5*c);for(var l,s=u;o>0?s>i:i>s;s-=c)a.point((l=xe([e,-r*Math.cos(s),-r*Math.sin(s)]))[0],l[1])}}function pr(n,t){var e=pe(t);e[0]-=n,Me(e);var r=nt(-e[1]);return((-e[2]<0?-r:r)+2*Math.PI-Ca)%(2*Math.PI)}function vr(n,t,e){var r=ta.range(n,t-Ca,e).concat(t);return function(n){return r.map(function(t){return[n,t]})}}function dr(n,t,e){var r=ta.range(n,t-Ca,e).concat(t);return function(n){return r.map(function(t){return[t,n]})}}function mr(n){return n.source}function yr(n){return n.target}function Mr(n,t,e,r){var u=Math.cos(t),i=Math.sin(t),o=Math.cos(r),a=Math.sin(r),c=u*Math.cos(n),l=u*Math.sin(n),s=o*Math.cos(e),f=o*Math.sin(e),h=2*Math.asin(Math.sqrt(it(r-t)+u*o*it(e-n))),g=1/Math.sin(h),p=h?function(n){var t=Math.sin(n*=h)*g,e=Math.sin(h-n)*g,r=e*c+t*s,u=e*l+t*f,o=e*i+t*a;return[Math.atan2(u,r)*Pa,Math.atan2(o,Math.sqrt(r*r+u*u))*Pa]}:function(){return[n*Pa,t*Pa]};return p.distance=h,p}function xr(){function n(n,u){var i=Math.sin(u*=Da),o=Math.cos(u),a=ga((n*=Da)-t),c=Math.cos(a);Yc+=Math.atan2(Math.sqrt((a=o*Math.sin(a))*a+(a=r*i-e*o*c)*a),e*i+r*o*c),t=n,e=i,r=o}var t,e,r;Zc.point=function(u,i){t=u*Da,e=Math.sin(i*=Da),r=Math.cos(i),Zc.point=n},Zc.lineEnd=function(){Zc.point=Zc.lineEnd=b}}function br(n,t){function e(t,e){var r=Math.cos(t),u=Math.cos(e),i=n(r*u);return[i*u*Math.sin(t),i*Math.sin(e)]}return e.invert=function(n,e){var r=Math.sqrt(n*n+e*e),u=t(r),i=Math.sin(u),o=Math.cos(u);return[Math.atan2(n*i,r*o),Math.asin(r&&e*i/r)]},e}function _r(n,t){function e(n,t){o>0?-Ra+Ca>t&&(t=-Ra+Ca):t>Ra-Ca&&(t=Ra-Ca);var e=o/Math.pow(u(t),i);return[e*Math.sin(i*n),o-e*Math.cos(i*n)]}var r=Math.cos(n),u=function(n){return Math.tan(qa/4+n/2)},i=n===t?Math.sin(n):Math.log(r/Math.cos(t))/Math.log(u(t)/u(n)),o=r*Math.pow(u(n),i)/i;return i?(e.invert=function(n,t){var e=o-t,r=K(i)*Math.sqrt(n*n+e*e);return[Math.atan2(n,e)/i,2*Math.atan(Math.pow(o/r,1/i))-Ra]},e):Sr}function wr(n,t){function e(n,t){var e=i-t;return[e*Math.sin(u*n),i-e*Math.cos(u*n)]}var r=Math.cos(n),u=n===t?Math.sin(n):(r-Math.cos(t))/(t-n),i=r/u+n;return ga(u)u;u++){for(;r>1&&Q(n[e[r-2]],n[e[r-1]],n[u])<=0;)--r;e[r++]=u}return e.slice(0,r)}function zr(n,t){return n[0]-t[0]||n[1]-t[1]}function qr(n,t,e){return(e[0]-t[0])*(n[1]-t[1])<(e[1]-t[1])*(n[0]-t[0])}function Lr(n,t,e,r){var u=n[0],i=e[0],o=t[0]-u,a=r[0]-i,c=n[1],l=e[1],s=t[1]-c,f=r[1]-l,h=(a*(c-l)-f*(u-i))/(f*o-a*s);return[u+h*o,c+h*s]}function Tr(n){var t=n[0],e=n[n.length-1];return!(t[0]-e[0]||t[1]-e[1])}function Rr(){tu(this),this.edge=this.site=this.circle=null}function Dr(n){var t=el.pop()||new Rr;return t.site=n,t}function Pr(n){Xr(n),Qc.remove(n),el.push(n),tu(n)}function Ur(n){var t=n.circle,e=t.x,r=t.cy,u={x:e,y:r},i=n.P,o=n.N,a=[n];Pr(n);for(var c=i;c.circle&&ga(e-c.circle.x)s;++s)l=a[s],c=a[s-1],Kr(l.edge,c.site,l.site,u);c=a[0],l=a[f-1],l.edge=Jr(c.site,l.site,null,u),Vr(c),Vr(l)}function jr(n){for(var t,e,r,u,i=n.x,o=n.y,a=Qc._;a;)if(r=Fr(a,o)-i,r>Ca)a=a.L;else{if(u=i-Hr(a,o),!(u>Ca)){r>-Ca?(t=a.P,e=a):u>-Ca?(t=a,e=a.N):t=e=a;break}if(!a.R){t=a;break}a=a.R}var c=Dr(n);if(Qc.insert(t,c),t||e){if(t===e)return Xr(t),e=Dr(t.site),Qc.insert(c,e),c.edge=e.edge=Jr(t.site,c.site),Vr(t),void Vr(e);if(!e)return void(c.edge=Jr(t.site,c.site));Xr(t),Xr(e);var l=t.site,s=l.x,f=l.y,h=n.x-s,g=n.y-f,p=e.site,v=p.x-s,d=p.y-f,m=2*(h*d-g*v),y=h*h+g*g,M=v*v+d*d,x={x:(d*y-g*M)/m+s,y:(h*M-v*y)/m+f};Kr(e.edge,l,p,x),c.edge=Jr(l,n,null,x),e.edge=Jr(n,p,null,x),Vr(t),Vr(e)}}function Fr(n,t){var e=n.site,r=e.x,u=e.y,i=u-t;if(!i)return r;var o=n.P;if(!o)return-1/0;e=o.site;var a=e.x,c=e.y,l=c-t;if(!l)return a;var s=a-r,f=1/i-1/l,h=s/l;return f?(-h+Math.sqrt(h*h-2*f*(s*s/(-2*l)-c+l/2+u-i/2)))/f+r:(r+a)/2}function Hr(n,t){var e=n.N;if(e)return Fr(e,t);var r=n.site;return r.y===t?r.x:1/0}function Or(n){this.site=n,this.edges=[]}function Ir(n){for(var t,e,r,u,i,o,a,c,l,s,f=n[0][0],h=n[1][0],g=n[0][1],p=n[1][1],v=Kc,d=v.length;d--;)if(i=v[d],i&&i.prepare())for(a=i.edges,c=a.length,o=0;c>o;)s=a[o].end(),r=s.x,u=s.y,l=a[++o%c].start(),t=l.x,e=l.y,(ga(r-t)>Ca||ga(u-e)>Ca)&&(a.splice(o,0,new Qr(Gr(i.site,s,ga(r-f)Ca?{x:f,y:ga(t-f)Ca?{x:ga(e-p)Ca?{x:h,y:ga(t-h)Ca?{x:ga(e-g)=-za)){var g=c*c+l*l,p=s*s+f*f,v=(f*g-l*p)/h,d=(c*p-s*g)/h,f=d+a,m=rl.pop()||new Zr;m.arc=n,m.site=u,m.x=v+o,m.y=f+Math.sqrt(v*v+d*d),m.cy=f,n.circle=m;for(var y=null,M=tl._;M;)if(m.yd||d>=a)return;if(h>p){if(i){if(i.y>=l)return}else i={x:d,y:c};e={x:d,y:l}}else{if(i){if(i.yr||r>1)if(h>p){if(i){if(i.y>=l)return}else i={x:(c-u)/r,y:c};e={x:(l-u)/r,y:l}}else{if(i){if(i.yg){if(i){if(i.x>=a)return}else i={x:o,y:r*o+u};e={x:a,y:r*a+u}}else{if(i){if(i.xi||f>o||r>h||u>g)){if(p=n.point){var p,v=t-n.x,d=e-n.y,m=v*v+d*d;if(c>m){var y=Math.sqrt(c=m);r=t-y,u=e-y,i=t+y,o=e+y,a=p}}for(var M=n.nodes,x=.5*(s+h),b=.5*(f+g),_=t>=x,w=e>=b,S=w<<1|_,k=S+4;k>S;++S)if(n=M[3&S])switch(3&S){case 0:l(n,s,f,x,b);break;case 1:l(n,x,f,h,b);break;case 2:l(n,s,b,x,g);break;case 3:l(n,x,b,h,g)}}}(n,r,u,i,o),a}function gu(n,t){n=ta.rgb(n),t=ta.rgb(t);var e=n.r,r=n.g,u=n.b,i=t.r-e,o=t.g-r,a=t.b-u;return function(n){return"#"+xt(Math.round(e+i*n))+xt(Math.round(r+o*n))+xt(Math.round(u+a*n))}}function pu(n,t){var e,r={},u={};for(e in n)e in t?r[e]=mu(n[e],t[e]):u[e]=n[e];for(e in t)e in n||(u[e]=t[e]);return function(n){for(e in r)u[e]=r[e](n);return u}}function vu(n,t){return n=+n,t=+t,function(e){return n*(1-e)+t*e}}function du(n,t){var e,r,u,i=il.lastIndex=ol.lastIndex=0,o=-1,a=[],c=[];for(n+="",t+="";(e=il.exec(n))&&(r=ol.exec(t));)(u=r.index)>i&&(u=t.slice(i,u),a[o]?a[o]+=u:a[++o]=u),(e=e[0])===(r=r[0])?a[o]?a[o]+=r:a[++o]=r:(a[++o]=null,c.push({i:o,x:vu(e,r)})),i=ol.lastIndex;return ir;++r)a[(e=c[r]).i]=e.x(n);return a.join("")})}function mu(n,t){for(var e,r=ta.interpolators.length;--r>=0&&!(e=ta.interpolators[r](n,t)););return e}function yu(n,t){var e,r=[],u=[],i=n.length,o=t.length,a=Math.min(n.length,t.length);for(e=0;a>e;++e)r.push(mu(n[e],t[e]));for(;i>e;++e)u[e]=n[e];for(;o>e;++e)u[e]=t[e];return function(n){for(e=0;a>e;++e)u[e]=r[e](n);return u}}function Mu(n){return function(t){return 0>=t?0:t>=1?1:n(t)}}function xu(n){return function(t){return 1-n(1-t)}}function bu(n){return function(t){return.5*(.5>t?n(2*t):2-n(2-2*t))}}function _u(n){return n*n}function wu(n){return n*n*n}function Su(n){if(0>=n)return 0;if(n>=1)return 1;var t=n*n,e=t*n;return 4*(.5>n?e:3*(n-t)+e-.75)}function ku(n){return function(t){return Math.pow(t,n)}}function Eu(n){return 1-Math.cos(n*Ra)}function Au(n){return Math.pow(2,10*(n-1))}function Nu(n){return 1-Math.sqrt(1-n*n)}function Cu(n,t){var e;return arguments.length<2&&(t=.45),arguments.length?e=t/La*Math.asin(1/n):(n=1,e=t/4),function(r){return 1+n*Math.pow(2,-10*r)*Math.sin((r-e)*La/t)}}function zu(n){return n||(n=1.70158),function(t){return t*t*((n+1)*t-n)}}function qu(n){return 1/2.75>n?7.5625*n*n:2/2.75>n?7.5625*(n-=1.5/2.75)*n+.75:2.5/2.75>n?7.5625*(n-=2.25/2.75)*n+.9375:7.5625*(n-=2.625/2.75)*n+.984375}function Lu(n,t){n=ta.hcl(n),t=ta.hcl(t);var e=n.h,r=n.c,u=n.l,i=t.h-e,o=t.c-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.c:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return st(e+i*n,r+o*n,u+a*n)+""}}function Tu(n,t){n=ta.hsl(n),t=ta.hsl(t);var e=n.h,r=n.s,u=n.l,i=t.h-e,o=t.s-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.s:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return ct(e+i*n,r+o*n,u+a*n)+""}}function Ru(n,t){n=ta.lab(n),t=ta.lab(t);var e=n.l,r=n.a,u=n.b,i=t.l-e,o=t.a-r,a=t.b-u;return function(n){return ht(e+i*n,r+o*n,u+a*n)+""}}function Du(n,t){return t-=n,function(e){return Math.round(n+t*e)}}function Pu(n){var t=[n.a,n.b],e=[n.c,n.d],r=ju(t),u=Uu(t,e),i=ju(Fu(e,t,-u))||0;t[0]*e[1]180?s+=360:s-l>180&&(l+=360),u.push({i:r.push(r.pop()+"rotate(",null,")")-2,x:vu(l,s)})):s&&r.push(r.pop()+"rotate("+s+")"),f!=h?u.push({i:r.push(r.pop()+"skewX(",null,")")-2,x:vu(f,h)}):h&&r.push(r.pop()+"skewX("+h+")"),g[0]!=p[0]||g[1]!=p[1]?(e=r.push(r.pop()+"scale(",null,",",null,")"),u.push({i:e-4,x:vu(g[0],p[0])},{i:e-2,x:vu(g[1],p[1])})):(1!=p[0]||1!=p[1])&&r.push(r.pop()+"scale("+p+")"),e=u.length,function(n){for(var t,i=-1;++i=0;)e.push(u[r])}function Qu(n,t){for(var e=[n],r=[];null!=(n=e.pop());)if(r.push(n),(i=n.children)&&(u=i.length))for(var u,i,o=-1;++oe;++e)(t=n[e][1])>u&&(r=e,u=t);return r}function si(n){return n.reduce(fi,0)}function fi(n,t){return n+t[1]}function hi(n,t){return gi(n,Math.ceil(Math.log(t.length)/Math.LN2+1))}function gi(n,t){for(var e=-1,r=+n[0],u=(n[1]-r)/t,i=[];++e<=t;)i[e]=u*e+r;return i}function pi(n){return[ta.min(n),ta.max(n)]}function vi(n,t){return n.value-t.value}function di(n,t){var e=n._pack_next;n._pack_next=t,t._pack_prev=n,t._pack_next=e,e._pack_prev=t}function mi(n,t){n._pack_next=t,t._pack_prev=n}function yi(n,t){var e=t.x-n.x,r=t.y-n.y,u=n.r+t.r;return.999*u*u>e*e+r*r}function Mi(n){function t(n){s=Math.min(n.x-n.r,s),f=Math.max(n.x+n.r,f),h=Math.min(n.y-n.r,h),g=Math.max(n.y+n.r,g)}if((e=n.children)&&(l=e.length)){var e,r,u,i,o,a,c,l,s=1/0,f=-1/0,h=1/0,g=-1/0;if(e.forEach(xi),r=e[0],r.x=-r.r,r.y=0,t(r),l>1&&(u=e[1],u.x=u.r,u.y=0,t(u),l>2))for(i=e[2],wi(r,u,i),t(i),di(r,i),r._pack_prev=i,di(i,u),u=r._pack_next,o=3;l>o;o++){wi(r,u,i=e[o]);var p=0,v=1,d=1;for(a=u._pack_next;a!==u;a=a._pack_next,v++)if(yi(a,i)){p=1;break}if(1==p)for(c=r._pack_prev;c!==a._pack_prev&&!yi(c,i);c=c._pack_prev,d++);p?(d>v||v==d&&u.ro;o++)i=e[o],i.x-=m,i.y-=y,M=Math.max(M,i.r+Math.sqrt(i.x*i.x+i.y*i.y));n.r=M,e.forEach(bi)}}function xi(n){n._pack_next=n._pack_prev=n}function bi(n){delete n._pack_next,delete n._pack_prev}function _i(n,t,e,r){var u=n.children;if(n.x=t+=r*n.x,n.y=e+=r*n.y,n.r*=r,u)for(var i=-1,o=u.length;++i=0;)t=u[i],t.z+=e,t.m+=e,e+=t.s+(r+=t.c)}function Ci(n,t,e){return n.a.parent===t.parent?n.a:e}function zi(n){return 1+ta.max(n,function(n){return n.y})}function qi(n){return n.reduce(function(n,t){return n+t.x},0)/n.length}function Li(n){var t=n.children;return t&&t.length?Li(t[0]):n}function Ti(n){var t,e=n.children;return e&&(t=e.length)?Ti(e[t-1]):n}function Ri(n){return{x:n.x,y:n.y,dx:n.dx,dy:n.dy}}function Di(n,t){var e=n.x+t[3],r=n.y+t[0],u=n.dx-t[1]-t[3],i=n.dy-t[0]-t[2];return 0>u&&(e+=u/2,u=0),0>i&&(r+=i/2,i=0),{x:e,y:r,dx:u,dy:i}}function Pi(n){var t=n[0],e=n[n.length-1];return e>t?[t,e]:[e,t]}function Ui(n){return n.rangeExtent?n.rangeExtent():Pi(n.range())}function ji(n,t,e,r){var u=e(n[0],n[1]),i=r(t[0],t[1]);return function(n){return i(u(n))}}function Fi(n,t){var e,r=0,u=n.length-1,i=n[r],o=n[u];return i>o&&(e=r,r=u,u=e,e=i,i=o,o=e),n[r]=t.floor(i),n[u]=t.ceil(o),n}function Hi(n){return n?{floor:function(t){return Math.floor(t/n)*n},ceil:function(t){return Math.ceil(t/n)*n}}:ml}function Oi(n,t,e,r){var u=[],i=[],o=0,a=Math.min(n.length,t.length)-1;for(n[a]2?Oi:ji,c=r?Iu:Ou;return o=u(n,t,c,e),a=u(t,n,c,mu),i}function i(n){return o(n)}var o,a;return i.invert=function(n){return a(n)},i.domain=function(t){return arguments.length?(n=t.map(Number),u()):n},i.range=function(n){return arguments.length?(t=n,u()):t},i.rangeRound=function(n){return i.range(n).interpolate(Du)},i.clamp=function(n){return arguments.length?(r=n,u()):r},i.interpolate=function(n){return arguments.length?(e=n,u()):e},i.ticks=function(t){return Xi(n,t)},i.tickFormat=function(t,e){return $i(n,t,e)},i.nice=function(t){return Zi(n,t),u()},i.copy=function(){return Ii(n,t,e,r)},u()}function Yi(n,t){return ta.rebind(n,t,"range","rangeRound","interpolate","clamp")}function Zi(n,t){return Fi(n,Hi(Vi(n,t)[2]))}function Vi(n,t){null==t&&(t=10);var e=Pi(n),r=e[1]-e[0],u=Math.pow(10,Math.floor(Math.log(r/t)/Math.LN10)),i=t/r*u;return.15>=i?u*=10:.35>=i?u*=5:.75>=i&&(u*=2),e[0]=Math.ceil(e[0]/u)*u,e[1]=Math.floor(e[1]/u)*u+.5*u,e[2]=u,e}function Xi(n,t){return ta.range.apply(ta,Vi(n,t))}function $i(n,t,e){var r=Vi(n,t);if(e){var u=ic.exec(e);if(u.shift(),"s"===u[8]){var i=ta.formatPrefix(Math.max(ga(r[0]),ga(r[1])));return u[7]||(u[7]="."+Bi(i.scale(r[2]))),u[8]="f",e=ta.format(u.join("")),function(n){return e(i.scale(n))+i.symbol}}u[7]||(u[7]="."+Wi(u[8],r)),e=u.join("")}else e=",."+Bi(r[2])+"f";return ta.format(e)}function Bi(n){return-Math.floor(Math.log(n)/Math.LN10+.01)}function Wi(n,t){var e=Bi(t[2]);return n in yl?Math.abs(e-Bi(Math.max(ga(t[0]),ga(t[1]))))+ +("e"!==n):e-2*("%"===n)}function Ji(n,t,e,r){function u(n){return(e?Math.log(0>n?0:n):-Math.log(n>0?0:-n))/Math.log(t)}function i(n){return e?Math.pow(t,n):-Math.pow(t,-n)}function o(t){return n(u(t))}return o.invert=function(t){return i(n.invert(t))},o.domain=function(t){return arguments.length?(e=t[0]>=0,n.domain((r=t.map(Number)).map(u)),o):r},o.base=function(e){return arguments.length?(t=+e,n.domain(r.map(u)),o):t},o.nice=function(){var t=Fi(r.map(u),e?Math:xl);return n.domain(t),r=t.map(i),o},o.ticks=function(){var n=Pi(r),o=[],a=n[0],c=n[1],l=Math.floor(u(a)),s=Math.ceil(u(c)),f=t%1?2:t;if(isFinite(s-l)){if(e){for(;s>l;l++)for(var h=1;f>h;h++)o.push(i(l)*h);o.push(i(l))}else for(o.push(i(l));l++0;h--)o.push(i(l)*h);for(l=0;o[l]c;s--);o=o.slice(l,s)}return o},o.tickFormat=function(n,t){if(!arguments.length)return Ml;arguments.length<2?t=Ml:"function"!=typeof t&&(t=ta.format(t));var r,a=Math.max(.1,n/o.ticks().length),c=e?(r=1e-12,Math.ceil):(r=-1e-12,Math.floor);return function(n){return n/i(c(u(n)+r))<=a?t(n):""}},o.copy=function(){return Ji(n.copy(),t,e,r)},Yi(o,n)}function Gi(n,t,e){function r(t){return n(u(t))}var u=Ki(t),i=Ki(1/t);return r.invert=function(t){return i(n.invert(t))},r.domain=function(t){return arguments.length?(n.domain((e=t.map(Number)).map(u)),r):e},r.ticks=function(n){return Xi(e,n)},r.tickFormat=function(n,t){return $i(e,n,t)},r.nice=function(n){return r.domain(Zi(e,n))},r.exponent=function(o){return arguments.length?(u=Ki(t=o),i=Ki(1/t),n.domain(e.map(u)),r):t},r.copy=function(){return Gi(n.copy(),t,e)},Yi(r,n)}function Ki(n){return function(t){return 0>t?-Math.pow(-t,n):Math.pow(t,n)}}function Qi(n,t){function e(e){return i[((u.get(e)||("range"===t.t?u.set(e,n.push(e)):0/0))-1)%i.length]}function r(t,e){return ta.range(n.length).map(function(n){return t+e*n})}var u,i,o;return e.domain=function(r){if(!arguments.length)return n;n=[],u=new l;for(var i,o=-1,a=r.length;++oe?[0/0,0/0]:[e>0?a[e-1]:n[0],et?0/0:t/i+n,[t,t+1/i]},r.copy=function(){return to(n,t,e)},u()}function eo(n,t){function e(e){return e>=e?t[ta.bisect(n,e)]:void 0}return e.domain=function(t){return arguments.length?(n=t,e):n},e.range=function(n){return arguments.length?(t=n,e):t},e.invertExtent=function(e){return e=t.indexOf(e),[n[e-1],n[e]]},e.copy=function(){return eo(n,t)},e}function ro(n){function t(n){return+n}return t.invert=t,t.domain=t.range=function(e){return arguments.length?(n=e.map(t),t):n},t.ticks=function(t){return Xi(n,t)},t.tickFormat=function(t,e){return $i(n,t,e)},t.copy=function(){return ro(n)},t}function uo(){return 0}function io(n){return n.innerRadius}function oo(n){return n.outerRadius}function ao(n){return n.startAngle}function co(n){return n.endAngle}function lo(n){return n&&n.padAngle}function so(n,t,e,r){return(n-e)*t-(t-r)*n>0?0:1}function fo(n,t,e,r,u){var i=n[0]-t[0],o=n[1]-t[1],a=(u?r:-r)/Math.sqrt(i*i+o*o),c=a*o,l=-a*i,s=n[0]+c,f=n[1]+l,h=t[0]+c,g=t[1]+l,p=(s+h)/2,v=(f+g)/2,d=h-s,m=g-f,y=d*d+m*m,M=e-r,x=s*g-h*f,b=(0>m?-1:1)*Math.sqrt(M*M*y-x*x),_=(x*m-d*b)/y,w=(-x*d-m*b)/y,S=(x*m+d*b)/y,k=(-x*d+m*b)/y,E=_-p,A=w-v,N=S-p,C=k-v;return E*E+A*A>N*N+C*C&&(_=S,w=k),[[_-c,w-l],[_*e/M,w*e/M]]}function ho(n){function t(t){function o(){l.push("M",i(n(s),a))}for(var c,l=[],s=[],f=-1,h=t.length,g=Et(e),p=Et(r);++f1&&u.push("H",r[0]),u.join("")}function mo(n){for(var t=0,e=n.length,r=n[0],u=[r[0],",",r[1]];++t1){a=t[1],i=n[c],c++,r+="C"+(u[0]+o[0])+","+(u[1]+o[1])+","+(i[0]-a[0])+","+(i[1]-a[1])+","+i[0]+","+i[1];for(var l=2;l9&&(u=3*t/Math.sqrt(u),o[a]=u*e,o[a+1]=u*r));for(a=-1;++a<=c;)u=(n[Math.min(c,a+1)][0]-n[Math.max(0,a-1)][0])/(6*(1+o[a]*o[a])),i.push([u||0,o[a]*u||0]);return i}function To(n){return n.length<3?go(n):n[0]+_o(n,Lo(n))}function Ro(n){for(var t,e,r,u=-1,i=n.length;++ur)return s();var u=i[i.active];u&&(--i.count,delete i[i.active],u.event&&u.event.interrupt.call(n,n.__data__,u.index)),i.active=r,o.event&&o.event.start.call(n,n.__data__,t),o.tween.forEach(function(e,r){(r=r.call(n,n.__data__,t))&&v.push(r)}),h=o.ease,f=o.duration,ta.timer(function(){return p.c=l(e||1)?Ne:l,1},0,a)}function l(e){if(i.active!==r)return 1;for(var u=e/f,a=h(u),c=v.length;c>0;)v[--c].call(n,a);return u>=1?(o.event&&o.event.end.call(n,n.__data__,t),s()):void 0}function s(){return--i.count?delete i[r]:delete n[e],1}var f,h,g=o.delay,p=ec,v=[];return p.t=g+a,u>=g?c(u-g):void(p.c=c)},0,a)}}function Bo(n,t,e){n.attr("transform",function(n){var r=t(n);return"translate("+(isFinite(r)?r:e(n))+",0)"})}function Wo(n,t,e){n.attr("transform",function(n){var r=t(n);return"translate(0,"+(isFinite(r)?r:e(n))+")"})}function Jo(n){return n.toISOString()}function Go(n,t,e){function r(t){return n(t)}function u(n,e){var r=n[1]-n[0],u=r/e,i=ta.bisect(Vl,u);return i==Vl.length?[t.year,Vi(n.map(function(n){return n/31536e6}),e)[2]]:i?t[u/Vl[i-1]1?{floor:function(t){for(;e(t=n.floor(t));)t=Ko(t-1);return t},ceil:function(t){for(;e(t=n.ceil(t));)t=Ko(+t+1);return t}}:n))},r.ticks=function(n,t){var e=Pi(r.domain()),i=null==n?u(e,10):"number"==typeof n?u(e,n):!n.range&&[{range:n},t];return i&&(n=i[0],t=i[1]),n.range(e[0],Ko(+e[1]+1),1>t?1:t)},r.tickFormat=function(){return e},r.copy=function(){return Go(n.copy(),t,e)},Yi(r,n)}function Ko(n){return new Date(n)}function Qo(n){return JSON.parse(n.responseText)}function na(n){var t=ua.createRange();return t.selectNode(ua.body),t.createContextualFragment(n.responseText)}var ta={version:"3.5.5"},ea=[].slice,ra=function(n){return ea.call(n)},ua=this.document;if(ua)try{ra(ua.documentElement.childNodes)[0].nodeType}catch(ia){ra=function(n){for(var t=n.length,e=new Array(t);t--;)e[t]=n[t];return e}}if(Date.now||(Date.now=function(){return+new Date}),ua)try{ua.createElement("DIV").style.setProperty("opacity",0,"")}catch(oa){var aa=this.Element.prototype,ca=aa.setAttribute,la=aa.setAttributeNS,sa=this.CSSStyleDeclaration.prototype,fa=sa.setProperty;aa.setAttribute=function(n,t){ca.call(this,n,t+"")},aa.setAttributeNS=function(n,t,e){la.call(this,n,t,e+"")},sa.setProperty=function(n,t,e){fa.call(this,n,t+"",e)}}ta.ascending=e,ta.descending=function(n,t){return n>t?-1:t>n?1:t>=n?0:0/0},ta.min=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=r){e=r;break}for(;++ur&&(e=r)}else{for(;++u=r){e=r;break}for(;++ur&&(e=r)}return e},ta.max=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=r){e=r;break}for(;++ue&&(e=r)}else{for(;++u=r){e=r;break}for(;++ue&&(e=r)}return e},ta.extent=function(n,t){var e,r,u,i=-1,o=n.length;if(1===arguments.length){for(;++i=r){e=u=r;break}for(;++ir&&(e=r),r>u&&(u=r))}else{for(;++i=r){e=u=r;break}for(;++ir&&(e=r),r>u&&(u=r))}return[e,u]},ta.sum=function(n,t){var e,r=0,i=n.length,o=-1;if(1===arguments.length)for(;++o1?c/(s-1):void 0},ta.deviation=function(){var n=ta.variance.apply(this,arguments);return n?Math.sqrt(n):n};var ha=i(e);ta.bisectLeft=ha.left,ta.bisect=ta.bisectRight=ha.right,ta.bisector=function(n){return i(1===n.length?function(t,r){return e(n(t),r)}:n)},ta.shuffle=function(n,t,e){(i=arguments.length)<3&&(e=n.length,2>i&&(t=0));for(var r,u,i=e-t;i;)u=Math.random()*i--|0,r=n[i+t],n[i+t]=n[u+t],n[u+t]=r;return n},ta.permute=function(n,t){for(var e=t.length,r=new Array(e);e--;)r[e]=n[t[e]];return r},ta.pairs=function(n){for(var t,e=0,r=n.length-1,u=n[0],i=new Array(0>r?0:r);r>e;)i[e]=[t=u,u=n[++e]];return i},ta.zip=function(){if(!(r=arguments.length))return[];for(var n=-1,t=ta.min(arguments,o),e=new Array(t);++n=0;)for(r=n[u],t=r.length;--t>=0;)e[--o]=r[t];return e};var ga=Math.abs;ta.range=function(n,t,e){if(arguments.length<3&&(e=1,arguments.length<2&&(t=n,n=0)),(t-n)/e===1/0)throw new Error("infinite range");var r,u=[],i=a(ga(e)),o=-1;if(n*=i,t*=i,e*=i,0>e)for(;(r=n+e*++o)>t;)u.push(r/i);else for(;(r=n+e*++o)=i.length)return r?r.call(u,o):e?o.sort(e):o;for(var c,s,f,h,g=-1,p=o.length,v=i[a++],d=new l;++g=i.length)return n;var r=[],u=o[e++];return n.forEach(function(n,u){r.push({key:n,values:t(u,e)})}),u?r.sort(function(n,t){return u(n.key,t.key)}):r}var e,r,u={},i=[],o=[];return u.map=function(t,e){return n(e,t,0)},u.entries=function(e){return t(n(ta.map,e,0),0)},u.key=function(n){return i.push(n),u},u.sortKeys=function(n){return o[i.length-1]=n,u},u.sortValues=function(n){return e=n,u},u.rollup=function(n){return r=n,u},u},ta.set=function(n){var t=new m;if(n)for(var e=0,r=n.length;r>e;++e)t.add(n[e]);return t},c(m,{has:h,add:function(n){return this._[s(n+="")]=!0,n},remove:g,values:p,size:v,empty:d,forEach:function(n){for(var t in this._)n.call(this,f(t))}}),ta.behavior={},ta.rebind=function(n,t){for(var e,r=1,u=arguments.length;++r=0&&(r=n.slice(e+1),n=n.slice(0,e)),n)return arguments.length<2?this[n].on(r):this[n].on(r,t);if(2===arguments.length){if(null==t)for(n in this)this.hasOwnProperty(n)&&this[n].on(r,null);return this}},ta.event=null,ta.requote=function(n){return n.replace(ma,"\\$&")};var ma=/[\\\^\$\*\+\?\|\[\]\(\)\.\{\}]/g,ya={}.__proto__?function(n,t){n.__proto__=t}:function(n,t){for(var e in t)n[e]=t[e]},Ma=function(n,t){return t.querySelector(n)},xa=function(n,t){return t.querySelectorAll(n)},ba=function(n,t){var e=n.matches||n[x(n,"matchesSelector")];return(ba=function(n,t){return e.call(n,t)})(n,t)};"function"==typeof Sizzle&&(Ma=function(n,t){return Sizzle(n,t)[0]||null},xa=Sizzle,ba=Sizzle.matchesSelector),ta.selection=function(){return ta.select(ua.documentElement)};var _a=ta.selection.prototype=[];_a.select=function(n){var t,e,r,u,i=[];n=N(n);for(var o=-1,a=this.length;++o=0&&(e=n.slice(0,t),n=n.slice(t+1)),wa.hasOwnProperty(e)?{space:wa[e],local:n}:n}},_a.attr=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node();return n=ta.ns.qualify(n),n.local?e.getAttributeNS(n.space,n.local):e.getAttribute(n)}for(t in n)this.each(z(t,n[t]));return this}return this.each(z(n,t))},_a.classed=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node(),r=(n=T(n)).length,u=-1;if(t=e.classList){for(;++uu){if("string"!=typeof n){2>u&&(e="");for(r in n)this.each(P(r,n[r],e));return this}if(2>u){var i=this.node();return t(i).getComputedStyle(i,null).getPropertyValue(n)}r=""}return this.each(P(n,e,r))},_a.property=function(n,t){if(arguments.length<2){if("string"==typeof n)return this.node()[n];for(t in n)this.each(U(t,n[t]));return this}return this.each(U(n,t))},_a.text=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.textContent=null==t?"":t}:null==n?function(){this.textContent=""}:function(){this.textContent=n}):this.node().textContent},_a.html=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.innerHTML=null==t?"":t}:null==n?function(){this.innerHTML=""}:function(){this.innerHTML=n}):this.node().innerHTML},_a.append=function(n){return n=j(n),this.select(function(){return this.appendChild(n.apply(this,arguments))})},_a.insert=function(n,t){return n=j(n),t=N(t),this.select(function(){return this.insertBefore(n.apply(this,arguments),t.apply(this,arguments)||null)})},_a.remove=function(){return this.each(F)},_a.data=function(n,t){function e(n,e){var r,u,i,o=n.length,f=e.length,h=Math.min(o,f),g=new Array(f),p=new Array(f),v=new Array(o);if(t){var d,m=new l,y=new Array(o);for(r=-1;++rr;++r)p[r]=H(e[r]);for(;o>r;++r)v[r]=n[r]}p.update=g,p.parentNode=g.parentNode=v.parentNode=n.parentNode,a.push(p),c.push(g),s.push(v)}var r,u,i=-1,o=this.length;if(!arguments.length){for(n=new Array(o=(r=this[0]).length);++ii;i++){u.push(t=[]),t.parentNode=(e=this[i]).parentNode;for(var a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return A(u)},_a.order=function(){for(var n=-1,t=this.length;++n=0;)(e=r[u])&&(i&&i!==e.nextSibling&&i.parentNode.insertBefore(e,i),i=e);return this},_a.sort=function(n){n=I.apply(this,arguments);for(var t=-1,e=this.length;++tn;n++)for(var e=this[n],r=0,u=e.length;u>r;r++){var i=e[r];if(i)return i}return null},_a.size=function(){var n=0;return Y(this,function(){++n}),n};var Sa=[];ta.selection.enter=Z,ta.selection.enter.prototype=Sa,Sa.append=_a.append,Sa.empty=_a.empty,Sa.node=_a.node,Sa.call=_a.call,Sa.size=_a.size,Sa.select=function(n){for(var t,e,r,u,i,o=[],a=-1,c=this.length;++ar){if("string"!=typeof n){2>r&&(t=!1);for(e in n)this.each(X(e,n[e],t));return this}if(2>r)return(r=this.node()["__on"+n])&&r._;e=!1}return this.each(X(n,t,e))};var ka=ta.map({mouseenter:"mouseover",mouseleave:"mouseout"});ua&&ka.forEach(function(n){"on"+n in ua&&ka.remove(n)});var Ea,Aa=0;ta.mouse=function(n){return J(n,k())};var Na=this.navigator&&/WebKit/.test(this.navigator.userAgent)?-1:0;ta.touch=function(n,t,e){if(arguments.length<3&&(e=t,t=k().changedTouches),t)for(var r,u=0,i=t.length;i>u;++u)if((r=t[u]).identifier===e)return J(n,r)},ta.behavior.drag=function(){function n(){this.on("mousedown.drag",i).on("touchstart.drag",o)}function e(n,t,e,i,o){return function(){function a(){var n,e,r=t(h,v);r&&(n=r[0]-M[0],e=r[1]-M[1],p|=n|e,M=r,g({type:"drag",x:r[0]+l[0],y:r[1]+l[1],dx:n,dy:e}))}function c(){t(h,v)&&(m.on(i+d,null).on(o+d,null),y(p&&ta.event.target===f),g({type:"dragend"}))}var l,s=this,f=ta.event.target,h=s.parentNode,g=r.of(s,arguments),p=0,v=n(),d=".drag"+(null==v?"":"-"+v),m=ta.select(e(f)).on(i+d,a).on(o+d,c),y=W(f),M=t(h,v);u?(l=u.apply(s,arguments),l=[l.x-M[0],l.y-M[1]]):l=[0,0],g({type:"dragstart"})}}var r=E(n,"drag","dragstart","dragend"),u=null,i=e(b,ta.mouse,t,"mousemove","mouseup"),o=e(G,ta.touch,y,"touchmove","touchend");return n.origin=function(t){return arguments.length?(u=t,n):u},ta.rebind(n,r,"on")},ta.touches=function(n,t){return arguments.length<2&&(t=k().touches),t?ra(t).map(function(t){var e=J(n,t);return e.identifier=t.identifier,e}):[]};var Ca=1e-6,za=Ca*Ca,qa=Math.PI,La=2*qa,Ta=La-Ca,Ra=qa/2,Da=qa/180,Pa=180/qa,Ua=Math.SQRT2,ja=2,Fa=4;ta.interpolateZoom=function(n,t){function e(n){var t=n*y;if(m){var e=rt(v),o=i/(ja*h)*(e*ut(Ua*t+v)-et(v));return[r+o*l,u+o*s,i*e/rt(Ua*t+v)]}return[r+n*l,u+n*s,i*Math.exp(Ua*t)]}var r=n[0],u=n[1],i=n[2],o=t[0],a=t[1],c=t[2],l=o-r,s=a-u,f=l*l+s*s,h=Math.sqrt(f),g=(c*c-i*i+Fa*f)/(2*i*ja*h),p=(c*c-i*i-Fa*f)/(2*c*ja*h),v=Math.log(Math.sqrt(g*g+1)-g),d=Math.log(Math.sqrt(p*p+1)-p),m=d-v,y=(m||Math.log(c/i))/Ua;return e.duration=1e3*y,e},ta.behavior.zoom=function(){function n(n){n.on(q,f).on(Oa+".zoom",g).on("dblclick.zoom",p).on(R,h)}function e(n){return[(n[0]-k.x)/k.k,(n[1]-k.y)/k.k]}function r(n){return[n[0]*k.k+k.x,n[1]*k.k+k.y]}function u(n){k.k=Math.max(N[0],Math.min(N[1],n))}function i(n,t){t=r(t),k.x+=n[0]-t[0],k.y+=n[1]-t[1]}function o(t,e,r,o){t.__chart__={x:k.x,y:k.y,k:k.k},u(Math.pow(2,o)),i(d=e,r),t=ta.select(t),C>0&&(t=t.transition().duration(C)),t.call(n.event)}function a(){b&&b.domain(x.range().map(function(n){return(n-k.x)/k.k}).map(x.invert)),w&&w.domain(_.range().map(function(n){return(n-k.y)/k.k}).map(_.invert))}function c(n){z++||n({type:"zoomstart"})}function l(n){a(),n({type:"zoom",scale:k.k,translate:[k.x,k.y]})}function s(n){--z||n({type:"zoomend"}),d=null}function f(){function n(){f=1,i(ta.mouse(u),g),l(a)}function r(){h.on(L,null).on(T,null),p(f&&ta.event.target===o),s(a)}var u=this,o=ta.event.target,a=D.of(u,arguments),f=0,h=ta.select(t(u)).on(L,n).on(T,r),g=e(ta.mouse(u)),p=W(u);Dl.call(u),c(a)}function h(){function n(){var n=ta.touches(p);return g=k.k,n.forEach(function(n){n.identifier in d&&(d[n.identifier]=e(n))}),n}function t(){var t=ta.event.target;ta.select(t).on(x,r).on(b,a),_.push(t);for(var e=ta.event.changedTouches,u=0,i=e.length;i>u;++u)d[e[u].identifier]=null;var c=n(),l=Date.now();if(1===c.length){if(500>l-M){var s=c[0];o(p,s,d[s.identifier],Math.floor(Math.log(k.k)/Math.LN2)+1),S()}M=l}else if(c.length>1){var s=c[0],f=c[1],h=s[0]-f[0],g=s[1]-f[1];m=h*h+g*g}}function r(){var n,t,e,r,o=ta.touches(p);Dl.call(p);for(var a=0,c=o.length;c>a;++a,r=null)if(e=o[a],r=d[e.identifier]){if(t)break;n=e,t=r}if(r){var s=(s=e[0]-n[0])*s+(s=e[1]-n[1])*s,f=m&&Math.sqrt(s/m);n=[(n[0]+e[0])/2,(n[1]+e[1])/2],t=[(t[0]+r[0])/2,(t[1]+r[1])/2],u(f*g)}M=null,i(n,t),l(v)}function a(){if(ta.event.touches.length){for(var t=ta.event.changedTouches,e=0,r=t.length;r>e;++e)delete d[t[e].identifier];for(var u in d)return void n()}ta.selectAll(_).on(y,null),w.on(q,f).on(R,h),E(),s(v)}var g,p=this,v=D.of(p,arguments),d={},m=0,y=".zoom-"+ta.event.changedTouches[0].identifier,x="touchmove"+y,b="touchend"+y,_=[],w=ta.select(p),E=W(p);t(),c(v),w.on(q,null).on(R,t)}function g(){var n=D.of(this,arguments);y?clearTimeout(y):(v=e(d=m||ta.mouse(this)),Dl.call(this),c(n)),y=setTimeout(function(){y=null,s(n)},50),S(),u(Math.pow(2,.002*Ha())*k.k),i(d,v),l(n)}function p(){var n=ta.mouse(this),t=Math.log(k.k)/Math.LN2;o(this,n,e(n),ta.event.shiftKey?Math.ceil(t)-1:Math.floor(t)+1)}var v,d,m,y,M,x,b,_,w,k={x:0,y:0,k:1},A=[960,500],N=Ia,C=250,z=0,q="mousedown.zoom",L="mousemove.zoom",T="mouseup.zoom",R="touchstart.zoom",D=E(n,"zoomstart","zoom","zoomend");return Oa||(Oa="onwheel"in ua?(Ha=function(){return-ta.event.deltaY*(ta.event.deltaMode?120:1)},"wheel"):"onmousewheel"in ua?(Ha=function(){return ta.event.wheelDelta},"mousewheel"):(Ha=function(){return-ta.event.detail},"MozMousePixelScroll")),n.event=function(n){n.each(function(){var n=D.of(this,arguments),t=k;Tl?ta.select(this).transition().each("start.zoom",function(){k=this.__chart__||{x:0,y:0,k:1},c(n)}).tween("zoom:zoom",function(){var e=A[0],r=A[1],u=d?d[0]:e/2,i=d?d[1]:r/2,o=ta.interpolateZoom([(u-k.x)/k.k,(i-k.y)/k.k,e/k.k],[(u-t.x)/t.k,(i-t.y)/t.k,e/t.k]);return function(t){var r=o(t),a=e/r[2];this.__chart__=k={x:u-r[0]*a,y:i-r[1]*a,k:a},l(n)}}).each("interrupt.zoom",function(){s(n)}).each("end.zoom",function(){s(n)}):(this.__chart__=k,c(n),l(n),s(n))})},n.translate=function(t){return arguments.length?(k={x:+t[0],y:+t[1],k:k.k},a(),n):[k.x,k.y]},n.scale=function(t){return arguments.length?(k={x:k.x,y:k.y,k:+t},a(),n):k.k},n.scaleExtent=function(t){return arguments.length?(N=null==t?Ia:[+t[0],+t[1]],n):N},n.center=function(t){return arguments.length?(m=t&&[+t[0],+t[1]],n):m},n.size=function(t){return arguments.length?(A=t&&[+t[0],+t[1]],n):A},n.duration=function(t){return arguments.length?(C=+t,n):C},n.x=function(t){return arguments.length?(b=t,x=t.copy(),k={x:0,y:0,k:1},n):b},n.y=function(t){return arguments.length?(w=t,_=t.copy(),k={x:0,y:0,k:1},n):w},ta.rebind(n,D,"on")};var Ha,Oa,Ia=[0,1/0];ta.color=ot,ot.prototype.toString=function(){return this.rgb()+""},ta.hsl=at;var Ya=at.prototype=new ot;Ya.brighter=function(n){return n=Math.pow(.7,arguments.length?n:1),new at(this.h,this.s,this.l/n)},Ya.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),new at(this.h,this.s,n*this.l)},Ya.rgb=function(){return ct(this.h,this.s,this.l)},ta.hcl=lt;var Za=lt.prototype=new ot;Za.brighter=function(n){return new lt(this.h,this.c,Math.min(100,this.l+Va*(arguments.length?n:1)))},Za.darker=function(n){return new lt(this.h,this.c,Math.max(0,this.l-Va*(arguments.length?n:1)))},Za.rgb=function(){return st(this.h,this.c,this.l).rgb()},ta.lab=ft;var Va=18,Xa=.95047,$a=1,Ba=1.08883,Wa=ft.prototype=new ot;Wa.brighter=function(n){return new ft(Math.min(100,this.l+Va*(arguments.length?n:1)),this.a,this.b)},Wa.darker=function(n){return new ft(Math.max(0,this.l-Va*(arguments.length?n:1)),this.a,this.b)},Wa.rgb=function(){return ht(this.l,this.a,this.b)},ta.rgb=mt;var Ja=mt.prototype=new ot;Ja.brighter=function(n){n=Math.pow(.7,arguments.length?n:1);var t=this.r,e=this.g,r=this.b,u=30;return t||e||r?(t&&u>t&&(t=u),e&&u>e&&(e=u),r&&u>r&&(r=u),new mt(Math.min(255,t/n),Math.min(255,e/n),Math.min(255,r/n))):new mt(u,u,u)},Ja.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),new mt(n*this.r,n*this.g,n*this.b)},Ja.hsl=function(){return _t(this.r,this.g,this.b)},Ja.toString=function(){return"#"+xt(this.r)+xt(this.g)+xt(this.b)};var Ga=ta.map({aliceblue:15792383,antiquewhite:16444375,aqua:65535,aquamarine:8388564,azure:15794175,beige:16119260,bisque:16770244,black:0,blanchedalmond:16772045,blue:255,blueviolet:9055202,brown:10824234,burlywood:14596231,cadetblue:6266528,chartreuse:8388352,chocolate:13789470,coral:16744272,cornflowerblue:6591981,cornsilk:16775388,crimson:14423100,cyan:65535,darkblue:139,darkcyan:35723,darkgoldenrod:12092939,darkgray:11119017,darkgreen:25600,darkgrey:11119017,darkkhaki:12433259,darkmagenta:9109643,darkolivegreen:5597999,darkorange:16747520,darkorchid:10040012,darkred:9109504,darksalmon:15308410,darkseagreen:9419919,darkslateblue:4734347,darkslategray:3100495,darkslategrey:3100495,darkturquoise:52945,darkviolet:9699539,deeppink:16716947,deepskyblue:49151,dimgray:6908265,dimgrey:6908265,dodgerblue:2003199,firebrick:11674146,floralwhite:16775920,forestgreen:2263842,fuchsia:16711935,gainsboro:14474460,ghostwhite:16316671,gold:16766720,goldenrod:14329120,gray:8421504,green:32768,greenyellow:11403055,grey:8421504,honeydew:15794160,hotpink:16738740,indianred:13458524,indigo:4915330,ivory:16777200,khaki:15787660,lavender:15132410,lavenderblush:16773365,lawngreen:8190976,lemonchiffon:16775885,lightblue:11393254,lightcoral:15761536,lightcyan:14745599,lightgoldenrodyellow:16448210,lightgray:13882323,lightgreen:9498256,lightgrey:13882323,lightpink:16758465,lightsalmon:16752762,lightseagreen:2142890,lightskyblue:8900346,lightslategray:7833753,lightslategrey:7833753,lightsteelblue:11584734,lightyellow:16777184,lime:65280,limegreen:3329330,linen:16445670,magenta:16711935,maroon:8388608,mediumaquamarine:6737322,mediumblue:205,mediumorchid:12211667,mediumpurple:9662683,mediumseagreen:3978097,mediumslateblue:8087790,mediumspringgreen:64154,mediumturquoise:4772300,mediumvioletred:13047173,midnightblue:1644912,mintcream:16121850,mistyrose:16770273,moccasin:16770229,navajowhite:16768685,navy:128,oldlace:16643558,olive:8421376,olivedrab:7048739,orange:16753920,orangered:16729344,orchid:14315734,palegoldenrod:15657130,palegreen:10025880,paleturquoise:11529966,palevioletred:14381203,papayawhip:16773077,peachpuff:16767673,peru:13468991,pink:16761035,plum:14524637,powderblue:11591910,purple:8388736,rebeccapurple:6697881,red:16711680,rosybrown:12357519,royalblue:4286945,saddlebrown:9127187,salmon:16416882,sandybrown:16032864,seagreen:3050327,seashell:16774638,sienna:10506797,silver:12632256,skyblue:8900331,slateblue:6970061,slategray:7372944,slategrey:7372944,snow:16775930,springgreen:65407,steelblue:4620980,tan:13808780,teal:32896,thistle:14204888,tomato:16737095,turquoise:4251856,violet:15631086,wheat:16113331,white:16777215,whitesmoke:16119285,yellow:16776960,yellowgreen:10145074});Ga.forEach(function(n,t){Ga.set(n,yt(t))}),ta.functor=Et,ta.xhr=At(y),ta.dsv=function(n,t){function e(n,e,i){arguments.length<3&&(i=e,e=null);var o=Nt(n,t,null==e?r:u(e),i);return o.row=function(n){return arguments.length?o.response(null==(e=n)?r:u(n)):e},o}function r(n){return e.parse(n.responseText)}function u(n){return function(t){return e.parse(t.responseText,n)}}function i(t){return t.map(o).join(n)}function o(n){return a.test(n)?'"'+n.replace(/\"/g,'""')+'"':n}var a=new RegExp('["'+n+"\n]"),c=n.charCodeAt(0);return e.parse=function(n,t){var r;return e.parseRows(n,function(n,e){if(r)return r(n,e-1);var u=new Function("d","return {"+n.map(function(n,t){return JSON.stringify(n)+": d["+t+"]"}).join(",")+"}");r=t?function(n,e){return t(u(n),e)}:u})},e.parseRows=function(n,t){function e(){if(s>=l)return o;if(u)return u=!1,i;var t=s;if(34===n.charCodeAt(t)){for(var e=t;e++s;){var r=n.charCodeAt(s++),a=1;if(10===r)u=!0;else if(13===r)u=!0,10===n.charCodeAt(s)&&(++s,++a);else if(r!==c)continue;return n.slice(t,s-a)}return n.slice(t)}for(var r,u,i={},o={},a=[],l=n.length,s=0,f=0;(r=e())!==o;){for(var h=[];r!==i&&r!==o;)h.push(r),r=e();t&&null==(h=t(h,f++))||a.push(h)}return a},e.format=function(t){if(Array.isArray(t[0]))return e.formatRows(t);var r=new m,u=[];return t.forEach(function(n){for(var t in n)r.has(t)||u.push(r.add(t))}),[u.map(o).join(n)].concat(t.map(function(t){return u.map(function(n){return o(t[n])}).join(n)})).join("\n")},e.formatRows=function(n){return n.map(i).join("\n")},e},ta.csv=ta.dsv(",","text/csv"),ta.tsv=ta.dsv(" ","text/tab-separated-values");var Ka,Qa,nc,tc,ec,rc=this[x(this,"requestAnimationFrame")]||function(n){setTimeout(n,17)};ta.timer=function(n,t,e){var r=arguments.length;2>r&&(t=0),3>r&&(e=Date.now());var u=e+t,i={c:n,t:u,f:!1,n:null};Qa?Qa.n=i:Ka=i,Qa=i,nc||(tc=clearTimeout(tc),nc=1,rc(qt))},ta.timer.flush=function(){Lt(),Tt()},ta.round=function(n,t){return t?Math.round(n*(t=Math.pow(10,t)))/t:Math.round(n)};var uc=["y","z","a","f","p","n","\xb5","m","","k","M","G","T","P","E","Z","Y"].map(Dt);ta.formatPrefix=function(n,t){var e=0;return n&&(0>n&&(n*=-1),t&&(n=ta.round(n,Rt(n,t))),e=1+Math.floor(1e-12+Math.log(n)/Math.LN10),e=Math.max(-24,Math.min(24,3*Math.floor((e-1)/3)))),uc[8+e/3]};var ic=/(?:([^{])?([<>=^]))?([+\- ])?([$#])?(0)?(\d+)?(,)?(\.-?\d+)?([a-z%])?/i,oc=ta.map({b:function(n){return n.toString(2)},c:function(n){return String.fromCharCode(n)},o:function(n){return n.toString(8)},x:function(n){return n.toString(16)},X:function(n){return n.toString(16).toUpperCase()},g:function(n,t){return n.toPrecision(t)},e:function(n,t){return n.toExponential(t)},f:function(n,t){return n.toFixed(t)},r:function(n,t){return(n=ta.round(n,Rt(n,t))).toFixed(Math.max(0,Math.min(20,Rt(n*(1+1e-15),t))))}}),ac=ta.time={},cc=Date;jt.prototype={getDate:function(){return this._.getUTCDate()},getDay:function(){return this._.getUTCDay()},getFullYear:function(){return this._.getUTCFullYear()},getHours:function(){return this._.getUTCHours()},getMilliseconds:function(){return this._.getUTCMilliseconds()},getMinutes:function(){return this._.getUTCMinutes()},getMonth:function(){return this._.getUTCMonth()},getSeconds:function(){return this._.getUTCSeconds()},getTime:function(){return this._.getTime()},getTimezoneOffset:function(){return 0},valueOf:function(){return this._.valueOf()},setDate:function(){lc.setUTCDate.apply(this._,arguments)},setDay:function(){lc.setUTCDay.apply(this._,arguments)},setFullYear:function(){lc.setUTCFullYear.apply(this._,arguments)},setHours:function(){lc.setUTCHours.apply(this._,arguments)},setMilliseconds:function(){lc.setUTCMilliseconds.apply(this._,arguments)},setMinutes:function(){lc.setUTCMinutes.apply(this._,arguments)},setMonth:function(){lc.setUTCMonth.apply(this._,arguments)},setSeconds:function(){lc.setUTCSeconds.apply(this._,arguments)},setTime:function(){lc.setTime.apply(this._,arguments)}};var lc=Date.prototype;ac.year=Ft(function(n){return n=ac.day(n),n.setMonth(0,1),n},function(n,t){n.setFullYear(n.getFullYear()+t)},function(n){return n.getFullYear()}),ac.years=ac.year.range,ac.years.utc=ac.year.utc.range,ac.day=Ft(function(n){var t=new cc(2e3,0);return t.setFullYear(n.getFullYear(),n.getMonth(),n.getDate()),t},function(n,t){n.setDate(n.getDate()+t)},function(n){return n.getDate()-1}),ac.days=ac.day.range,ac.days.utc=ac.day.utc.range,ac.dayOfYear=function(n){var t=ac.year(n);return Math.floor((n-t-6e4*(n.getTimezoneOffset()-t.getTimezoneOffset()))/864e5)},["sunday","monday","tuesday","wednesday","thursday","friday","saturday"].forEach(function(n,t){t=7-t;var e=ac[n]=Ft(function(n){return(n=ac.day(n)).setDate(n.getDate()-(n.getDay()+t)%7),n},function(n,t){n.setDate(n.getDate()+7*Math.floor(t))},function(n){var e=ac.year(n).getDay();return Math.floor((ac.dayOfYear(n)+(e+t)%7)/7)-(e!==t)});ac[n+"s"]=e.range,ac[n+"s"].utc=e.utc.range,ac[n+"OfYear"]=function(n){var e=ac.year(n).getDay();return Math.floor((ac.dayOfYear(n)+(e+t)%7)/7)}}),ac.week=ac.sunday,ac.weeks=ac.sunday.range,ac.weeks.utc=ac.sunday.utc.range,ac.weekOfYear=ac.sundayOfYear;var sc={"-":"",_:" ",0:"0"},fc=/^\s*\d+/,hc=/^%/;ta.locale=function(n){return{numberFormat:Pt(n),timeFormat:Ot(n)}};var gc=ta.locale({decimal:".",thousands:",",grouping:[3],currency:["$",""],dateTime:"%a %b %e %X %Y",date:"%m/%d/%Y",time:"%H:%M:%S",periods:["AM","PM"],days:["Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"],shortDays:["Sun","Mon","Tue","Wed","Thu","Fri","Sat"],months:["January","February","March","April","May","June","July","August","September","October","November","December"],shortMonths:["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]});ta.format=gc.numberFormat,ta.geo={},ce.prototype={s:0,t:0,add:function(n){le(n,this.t,pc),le(pc.s,this.s,this),this.s?this.t+=pc.t:this.s=pc.t +},reset:function(){this.s=this.t=0},valueOf:function(){return this.s}};var pc=new ce;ta.geo.stream=function(n,t){n&&vc.hasOwnProperty(n.type)?vc[n.type](n,t):se(n,t)};var vc={Feature:function(n,t){se(n.geometry,t)},FeatureCollection:function(n,t){for(var e=n.features,r=-1,u=e.length;++rn?4*qa+n:n,Mc.lineStart=Mc.lineEnd=Mc.point=b}};ta.geo.bounds=function(){function n(n,t){M.push(x=[s=n,h=n]),f>t&&(f=t),t>g&&(g=t)}function t(t,e){var r=pe([t*Da,e*Da]);if(m){var u=de(m,r),i=[u[1],-u[0],0],o=de(i,u);Me(o),o=xe(o);var c=t-p,l=c>0?1:-1,v=o[0]*Pa*l,d=ga(c)>180;if(d^(v>l*p&&l*t>v)){var y=o[1]*Pa;y>g&&(g=y)}else if(v=(v+360)%360-180,d^(v>l*p&&l*t>v)){var y=-o[1]*Pa;f>y&&(f=y)}else f>e&&(f=e),e>g&&(g=e);d?p>t?a(s,t)>a(s,h)&&(h=t):a(t,h)>a(s,h)&&(s=t):h>=s?(s>t&&(s=t),t>h&&(h=t)):t>p?a(s,t)>a(s,h)&&(h=t):a(t,h)>a(s,h)&&(s=t)}else n(t,e);m=r,p=t}function e(){b.point=t}function r(){x[0]=s,x[1]=h,b.point=n,m=null}function u(n,e){if(m){var r=n-p;y+=ga(r)>180?r+(r>0?360:-360):r}else v=n,d=e;Mc.point(n,e),t(n,e)}function i(){Mc.lineStart()}function o(){u(v,d),Mc.lineEnd(),ga(y)>Ca&&(s=-(h=180)),x[0]=s,x[1]=h,m=null}function a(n,t){return(t-=n)<0?t+360:t}function c(n,t){return n[0]-t[0]}function l(n,t){return t[0]<=t[1]?t[0]<=n&&n<=t[1]:nyc?(s=-(h=180),f=-(g=90)):y>Ca?g=90:-Ca>y&&(f=-90),x[0]=s,x[1]=h}};return function(n){g=h=-(s=f=1/0),M=[],ta.geo.stream(n,b);var t=M.length;if(t){M.sort(c);for(var e,r=1,u=M[0],i=[u];t>r;++r)e=M[r],l(e[0],u)||l(e[1],u)?(a(u[0],e[1])>a(u[0],u[1])&&(u[1]=e[1]),a(e[0],u[1])>a(u[0],u[1])&&(u[0]=e[0])):i.push(u=e);for(var o,e,p=-1/0,t=i.length-1,r=0,u=i[t];t>=r;u=e,++r)e=i[r],(o=a(u[1],e[0]))>p&&(p=o,s=e[0],h=u[1])}return M=x=null,1/0===s||1/0===f?[[0/0,0/0],[0/0,0/0]]:[[s,f],[h,g]]}}(),ta.geo.centroid=function(n){xc=bc=_c=wc=Sc=kc=Ec=Ac=Nc=Cc=zc=0,ta.geo.stream(n,qc);var t=Nc,e=Cc,r=zc,u=t*t+e*e+r*r;return za>u&&(t=kc,e=Ec,r=Ac,Ca>bc&&(t=_c,e=wc,r=Sc),u=t*t+e*e+r*r,za>u)?[0/0,0/0]:[Math.atan2(e,t)*Pa,tt(r/Math.sqrt(u))*Pa]};var xc,bc,_c,wc,Sc,kc,Ec,Ac,Nc,Cc,zc,qc={sphere:b,point:_e,lineStart:Se,lineEnd:ke,polygonStart:function(){qc.lineStart=Ee},polygonEnd:function(){qc.lineStart=Se}},Lc=Le(Ne,Pe,je,[-qa,-qa/2]),Tc=1e9;ta.geo.clipExtent=function(){var n,t,e,r,u,i,o={stream:function(n){return u&&(u.valid=!1),u=i(n),u.valid=!0,u},extent:function(a){return arguments.length?(i=Ie(n=+a[0][0],t=+a[0][1],e=+a[1][0],r=+a[1][1]),u&&(u.valid=!1,u=null),o):[[n,t],[e,r]]}};return o.extent([[0,0],[960,500]])},(ta.geo.conicEqualArea=function(){return Ye(Ze)}).raw=Ze,ta.geo.albers=function(){return ta.geo.conicEqualArea().rotate([96,0]).center([-.6,38.7]).parallels([29.5,45.5]).scale(1070)},ta.geo.albersUsa=function(){function n(n){var i=n[0],o=n[1];return t=null,e(i,o),t||(r(i,o),t)||u(i,o),t}var t,e,r,u,i=ta.geo.albers(),o=ta.geo.conicEqualArea().rotate([154,0]).center([-2,58.5]).parallels([55,65]),a=ta.geo.conicEqualArea().rotate([157,0]).center([-3,19.9]).parallels([8,18]),c={point:function(n,e){t=[n,e]}};return n.invert=function(n){var t=i.scale(),e=i.translate(),r=(n[0]-e[0])/t,u=(n[1]-e[1])/t;return(u>=.12&&.234>u&&r>=-.425&&-.214>r?o:u>=.166&&.234>u&&r>=-.214&&-.115>r?a:i).invert(n)},n.stream=function(n){var t=i.stream(n),e=o.stream(n),r=a.stream(n);return{point:function(n,u){t.point(n,u),e.point(n,u),r.point(n,u)},sphere:function(){t.sphere(),e.sphere(),r.sphere()},lineStart:function(){t.lineStart(),e.lineStart(),r.lineStart()},lineEnd:function(){t.lineEnd(),e.lineEnd(),r.lineEnd()},polygonStart:function(){t.polygonStart(),e.polygonStart(),r.polygonStart()},polygonEnd:function(){t.polygonEnd(),e.polygonEnd(),r.polygonEnd()}}},n.precision=function(t){return arguments.length?(i.precision(t),o.precision(t),a.precision(t),n):i.precision()},n.scale=function(t){return arguments.length?(i.scale(t),o.scale(.35*t),a.scale(t),n.translate(i.translate())):i.scale()},n.translate=function(t){if(!arguments.length)return i.translate();var l=i.scale(),s=+t[0],f=+t[1];return e=i.translate(t).clipExtent([[s-.455*l,f-.238*l],[s+.455*l,f+.238*l]]).stream(c).point,r=o.translate([s-.307*l,f+.201*l]).clipExtent([[s-.425*l+Ca,f+.12*l+Ca],[s-.214*l-Ca,f+.234*l-Ca]]).stream(c).point,u=a.translate([s-.205*l,f+.212*l]).clipExtent([[s-.214*l+Ca,f+.166*l+Ca],[s-.115*l-Ca,f+.234*l-Ca]]).stream(c).point,n},n.scale(1070)};var Rc,Dc,Pc,Uc,jc,Fc,Hc={point:b,lineStart:b,lineEnd:b,polygonStart:function(){Dc=0,Hc.lineStart=Ve},polygonEnd:function(){Hc.lineStart=Hc.lineEnd=Hc.point=b,Rc+=ga(Dc/2)}},Oc={point:Xe,lineStart:b,lineEnd:b,polygonStart:b,polygonEnd:b},Ic={point:We,lineStart:Je,lineEnd:Ge,polygonStart:function(){Ic.lineStart=Ke},polygonEnd:function(){Ic.point=We,Ic.lineStart=Je,Ic.lineEnd=Ge}};ta.geo.path=function(){function n(n){return n&&("function"==typeof a&&i.pointRadius(+a.apply(this,arguments)),o&&o.valid||(o=u(i)),ta.geo.stream(n,o)),i.result()}function t(){return o=null,n}var e,r,u,i,o,a=4.5;return n.area=function(n){return Rc=0,ta.geo.stream(n,u(Hc)),Rc},n.centroid=function(n){return _c=wc=Sc=kc=Ec=Ac=Nc=Cc=zc=0,ta.geo.stream(n,u(Ic)),zc?[Nc/zc,Cc/zc]:Ac?[kc/Ac,Ec/Ac]:Sc?[_c/Sc,wc/Sc]:[0/0,0/0]},n.bounds=function(n){return jc=Fc=-(Pc=Uc=1/0),ta.geo.stream(n,u(Oc)),[[Pc,Uc],[jc,Fc]]},n.projection=function(n){return arguments.length?(u=(e=n)?n.stream||tr(n):y,t()):e},n.context=function(n){return arguments.length?(i=null==(r=n)?new $e:new Qe(n),"function"!=typeof a&&i.pointRadius(a),t()):r},n.pointRadius=function(t){return arguments.length?(a="function"==typeof t?t:(i.pointRadius(+t),+t),n):a},n.projection(ta.geo.albersUsa()).context(null)},ta.geo.transform=function(n){return{stream:function(t){var e=new er(t);for(var r in n)e[r]=n[r];return e}}},er.prototype={point:function(n,t){this.stream.point(n,t)},sphere:function(){this.stream.sphere()},lineStart:function(){this.stream.lineStart()},lineEnd:function(){this.stream.lineEnd()},polygonStart:function(){this.stream.polygonStart()},polygonEnd:function(){this.stream.polygonEnd()}},ta.geo.projection=ur,ta.geo.projectionMutator=ir,(ta.geo.equirectangular=function(){return ur(ar)}).raw=ar.invert=ar,ta.geo.rotation=function(n){function t(t){return t=n(t[0]*Da,t[1]*Da),t[0]*=Pa,t[1]*=Pa,t}return n=lr(n[0]%360*Da,n[1]*Da,n.length>2?n[2]*Da:0),t.invert=function(t){return t=n.invert(t[0]*Da,t[1]*Da),t[0]*=Pa,t[1]*=Pa,t},t},cr.invert=ar,ta.geo.circle=function(){function n(){var n="function"==typeof r?r.apply(this,arguments):r,t=lr(-n[0]*Da,-n[1]*Da,0).invert,u=[];return e(null,null,1,{point:function(n,e){u.push(n=t(n,e)),n[0]*=Pa,n[1]*=Pa}}),{type:"Polygon",coordinates:[u]}}var t,e,r=[0,0],u=6;return n.origin=function(t){return arguments.length?(r=t,n):r},n.angle=function(r){return arguments.length?(e=gr((t=+r)*Da,u*Da),n):t},n.precision=function(r){return arguments.length?(e=gr(t*Da,(u=+r)*Da),n):u},n.angle(90)},ta.geo.distance=function(n,t){var e,r=(t[0]-n[0])*Da,u=n[1]*Da,i=t[1]*Da,o=Math.sin(r),a=Math.cos(r),c=Math.sin(u),l=Math.cos(u),s=Math.sin(i),f=Math.cos(i);return Math.atan2(Math.sqrt((e=f*o)*e+(e=l*s-c*f*a)*e),c*s+l*f*a)},ta.geo.graticule=function(){function n(){return{type:"MultiLineString",coordinates:t()}}function t(){return ta.range(Math.ceil(i/d)*d,u,d).map(h).concat(ta.range(Math.ceil(l/m)*m,c,m).map(g)).concat(ta.range(Math.ceil(r/p)*p,e,p).filter(function(n){return ga(n%d)>Ca}).map(s)).concat(ta.range(Math.ceil(a/v)*v,o,v).filter(function(n){return ga(n%m)>Ca}).map(f))}var e,r,u,i,o,a,c,l,s,f,h,g,p=10,v=p,d=90,m=360,y=2.5;return n.lines=function(){return t().map(function(n){return{type:"LineString",coordinates:n}})},n.outline=function(){return{type:"Polygon",coordinates:[h(i).concat(g(c).slice(1),h(u).reverse().slice(1),g(l).reverse().slice(1))]}},n.extent=function(t){return arguments.length?n.majorExtent(t).minorExtent(t):n.minorExtent()},n.majorExtent=function(t){return arguments.length?(i=+t[0][0],u=+t[1][0],l=+t[0][1],c=+t[1][1],i>u&&(t=i,i=u,u=t),l>c&&(t=l,l=c,c=t),n.precision(y)):[[i,l],[u,c]]},n.minorExtent=function(t){return arguments.length?(r=+t[0][0],e=+t[1][0],a=+t[0][1],o=+t[1][1],r>e&&(t=r,r=e,e=t),a>o&&(t=a,a=o,o=t),n.precision(y)):[[r,a],[e,o]]},n.step=function(t){return arguments.length?n.majorStep(t).minorStep(t):n.minorStep()},n.majorStep=function(t){return arguments.length?(d=+t[0],m=+t[1],n):[d,m]},n.minorStep=function(t){return arguments.length?(p=+t[0],v=+t[1],n):[p,v]},n.precision=function(t){return arguments.length?(y=+t,s=vr(a,o,90),f=dr(r,e,y),h=vr(l,c,90),g=dr(i,u,y),n):y},n.majorExtent([[-180,-90+Ca],[180,90-Ca]]).minorExtent([[-180,-80-Ca],[180,80+Ca]])},ta.geo.greatArc=function(){function n(){return{type:"LineString",coordinates:[t||r.apply(this,arguments),e||u.apply(this,arguments)]}}var t,e,r=mr,u=yr;return n.distance=function(){return ta.geo.distance(t||r.apply(this,arguments),e||u.apply(this,arguments))},n.source=function(e){return arguments.length?(r=e,t="function"==typeof e?null:e,n):r},n.target=function(t){return arguments.length?(u=t,e="function"==typeof t?null:t,n):u},n.precision=function(){return arguments.length?n:0},n},ta.geo.interpolate=function(n,t){return Mr(n[0]*Da,n[1]*Da,t[0]*Da,t[1]*Da)},ta.geo.length=function(n){return Yc=0,ta.geo.stream(n,Zc),Yc};var Yc,Zc={sphere:b,point:b,lineStart:xr,lineEnd:b,polygonStart:b,polygonEnd:b},Vc=br(function(n){return Math.sqrt(2/(1+n))},function(n){return 2*Math.asin(n/2)});(ta.geo.azimuthalEqualArea=function(){return ur(Vc)}).raw=Vc;var Xc=br(function(n){var t=Math.acos(n);return t&&t/Math.sin(t)},y);(ta.geo.azimuthalEquidistant=function(){return ur(Xc)}).raw=Xc,(ta.geo.conicConformal=function(){return Ye(_r)}).raw=_r,(ta.geo.conicEquidistant=function(){return Ye(wr)}).raw=wr;var $c=br(function(n){return 1/n},Math.atan);(ta.geo.gnomonic=function(){return ur($c)}).raw=$c,Sr.invert=function(n,t){return[n,2*Math.atan(Math.exp(t))-Ra]},(ta.geo.mercator=function(){return kr(Sr)}).raw=Sr;var Bc=br(function(){return 1},Math.asin);(ta.geo.orthographic=function(){return ur(Bc)}).raw=Bc;var Wc=br(function(n){return 1/(1+n)},function(n){return 2*Math.atan(n)});(ta.geo.stereographic=function(){return ur(Wc)}).raw=Wc,Er.invert=function(n,t){return[-t,2*Math.atan(Math.exp(n))-Ra]},(ta.geo.transverseMercator=function(){var n=kr(Er),t=n.center,e=n.rotate;return n.center=function(n){return n?t([-n[1],n[0]]):(n=t(),[n[1],-n[0]])},n.rotate=function(n){return n?e([n[0],n[1],n.length>2?n[2]+90:90]):(n=e(),[n[0],n[1],n[2]-90])},e([0,0,90])}).raw=Er,ta.geom={},ta.geom.hull=function(n){function t(n){if(n.length<3)return[];var t,u=Et(e),i=Et(r),o=n.length,a=[],c=[];for(t=0;o>t;t++)a.push([+u.call(this,n[t],t),+i.call(this,n[t],t),t]);for(a.sort(zr),t=0;o>t;t++)c.push([a[t][0],-a[t][1]]);var l=Cr(a),s=Cr(c),f=s[0]===l[0],h=s[s.length-1]===l[l.length-1],g=[];for(t=l.length-1;t>=0;--t)g.push(n[a[l[t]][2]]);for(t=+f;t=r&&l.x<=i&&l.y>=u&&l.y<=o?[[r,o],[i,o],[i,u],[r,u]]:[];s.point=n[a]}),t}function e(n){return n.map(function(n,t){return{x:Math.round(i(n,t)/Ca)*Ca,y:Math.round(o(n,t)/Ca)*Ca,i:t}})}var r=Ar,u=Nr,i=r,o=u,a=ul;return n?t(n):(t.links=function(n){return iu(e(n)).edges.filter(function(n){return n.l&&n.r}).map(function(t){return{source:n[t.l.i],target:n[t.r.i]}})},t.triangles=function(n){var t=[];return iu(e(n)).cells.forEach(function(e,r){for(var u,i,o=e.site,a=e.edges.sort(Yr),c=-1,l=a.length,s=a[l-1].edge,f=s.l===o?s.r:s.l;++c=l,h=r>=s,g=h<<1|f;n.leaf=!1,n=n.nodes[g]||(n.nodes[g]=su()),f?u=l:a=l,h?o=s:c=s,i(n,t,e,r,u,o,a,c)}var s,f,h,g,p,v,d,m,y,M=Et(a),x=Et(c);if(null!=t)v=t,d=e,m=r,y=u;else if(m=y=-(v=d=1/0),f=[],h=[],p=n.length,o)for(g=0;p>g;++g)s=n[g],s.xm&&(m=s.x),s.y>y&&(y=s.y),f.push(s.x),h.push(s.y);else for(g=0;p>g;++g){var b=+M(s=n[g],g),_=+x(s,g);v>b&&(v=b),d>_&&(d=_),b>m&&(m=b),_>y&&(y=_),f.push(b),h.push(_)}var w=m-v,S=y-d;w>S?y=d+w:m=v+S;var k=su();if(k.add=function(n){i(k,n,+M(n,++g),+x(n,g),v,d,m,y)},k.visit=function(n){fu(n,k,v,d,m,y)},k.find=function(n){return hu(k,n[0],n[1],v,d,m,y)},g=-1,null==t){for(;++g=0?n.slice(0,t):n,r=t>=0?n.slice(t+1):"in";return e=cl.get(e)||al,r=ll.get(r)||y,Mu(r(e.apply(null,ea.call(arguments,1))))},ta.interpolateHcl=Lu,ta.interpolateHsl=Tu,ta.interpolateLab=Ru,ta.interpolateRound=Du,ta.transform=function(n){var t=ua.createElementNS(ta.ns.prefix.svg,"g");return(ta.transform=function(n){if(null!=n){t.setAttribute("transform",n);var e=t.transform.baseVal.consolidate()}return new Pu(e?e.matrix:sl)})(n)},Pu.prototype.toString=function(){return"translate("+this.translate+")rotate("+this.rotate+")skewX("+this.skew+")scale("+this.scale+")"};var sl={a:1,b:0,c:0,d:1,e:0,f:0};ta.interpolateTransform=Hu,ta.layout={},ta.layout.bundle=function(){return function(n){for(var t=[],e=-1,r=n.length;++ea*a/d){if(p>c){var l=t.charge/c;n.px-=i*l,n.py-=o*l}return!0}if(t.point&&c&&p>c){var l=t.pointCharge/c;n.px-=i*l,n.py-=o*l}}return!t.charge}}function t(n){n.px=ta.event.x,n.py=ta.event.y,a.resume()}var e,r,u,i,o,a={},c=ta.dispatch("start","tick","end"),l=[1,1],s=.9,f=fl,h=hl,g=-30,p=gl,v=.1,d=.64,m=[],M=[];return a.tick=function(){if((r*=.99)<.005)return c.end({type:"end",alpha:r=0}),!0;var t,e,a,f,h,p,d,y,x,b=m.length,_=M.length;for(e=0;_>e;++e)a=M[e],f=a.source,h=a.target,y=h.x-f.x,x=h.y-f.y,(p=y*y+x*x)&&(p=r*i[e]*((p=Math.sqrt(p))-u[e])/p,y*=p,x*=p,h.x-=y*(d=f.weight/(h.weight+f.weight)),h.y-=x*d,f.x+=y*(d=1-d),f.y+=x*d);if((d=r*v)&&(y=l[0]/2,x=l[1]/2,e=-1,d))for(;++e0?n:0:n>0&&(c.start({type:"start",alpha:r=n}),ta.timer(a.tick)),a):r},a.start=function(){function n(n,r){if(!e){for(e=new Array(c),a=0;c>a;++a)e[a]=[];for(a=0;s>a;++a){var u=M[a];e[u.source.index].push(u.target),e[u.target.index].push(u.source)}}for(var i,o=e[t],a=-1,l=o.length;++at;++t)(r=m[t]).index=t,r.weight=0;for(t=0;s>t;++t)r=M[t],"number"==typeof r.source&&(r.source=m[r.source]),"number"==typeof r.target&&(r.target=m[r.target]),++r.source.weight,++r.target.weight;for(t=0;c>t;++t)r=m[t],isNaN(r.x)&&(r.x=n("x",p)),isNaN(r.y)&&(r.y=n("y",v)),isNaN(r.px)&&(r.px=r.x),isNaN(r.py)&&(r.py=r.y);if(u=[],"function"==typeof f)for(t=0;s>t;++t)u[t]=+f.call(this,M[t],t);else for(t=0;s>t;++t)u[t]=f;if(i=[],"function"==typeof h)for(t=0;s>t;++t)i[t]=+h.call(this,M[t],t);else for(t=0;s>t;++t)i[t]=h;if(o=[],"function"==typeof g)for(t=0;c>t;++t)o[t]=+g.call(this,m[t],t);else for(t=0;c>t;++t)o[t]=g;return a.resume()},a.resume=function(){return a.alpha(.1)},a.stop=function(){return a.alpha(0)},a.drag=function(){return e||(e=ta.behavior.drag().origin(y).on("dragstart.force",Xu).on("drag.force",t).on("dragend.force",$u)),arguments.length?void this.on("mouseover.force",Bu).on("mouseout.force",Wu).call(e):e},ta.rebind(a,c,"on")};var fl=20,hl=1,gl=1/0;ta.layout.hierarchy=function(){function n(u){var i,o=[u],a=[];for(u.depth=0;null!=(i=o.pop());)if(a.push(i),(l=e.call(n,i,i.depth))&&(c=l.length)){for(var c,l,s;--c>=0;)o.push(s=l[c]),s.parent=i,s.depth=i.depth+1;r&&(i.value=0),i.children=l}else r&&(i.value=+r.call(n,i,i.depth)||0),delete i.children;return Qu(u,function(n){var e,u;t&&(e=n.children)&&e.sort(t),r&&(u=n.parent)&&(u.value+=n.value)}),a}var t=ei,e=ni,r=ti;return n.sort=function(e){return arguments.length?(t=e,n):t},n.children=function(t){return arguments.length?(e=t,n):e},n.value=function(t){return arguments.length?(r=t,n):r},n.revalue=function(t){return r&&(Ku(t,function(n){n.children&&(n.value=0)}),Qu(t,function(t){var e;t.children||(t.value=+r.call(n,t,t.depth)||0),(e=t.parent)&&(e.value+=t.value)})),t},n},ta.layout.partition=function(){function n(t,e,r,u){var i=t.children;if(t.x=e,t.y=t.depth*u,t.dx=r,t.dy=u,i&&(o=i.length)){var o,a,c,l=-1;for(r=t.value?r/t.value:0;++lf?-1:1),p=(f-c*g)/ta.sum(l),v=ta.range(c),d=[];return null!=e&&v.sort(e===pl?function(n,t){return l[t]-l[n]}:function(n,t){return e(o[n],o[t])}),v.forEach(function(n){d[n]={data:o[n],value:a=l[n],startAngle:s,endAngle:s+=a*p+g,padAngle:h}}),d}var t=Number,e=pl,r=0,u=La,i=0;return n.value=function(e){return arguments.length?(t=e,n):t},n.sort=function(t){return arguments.length?(e=t,n):e},n.startAngle=function(t){return arguments.length?(r=t,n):r},n.endAngle=function(t){return arguments.length?(u=t,n):u},n.padAngle=function(t){return arguments.length?(i=t,n):i},n};var pl={};ta.layout.stack=function(){function n(a,c){if(!(h=a.length))return a;var l=a.map(function(e,r){return t.call(n,e,r)}),s=l.map(function(t){return t.map(function(t,e){return[i.call(n,t,e),o.call(n,t,e)]})}),f=e.call(n,s,c);l=ta.permute(l,f),s=ta.permute(s,f);var h,g,p,v,d=r.call(n,s,c),m=l[0].length;for(p=0;m>p;++p)for(u.call(n,l[0][p],v=d[p],s[0][p][1]),g=1;h>g;++g)u.call(n,l[g][p],v+=s[g-1][p][1],s[g][p][1]);return a}var t=y,e=ai,r=ci,u=oi,i=ui,o=ii;return n.values=function(e){return arguments.length?(t=e,n):t},n.order=function(t){return arguments.length?(e="function"==typeof t?t:vl.get(t)||ai,n):e},n.offset=function(t){return arguments.length?(r="function"==typeof t?t:dl.get(t)||ci,n):r},n.x=function(t){return arguments.length?(i=t,n):i},n.y=function(t){return arguments.length?(o=t,n):o},n.out=function(t){return arguments.length?(u=t,n):u},n};var vl=ta.map({"inside-out":function(n){var t,e,r=n.length,u=n.map(li),i=n.map(si),o=ta.range(r).sort(function(n,t){return u[n]-u[t]}),a=0,c=0,l=[],s=[];for(t=0;r>t;++t)e=o[t],c>a?(a+=i[e],l.push(e)):(c+=i[e],s.push(e));return s.reverse().concat(l)},reverse:function(n){return ta.range(n.length).reverse()},"default":ai}),dl=ta.map({silhouette:function(n){var t,e,r,u=n.length,i=n[0].length,o=[],a=0,c=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];r>a&&(a=r),o.push(r)}for(e=0;i>e;++e)c[e]=(a-o[e])/2;return c},wiggle:function(n){var t,e,r,u,i,o,a,c,l,s=n.length,f=n[0],h=f.length,g=[];for(g[0]=c=l=0,e=1;h>e;++e){for(t=0,u=0;s>t;++t)u+=n[t][e][1];for(t=0,i=0,a=f[e][0]-f[e-1][0];s>t;++t){for(r=0,o=(n[t][e][1]-n[t][e-1][1])/(2*a);t>r;++r)o+=(n[r][e][1]-n[r][e-1][1])/a;i+=o*n[t][e][1]}g[e]=c-=u?i/u*a:0,l>c&&(l=c)}for(e=0;h>e;++e)g[e]-=l;return g},expand:function(n){var t,e,r,u=n.length,i=n[0].length,o=1/u,a=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];if(r)for(t=0;u>t;t++)n[t][e][1]/=r;else for(t=0;u>t;t++)n[t][e][1]=o}for(e=0;i>e;++e)a[e]=0;return a},zero:ci});ta.layout.histogram=function(){function n(n,i){for(var o,a,c=[],l=n.map(e,this),s=r.call(this,l,i),f=u.call(this,s,l,i),i=-1,h=l.length,g=f.length-1,p=t?1:1/h;++i0)for(i=-1;++i=s[0]&&a<=s[1]&&(o=c[ta.bisect(f,a,1,g)-1],o.y+=p,o.push(n[i]));return c}var t=!0,e=Number,r=pi,u=hi;return n.value=function(t){return arguments.length?(e=t,n):e},n.range=function(t){return arguments.length?(r=Et(t),n):r},n.bins=function(t){return arguments.length?(u="number"==typeof t?function(n){return gi(n,t)}:Et(t),n):u},n.frequency=function(e){return arguments.length?(t=!!e,n):t},n},ta.layout.pack=function(){function n(n,i){var o=e.call(this,n,i),a=o[0],c=u[0],l=u[1],s=null==t?Math.sqrt:"function"==typeof t?t:function(){return t};if(a.x=a.y=0,Qu(a,function(n){n.r=+s(n.value)}),Qu(a,Mi),r){var f=r*(t?1:Math.max(2*a.r/c,2*a.r/l))/2;Qu(a,function(n){n.r+=f}),Qu(a,Mi),Qu(a,function(n){n.r-=f})}return _i(a,c/2,l/2,t?1:1/Math.max(2*a.r/c,2*a.r/l)),o}var t,e=ta.layout.hierarchy().sort(vi),r=0,u=[1,1];return n.size=function(t){return arguments.length?(u=t,n):u},n.radius=function(e){return arguments.length?(t=null==e||"function"==typeof e?e:+e,n):t},n.padding=function(t){return arguments.length?(r=+t,n):r},Gu(n,e)},ta.layout.tree=function(){function n(n,u){var s=o.call(this,n,u),f=s[0],h=t(f);if(Qu(h,e),h.parent.m=-h.z,Ku(h,r),l)Ku(f,i);else{var g=f,p=f,v=f;Ku(f,function(n){n.xp.x&&(p=n),n.depth>v.depth&&(v=n)});var d=a(g,p)/2-g.x,m=c[0]/(p.x+a(p,g)/2+d),y=c[1]/(v.depth||1);Ku(f,function(n){n.x=(n.x+d)*m,n.y=n.depth*y})}return s}function t(n){for(var t,e={A:null,children:[n]},r=[e];null!=(t=r.pop());)for(var u,i=t.children,o=0,a=i.length;a>o;++o)r.push((i[o]=u={_:i[o],parent:t,children:(u=i[o].children)&&u.slice()||[],A:null,a:null,z:0,m:0,c:0,s:0,t:null,i:o}).a=u);return e.children[0]}function e(n){var t=n.children,e=n.parent.children,r=n.i?e[n.i-1]:null;if(t.length){Ni(n);var i=(t[0].z+t[t.length-1].z)/2;r?(n.z=r.z+a(n._,r._),n.m=n.z-i):n.z=i}else r&&(n.z=r.z+a(n._,r._));n.parent.A=u(n,r,n.parent.A||e[0])}function r(n){n._.x=n.z+n.parent.m,n.m+=n.parent.m}function u(n,t,e){if(t){for(var r,u=n,i=n,o=t,c=u.parent.children[0],l=u.m,s=i.m,f=o.m,h=c.m;o=Ei(o),u=ki(u),o&&u;)c=ki(c),i=Ei(i),i.a=n,r=o.z+f-u.z-l+a(o._,u._),r>0&&(Ai(Ci(o,n,e),n,r),l+=r,s+=r),f+=o.m,l+=u.m,h+=c.m,s+=i.m;o&&!Ei(i)&&(i.t=o,i.m+=f-s),u&&!ki(c)&&(c.t=u,c.m+=l-h,e=n)}return e}function i(n){n.x*=c[0],n.y=n.depth*c[1]}var o=ta.layout.hierarchy().sort(null).value(null),a=Si,c=[1,1],l=null;return n.separation=function(t){return arguments.length?(a=t,n):a},n.size=function(t){return arguments.length?(l=null==(c=t)?i:null,n):l?null:c},n.nodeSize=function(t){return arguments.length?(l=null==(c=t)?null:i,n):l?c:null},Gu(n,o)},ta.layout.cluster=function(){function n(n,i){var o,a=t.call(this,n,i),c=a[0],l=0;Qu(c,function(n){var t=n.children;t&&t.length?(n.x=qi(t),n.y=zi(t)):(n.x=o?l+=e(n,o):0,n.y=0,o=n)});var s=Li(c),f=Ti(c),h=s.x-e(s,f)/2,g=f.x+e(f,s)/2;return Qu(c,u?function(n){n.x=(n.x-c.x)*r[0],n.y=(c.y-n.y)*r[1]}:function(n){n.x=(n.x-h)/(g-h)*r[0],n.y=(1-(c.y?n.y/c.y:1))*r[1]}),a}var t=ta.layout.hierarchy().sort(null).value(null),e=Si,r=[1,1],u=!1;return n.separation=function(t){return arguments.length?(e=t,n):e},n.size=function(t){return arguments.length?(u=null==(r=t),n):u?null:r},n.nodeSize=function(t){return arguments.length?(u=null!=(r=t),n):u?r:null},Gu(n,t)},ta.layout.treemap=function(){function n(n,t){for(var e,r,u=-1,i=n.length;++ut?0:t),e.area=isNaN(r)||0>=r?0:r}function t(e){var i=e.children;if(i&&i.length){var o,a,c,l=f(e),s=[],h=i.slice(),p=1/0,v="slice"===g?l.dx:"dice"===g?l.dy:"slice-dice"===g?1&e.depth?l.dy:l.dx:Math.min(l.dx,l.dy);for(n(h,l.dx*l.dy/e.value),s.area=0;(c=h.length)>0;)s.push(o=h[c-1]),s.area+=o.area,"squarify"!==g||(a=r(s,v))<=p?(h.pop(),p=a):(s.area-=s.pop().area,u(s,v,l,!1),v=Math.min(l.dx,l.dy),s.length=s.area=0,p=1/0);s.length&&(u(s,v,l,!0),s.length=s.area=0),i.forEach(t)}}function e(t){var r=t.children;if(r&&r.length){var i,o=f(t),a=r.slice(),c=[];for(n(a,o.dx*o.dy/t.value),c.area=0;i=a.pop();)c.push(i),c.area+=i.area,null!=i.z&&(u(c,i.z?o.dx:o.dy,o,!a.length),c.length=c.area=0);r.forEach(e)}}function r(n,t){for(var e,r=n.area,u=0,i=1/0,o=-1,a=n.length;++oe&&(i=e),e>u&&(u=e));return r*=r,t*=t,r?Math.max(t*u*p/r,r/(t*i*p)):1/0}function u(n,t,e,r){var u,i=-1,o=n.length,a=e.x,l=e.y,s=t?c(n.area/t):0;if(t==e.dx){for((r||s>e.dy)&&(s=e.dy);++ie.dx)&&(s=e.dx);++ie&&(t=1),1>e&&(n=0),function(){var e,r,u;do e=2*Math.random()-1,r=2*Math.random()-1,u=e*e+r*r;while(!u||u>1);return n+t*e*Math.sqrt(-2*Math.log(u)/u)}},logNormal:function(){var n=ta.random.normal.apply(ta,arguments);return function(){return Math.exp(n())}},bates:function(n){var t=ta.random.irwinHall(n);return function(){return t()/n}},irwinHall:function(n){return function(){for(var t=0,e=0;n>e;e++)t+=Math.random();return t}}},ta.scale={};var ml={floor:y,ceil:y};ta.scale.linear=function(){return Ii([0,1],[0,1],mu,!1)};var yl={s:1,g:1,p:1,r:1,e:1};ta.scale.log=function(){return Ji(ta.scale.linear().domain([0,1]),10,!0,[1,10])};var Ml=ta.format(".0e"),xl={floor:function(n){return-Math.ceil(-n)},ceil:function(n){return-Math.floor(-n)}};ta.scale.pow=function(){return Gi(ta.scale.linear(),1,[0,1])},ta.scale.sqrt=function(){return ta.scale.pow().exponent(.5)},ta.scale.ordinal=function(){return Qi([],{t:"range",a:[[]]})},ta.scale.category10=function(){return ta.scale.ordinal().range(bl)},ta.scale.category20=function(){return ta.scale.ordinal().range(_l)},ta.scale.category20b=function(){return ta.scale.ordinal().range(wl)},ta.scale.category20c=function(){return ta.scale.ordinal().range(Sl)};var bl=[2062260,16744206,2924588,14034728,9725885,9197131,14907330,8355711,12369186,1556175].map(Mt),_l=[2062260,11454440,16744206,16759672,2924588,10018698,14034728,16750742,9725885,12955861,9197131,12885140,14907330,16234194,8355711,13092807,12369186,14408589,1556175,10410725].map(Mt),wl=[3750777,5395619,7040719,10264286,6519097,9216594,11915115,13556636,9202993,12426809,15186514,15190932,8666169,11356490,14049643,15177372,8077683,10834324,13528509,14589654].map(Mt),Sl=[3244733,7057110,10406625,13032431,15095053,16616764,16625259,16634018,3253076,7652470,10607003,13101504,7695281,10394312,12369372,14342891,6513507,9868950,12434877,14277081].map(Mt);ta.scale.quantile=function(){return no([],[])},ta.scale.quantize=function(){return to(0,1,[0,1])},ta.scale.threshold=function(){return eo([.5],[0,1])},ta.scale.identity=function(){return ro([0,1])},ta.svg={},ta.svg.arc=function(){function n(){var n=Math.max(0,+e.apply(this,arguments)),l=Math.max(0,+r.apply(this,arguments)),s=o.apply(this,arguments)-Ra,f=a.apply(this,arguments)-Ra,h=Math.abs(f-s),g=s>f?0:1;if(n>l&&(p=l,l=n,n=p),h>=Ta)return t(l,g)+(n?t(n,1-g):"")+"Z";var p,v,d,m,y,M,x,b,_,w,S,k,E=0,A=0,N=[];if((m=(+c.apply(this,arguments)||0)/2)&&(d=i===kl?Math.sqrt(n*n+l*l):+i.apply(this,arguments),g||(A*=-1),l&&(A=tt(d/l*Math.sin(m))),n&&(E=tt(d/n*Math.sin(m)))),l){y=l*Math.cos(s+A),M=l*Math.sin(s+A),x=l*Math.cos(f-A),b=l*Math.sin(f-A);var C=Math.abs(f-s-2*A)<=qa?0:1;if(A&&so(y,M,x,b)===g^C){var z=(s+f)/2;y=l*Math.cos(z),M=l*Math.sin(z),x=b=null}}else y=M=0;if(n){_=n*Math.cos(f-E),w=n*Math.sin(f-E),S=n*Math.cos(s+E),k=n*Math.sin(s+E);var q=Math.abs(s-f+2*E)<=qa?0:1;if(E&&so(_,w,S,k)===1-g^q){var L=(s+f)/2;_=n*Math.cos(L),w=n*Math.sin(L),S=k=null}}else _=w=0;if((p=Math.min(Math.abs(l-n)/2,+u.apply(this,arguments)))>.001){v=l>n^g?0:1;var T=null==S?[_,w]:null==x?[y,M]:Lr([y,M],[S,k],[x,b],[_,w]),R=y-T[0],D=M-T[1],P=x-T[0],U=b-T[1],j=1/Math.sin(Math.acos((R*P+D*U)/(Math.sqrt(R*R+D*D)*Math.sqrt(P*P+U*U)))/2),F=Math.sqrt(T[0]*T[0]+T[1]*T[1]);if(null!=x){var H=Math.min(p,(l-F)/(j+1)),O=fo(null==S?[_,w]:[S,k],[y,M],l,H,g),I=fo([x,b],[_,w],l,H,g);p===H?N.push("M",O[0],"A",H,",",H," 0 0,",v," ",O[1],"A",l,",",l," 0 ",1-g^so(O[1][0],O[1][1],I[1][0],I[1][1]),",",g," ",I[1],"A",H,",",H," 0 0,",v," ",I[0]):N.push("M",O[0],"A",H,",",H," 0 1,",v," ",I[0])}else N.push("M",y,",",M);if(null!=S){var Y=Math.min(p,(n-F)/(j-1)),Z=fo([y,M],[S,k],n,-Y,g),V=fo([_,w],null==x?[y,M]:[x,b],n,-Y,g);p===Y?N.push("L",V[0],"A",Y,",",Y," 0 0,",v," ",V[1],"A",n,",",n," 0 ",g^so(V[1][0],V[1][1],Z[1][0],Z[1][1]),",",1-g," ",Z[1],"A",Y,",",Y," 0 0,",v," ",Z[0]):N.push("L",V[0],"A",Y,",",Y," 0 0,",v," ",Z[0])}else N.push("L",_,",",w)}else N.push("M",y,",",M),null!=x&&N.push("A",l,",",l," 0 ",C,",",g," ",x,",",b),N.push("L",_,",",w),null!=S&&N.push("A",n,",",n," 0 ",q,",",1-g," ",S,",",k);return N.push("Z"),N.join("")}function t(n,t){return"M0,"+n+"A"+n+","+n+" 0 1,"+t+" 0,"+-n+"A"+n+","+n+" 0 1,"+t+" 0,"+n}var e=io,r=oo,u=uo,i=kl,o=ao,a=co,c=lo;return n.innerRadius=function(t){return arguments.length?(e=Et(t),n):e},n.outerRadius=function(t){return arguments.length?(r=Et(t),n):r},n.cornerRadius=function(t){return arguments.length?(u=Et(t),n):u},n.padRadius=function(t){return arguments.length?(i=t==kl?kl:Et(t),n):i},n.startAngle=function(t){return arguments.length?(o=Et(t),n):o},n.endAngle=function(t){return arguments.length?(a=Et(t),n):a},n.padAngle=function(t){return arguments.length?(c=Et(t),n):c},n.centroid=function(){var n=(+e.apply(this,arguments)+ +r.apply(this,arguments))/2,t=(+o.apply(this,arguments)+ +a.apply(this,arguments))/2-Ra;return[Math.cos(t)*n,Math.sin(t)*n]},n};var kl="auto";ta.svg.line=function(){return ho(y)};var El=ta.map({linear:go,"linear-closed":po,step:vo,"step-before":mo,"step-after":yo,basis:So,"basis-open":ko,"basis-closed":Eo,bundle:Ao,cardinal:bo,"cardinal-open":Mo,"cardinal-closed":xo,monotone:To});El.forEach(function(n,t){t.key=n,t.closed=/-closed$/.test(n)});var Al=[0,2/3,1/3,0],Nl=[0,1/3,2/3,0],Cl=[0,1/6,2/3,1/6];ta.svg.line.radial=function(){var n=ho(Ro);return n.radius=n.x,delete n.x,n.angle=n.y,delete n.y,n},mo.reverse=yo,yo.reverse=mo,ta.svg.area=function(){return Do(y)},ta.svg.area.radial=function(){var n=Do(Ro);return n.radius=n.x,delete n.x,n.innerRadius=n.x0,delete n.x0,n.outerRadius=n.x1,delete n.x1,n.angle=n.y,delete n.y,n.startAngle=n.y0,delete n.y0,n.endAngle=n.y1,delete n.y1,n},ta.svg.chord=function(){function n(n,a){var c=t(this,i,n,a),l=t(this,o,n,a);return"M"+c.p0+r(c.r,c.p1,c.a1-c.a0)+(e(c,l)?u(c.r,c.p1,c.r,c.p0):u(c.r,c.p1,l.r,l.p0)+r(l.r,l.p1,l.a1-l.a0)+u(l.r,l.p1,c.r,c.p0))+"Z"}function t(n,t,e,r){var u=t.call(n,e,r),i=a.call(n,u,r),o=c.call(n,u,r)-Ra,s=l.call(n,u,r)-Ra;return{r:i,a0:o,a1:s,p0:[i*Math.cos(o),i*Math.sin(o)],p1:[i*Math.cos(s),i*Math.sin(s)]}}function e(n,t){return n.a0==t.a0&&n.a1==t.a1}function r(n,t,e){return"A"+n+","+n+" 0 "+ +(e>qa)+",1 "+t}function u(n,t,e,r){return"Q 0,0 "+r}var i=mr,o=yr,a=Po,c=ao,l=co;return n.radius=function(t){return arguments.length?(a=Et(t),n):a},n.source=function(t){return arguments.length?(i=Et(t),n):i},n.target=function(t){return arguments.length?(o=Et(t),n):o},n.startAngle=function(t){return arguments.length?(c=Et(t),n):c},n.endAngle=function(t){return arguments.length?(l=Et(t),n):l},n},ta.svg.diagonal=function(){function n(n,u){var i=t.call(this,n,u),o=e.call(this,n,u),a=(i.y+o.y)/2,c=[i,{x:i.x,y:a},{x:o.x,y:a},o];return c=c.map(r),"M"+c[0]+"C"+c[1]+" "+c[2]+" "+c[3]}var t=mr,e=yr,r=Uo;return n.source=function(e){return arguments.length?(t=Et(e),n):t},n.target=function(t){return arguments.length?(e=Et(t),n):e},n.projection=function(t){return arguments.length?(r=t,n):r},n},ta.svg.diagonal.radial=function(){var n=ta.svg.diagonal(),t=Uo,e=n.projection;return n.projection=function(n){return arguments.length?e(jo(t=n)):t},n},ta.svg.symbol=function(){function n(n,r){return(zl.get(t.call(this,n,r))||Oo)(e.call(this,n,r))}var t=Ho,e=Fo;return n.type=function(e){return arguments.length?(t=Et(e),n):t},n.size=function(t){return arguments.length?(e=Et(t),n):e},n};var zl=ta.map({circle:Oo,cross:function(n){var t=Math.sqrt(n/5)/2;return"M"+-3*t+","+-t+"H"+-t+"V"+-3*t+"H"+t+"V"+-t+"H"+3*t+"V"+t+"H"+t+"V"+3*t+"H"+-t+"V"+t+"H"+-3*t+"Z"},diamond:function(n){var t=Math.sqrt(n/(2*Ll)),e=t*Ll;return"M0,"+-t+"L"+e+",0 0,"+t+" "+-e+",0Z"},square:function(n){var t=Math.sqrt(n)/2;return"M"+-t+","+-t+"L"+t+","+-t+" "+t+","+t+" "+-t+","+t+"Z"},"triangle-down":function(n){var t=Math.sqrt(n/ql),e=t*ql/2;return"M0,"+e+"L"+t+","+-e+" "+-t+","+-e+"Z"},"triangle-up":function(n){var t=Math.sqrt(n/ql),e=t*ql/2;return"M0,"+-e+"L"+t+","+e+" "+-t+","+e+"Z"}});ta.svg.symbolTypes=zl.keys();var ql=Math.sqrt(3),Ll=Math.tan(30*Da);_a.transition=function(n){for(var t,e,r=Tl||++Ul,u=Xo(n),i=[],o=Rl||{time:Date.now(),ease:Su,delay:0,duration:250},a=-1,c=this.length;++ai;i++){u.push(t=[]);for(var e=this[i],a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return Yo(u,this.namespace,this.id)},Pl.tween=function(n,t){var e=this.id,r=this.namespace;return arguments.length<2?this.node()[r][e].tween.get(n):Y(this,null==t?function(t){t[r][e].tween.remove(n)}:function(u){u[r][e].tween.set(n,t)})},Pl.attr=function(n,t){function e(){this.removeAttribute(a)}function r(){this.removeAttributeNS(a.space,a.local)}function u(n){return null==n?e:(n+="",function(){var t,e=this.getAttribute(a);return e!==n&&(t=o(e,n),function(n){this.setAttribute(a,t(n))})})}function i(n){return null==n?r:(n+="",function(){var t,e=this.getAttributeNS(a.space,a.local);return e!==n&&(t=o(e,n),function(n){this.setAttributeNS(a.space,a.local,t(n))})})}if(arguments.length<2){for(t in n)this.attr(t,n[t]);return this}var o="transform"==n?Hu:mu,a=ta.ns.qualify(n);return Zo(this,"attr."+n,t,a.local?i:u)},Pl.attrTween=function(n,t){function e(n,e){var r=t.call(this,n,e,this.getAttribute(u));return r&&function(n){this.setAttribute(u,r(n))}}function r(n,e){var r=t.call(this,n,e,this.getAttributeNS(u.space,u.local));return r&&function(n){this.setAttributeNS(u.space,u.local,r(n))}}var u=ta.ns.qualify(n);return this.tween("attr."+n,u.local?r:e)},Pl.style=function(n,e,r){function u(){this.style.removeProperty(n)}function i(e){return null==e?u:(e+="",function(){var u,i=t(this).getComputedStyle(this,null).getPropertyValue(n);return i!==e&&(u=mu(i,e),function(t){this.style.setProperty(n,u(t),r)})})}var o=arguments.length;if(3>o){if("string"!=typeof n){2>o&&(e="");for(r in n)this.style(r,n[r],e);return this}r=""}return Zo(this,"style."+n,e,i)},Pl.styleTween=function(n,e,r){function u(u,i){var o=e.call(this,u,i,t(this).getComputedStyle(this,null).getPropertyValue(n));return o&&function(t){this.style.setProperty(n,o(t),r)}}return arguments.length<3&&(r=""),this.tween("style."+n,u)},Pl.text=function(n){return Zo(this,"text",n,Vo)},Pl.remove=function(){var n=this.namespace;return this.each("end.transition",function(){var t;this[n].count<2&&(t=this.parentNode)&&t.removeChild(this)})},Pl.ease=function(n){var t=this.id,e=this.namespace;return arguments.length<1?this.node()[e][t].ease:("function"!=typeof n&&(n=ta.ease.apply(ta,arguments)),Y(this,function(r){r[e][t].ease=n}))},Pl.delay=function(n){var t=this.id,e=this.namespace;return arguments.length<1?this.node()[e][t].delay:Y(this,"function"==typeof n?function(r,u,i){r[e][t].delay=+n.call(r,r.__data__,u,i)}:(n=+n,function(r){r[e][t].delay=n}))},Pl.duration=function(n){var t=this.id,e=this.namespace;return arguments.length<1?this.node()[e][t].duration:Y(this,"function"==typeof n?function(r,u,i){r[e][t].duration=Math.max(1,n.call(r,r.__data__,u,i))}:(n=Math.max(1,n),function(r){r[e][t].duration=n}))},Pl.each=function(n,t){var e=this.id,r=this.namespace;if(arguments.length<2){var u=Rl,i=Tl;try{Tl=e,Y(this,function(t,u,i){Rl=t[r][e],n.call(t,t.__data__,u,i)})}finally{Rl=u,Tl=i}}else Y(this,function(u){var i=u[r][e];(i.event||(i.event=ta.dispatch("start","end","interrupt"))).on(n,t)});return this},Pl.transition=function(){for(var n,t,e,r,u=this.id,i=++Ul,o=this.namespace,a=[],c=0,l=this.length;l>c;c++){a.push(n=[]);for(var t=this[c],s=0,f=t.length;f>s;s++)(e=t[s])&&(r=e[o][u],$o(e,s,o,i,{time:r.time,ease:r.ease,delay:r.delay+r.duration,duration:r.duration})),n.push(e)}return Yo(a,o,i)},ta.svg.axis=function(){function n(n){n.each(function(){var n,l=ta.select(this),s=this.__chart__||e,f=this.__chart__=e.copy(),h=null==c?f.ticks?f.ticks.apply(f,a):f.domain():c,g=null==t?f.tickFormat?f.tickFormat.apply(f,a):y:t,p=l.selectAll(".tick").data(h,f),v=p.enter().insert("g",".domain").attr("class","tick").style("opacity",Ca),d=ta.transition(p.exit()).style("opacity",Ca).remove(),m=ta.transition(p.order()).style("opacity",1),M=Math.max(u,0)+o,x=Ui(f),b=l.selectAll(".domain").data([0]),_=(b.enter().append("path").attr("class","domain"),ta.transition(b));v.append("line"),v.append("text");var w,S,k,E,A=v.select("line"),N=m.select("line"),C=p.select("text").text(g),z=v.select("text"),q=m.select("text"),L="top"===r||"left"===r?-1:1;if("bottom"===r||"top"===r?(n=Bo,w="x",k="y",S="x2",E="y2",C.attr("dy",0>L?"0em":".71em").style("text-anchor","middle"),_.attr("d","M"+x[0]+","+L*i+"V0H"+x[1]+"V"+L*i)):(n=Wo,w="y",k="x",S="y2",E="x2",C.attr("dy",".32em").style("text-anchor",0>L?"end":"start"),_.attr("d","M"+L*i+","+x[0]+"H0V"+x[1]+"H"+L*i)),A.attr(E,L*u),z.attr(k,L*M),N.attr(S,0).attr(E,L*u),q.attr(w,0).attr(k,L*M),f.rangeBand){var T=f,R=T.rangeBand()/2;s=f=function(n){return T(n)+R}}else s.rangeBand?s=f:d.call(n,f,s);v.call(n,s,f),m.call(n,f,f)})}var t,e=ta.scale.linear(),r=jl,u=6,i=6,o=3,a=[10],c=null;return n.scale=function(t){return arguments.length?(e=t,n):e},n.orient=function(t){return arguments.length?(r=t in Fl?t+"":jl,n):r},n.ticks=function(){return arguments.length?(a=arguments,n):a},n.tickValues=function(t){return arguments.length?(c=t,n):c},n.tickFormat=function(e){return arguments.length?(t=e,n):t},n.tickSize=function(t){var e=arguments.length;return e?(u=+t,i=+arguments[e-1],n):u},n.innerTickSize=function(t){return arguments.length?(u=+t,n):u},n.outerTickSize=function(t){return arguments.length?(i=+t,n):i},n.tickPadding=function(t){return arguments.length?(o=+t,n):o},n.tickSubdivide=function(){return arguments.length&&n},n};var jl="bottom",Fl={top:1,right:1,bottom:1,left:1};ta.svg.brush=function(){function n(t){t.each(function(){var t=ta.select(this).style("pointer-events","all").style("-webkit-tap-highlight-color","rgba(0,0,0,0)").on("mousedown.brush",i).on("touchstart.brush",i),o=t.selectAll(".background").data([0]);o.enter().append("rect").attr("class","background").style("visibility","hidden").style("cursor","crosshair"),t.selectAll(".extent").data([0]).enter().append("rect").attr("class","extent").style("cursor","move");var a=t.selectAll(".resize").data(v,y);a.exit().remove(),a.enter().append("g").attr("class",function(n){return"resize "+n}).style("cursor",function(n){return Hl[n]}).append("rect").attr("x",function(n){return/[ew]$/.test(n)?-3:null}).attr("y",function(n){return/^[ns]/.test(n)?-3:null}).attr("width",6).attr("height",6).style("visibility","hidden"),a.style("display",n.empty()?"none":null);var c,f=ta.transition(t),h=ta.transition(o);l&&(c=Ui(l),h.attr("x",c[0]).attr("width",c[1]-c[0]),r(f)),s&&(c=Ui(s),h.attr("y",c[0]).attr("height",c[1]-c[0]),u(f)),e(f)})}function e(n){n.selectAll(".resize").attr("transform",function(n){return"translate("+f[+/e$/.test(n)]+","+h[+/^s/.test(n)]+")"})}function r(n){n.select(".extent").attr("x",f[0]),n.selectAll(".extent,.n>rect,.s>rect").attr("width",f[1]-f[0])}function u(n){n.select(".extent").attr("y",h[0]),n.selectAll(".extent,.e>rect,.w>rect").attr("height",h[1]-h[0])}function i(){function i(){32==ta.event.keyCode&&(C||(M=null,q[0]-=f[1],q[1]-=h[1],C=2),S())}function v(){32==ta.event.keyCode&&2==C&&(q[0]+=f[1],q[1]+=h[1],C=0,S())}function d(){var n=ta.mouse(b),t=!1;x&&(n[0]+=x[0],n[1]+=x[1]),C||(ta.event.altKey?(M||(M=[(f[0]+f[1])/2,(h[0]+h[1])/2]),q[0]=f[+(n[0]s?(u=r,r=s):u=s),v[0]!=r||v[1]!=u?(e?a=null:o=null,v[0]=r,v[1]=u,!0):void 0}function y(){d(),k.style("pointer-events","all").selectAll(".resize").style("display",n.empty()?"none":null),ta.select("body").style("cursor",null),L.on("mousemove.brush",null).on("mouseup.brush",null).on("touchmove.brush",null).on("touchend.brush",null).on("keydown.brush",null).on("keyup.brush",null),z(),w({type:"brushend"})}var M,x,b=this,_=ta.select(ta.event.target),w=c.of(b,arguments),k=ta.select(b),E=_.datum(),A=!/^(n|s)$/.test(E)&&l,N=!/^(e|w)$/.test(E)&&s,C=_.classed("extent"),z=W(b),q=ta.mouse(b),L=ta.select(t(b)).on("keydown.brush",i).on("keyup.brush",v);if(ta.event.changedTouches?L.on("touchmove.brush",d).on("touchend.brush",y):L.on("mousemove.brush",d).on("mouseup.brush",y),k.interrupt().selectAll("*").interrupt(),C)q[0]=f[0]-q[0],q[1]=h[0]-q[1];else if(E){var T=+/w$/.test(E),R=+/^n/.test(E);x=[f[1-T]-q[0],h[1-R]-q[1]],q[0]=f[T],q[1]=h[R]}else ta.event.altKey&&(M=q.slice());k.style("pointer-events","none").selectAll(".resize").style("display",null),ta.select("body").style("cursor",_.style("cursor")),w({type:"brushstart"}),d()}var o,a,c=E(n,"brushstart","brush","brushend"),l=null,s=null,f=[0,0],h=[0,0],g=!0,p=!0,v=Ol[0];return n.event=function(n){n.each(function(){var n=c.of(this,arguments),t={x:f,y:h,i:o,j:a},e=this.__chart__||t;this.__chart__=t,Tl?ta.select(this).transition().each("start.brush",function(){o=e.i,a=e.j,f=e.x,h=e.y,n({type:"brushstart"})}).tween("brush:brush",function(){var e=yu(f,t.x),r=yu(h,t.y);return o=a=null,function(u){f=t.x=e(u),h=t.y=r(u),n({type:"brush",mode:"resize"})}}).each("end.brush",function(){o=t.i,a=t.j,n({type:"brush",mode:"resize"}),n({type:"brushend"})}):(n({type:"brushstart"}),n({type:"brush",mode:"resize"}),n({type:"brushend"}))})},n.x=function(t){return arguments.length?(l=t,v=Ol[!l<<1|!s],n):l},n.y=function(t){return arguments.length?(s=t,v=Ol[!l<<1|!s],n):s},n.clamp=function(t){return arguments.length?(l&&s?(g=!!t[0],p=!!t[1]):l?g=!!t:s&&(p=!!t),n):l&&s?[g,p]:l?g:s?p:null},n.extent=function(t){var e,r,u,i,c;return arguments.length?(l&&(e=t[0],r=t[1],s&&(e=e[0],r=r[0]),o=[e,r],l.invert&&(e=l(e),r=l(r)),e>r&&(c=e,e=r,r=c),(e!=f[0]||r!=f[1])&&(f=[e,r])),s&&(u=t[0],i=t[1],l&&(u=u[1],i=i[1]),a=[u,i],s.invert&&(u=s(u),i=s(i)),u>i&&(c=u,u=i,i=c),(u!=h[0]||i!=h[1])&&(h=[u,i])),n):(l&&(o?(e=o[0],r=o[1]):(e=f[0],r=f[1],l.invert&&(e=l.invert(e),r=l.invert(r)),e>r&&(c=e,e=r,r=c))),s&&(a?(u=a[0],i=a[1]):(u=h[0],i=h[1],s.invert&&(u=s.invert(u),i=s.invert(i)),u>i&&(c=u,u=i,i=c))),l&&s?[[e,u],[r,i]]:l?[e,r]:s&&[u,i])},n.clear=function(){return n.empty()||(f=[0,0],h=[0,0],o=a=null),n},n.empty=function(){return!!l&&f[0]==f[1]||!!s&&h[0]==h[1]},ta.rebind(n,c,"on")};var Hl={n:"ns-resize",e:"ew-resize",s:"ns-resize",w:"ew-resize",nw:"nwse-resize",ne:"nesw-resize",se:"nwse-resize",sw:"nesw-resize"},Ol=[["n","e","s","w","nw","ne","se","sw"],["e","w"],["n","s"],[]],Il=ac.format=gc.timeFormat,Yl=Il.utc,Zl=Yl("%Y-%m-%dT%H:%M:%S.%LZ");Il.iso=Date.prototype.toISOString&&+new Date("2000-01-01T00:00:00.000Z")?Jo:Zl,Jo.parse=function(n){var t=new Date(n);return isNaN(t)?null:t},Jo.toString=Zl.toString,ac.second=Ft(function(n){return new cc(1e3*Math.floor(n/1e3))},function(n,t){n.setTime(n.getTime()+1e3*Math.floor(t))},function(n){return n.getSeconds()}),ac.seconds=ac.second.range,ac.seconds.utc=ac.second.utc.range,ac.minute=Ft(function(n){return new cc(6e4*Math.floor(n/6e4))},function(n,t){n.setTime(n.getTime()+6e4*Math.floor(t))},function(n){return n.getMinutes()}),ac.minutes=ac.minute.range,ac.minutes.utc=ac.minute.utc.range,ac.hour=Ft(function(n){var t=n.getTimezoneOffset()/60;return new cc(36e5*(Math.floor(n/36e5-t)+t))},function(n,t){n.setTime(n.getTime()+36e5*Math.floor(t))},function(n){return n.getHours()}),ac.hours=ac.hour.range,ac.hours.utc=ac.hour.utc.range,ac.month=Ft(function(n){return n=ac.day(n),n.setDate(1),n},function(n,t){n.setMonth(n.getMonth()+t)},function(n){return n.getMonth()}),ac.months=ac.month.range,ac.months.utc=ac.month.utc.range;var Vl=[1e3,5e3,15e3,3e4,6e4,3e5,9e5,18e5,36e5,108e5,216e5,432e5,864e5,1728e5,6048e5,2592e6,7776e6,31536e6],Xl=[[ac.second,1],[ac.second,5],[ac.second,15],[ac.second,30],[ac.minute,1],[ac.minute,5],[ac.minute,15],[ac.minute,30],[ac.hour,1],[ac.hour,3],[ac.hour,6],[ac.hour,12],[ac.day,1],[ac.day,2],[ac.week,1],[ac.month,1],[ac.month,3],[ac.year,1]],$l=Il.multi([[".%L",function(n){return n.getMilliseconds()}],[":%S",function(n){return n.getSeconds()}],["%I:%M",function(n){return n.getMinutes()}],["%I %p",function(n){return n.getHours()}],["%a %d",function(n){return n.getDay()&&1!=n.getDate()}],["%b %d",function(n){return 1!=n.getDate()}],["%B",function(n){return n.getMonth()}],["%Y",Ne]]),Bl={range:function(n,t,e){return ta.range(Math.ceil(n/e)*e,+t,e).map(Ko)},floor:y,ceil:y};Xl.year=ac.year,ac.scale=function(){return Go(ta.scale.linear(),Xl,$l)};var Wl=Xl.map(function(n){return[n[0].utc,n[1]]}),Jl=Yl.multi([[".%L",function(n){return n.getUTCMilliseconds()}],[":%S",function(n){return n.getUTCSeconds()}],["%I:%M",function(n){return n.getUTCMinutes()}],["%I %p",function(n){return n.getUTCHours()}],["%a %d",function(n){return n.getUTCDay()&&1!=n.getUTCDate()}],["%b %d",function(n){return 1!=n.getUTCDate()}],["%B",function(n){return n.getUTCMonth()}],["%Y",Ne]]);Wl.year=ac.year.utc,ac.scale.utc=function(){return Go(ta.scale.linear(),Wl,Jl)},ta.text=At(function(n){return n.responseText}),ta.json=function(n,t){return Nt(n,"application/json",Qo,t)},ta.html=function(n,t){return Nt(n,"text/html",na,t)},ta.xml=At(function(n){return n.responseXML}),"function"==typeof define&&define.amd?define(ta):"object"==typeof module&&module.exports&&(module.exports=ta),this.d3=ta}(); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js new file mode 100644 index 000000000000..2d9262b972a5 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -0,0 +1,29 @@ +/* This is a custom version of dagre-d3 on top of v0.4.3. The full list of commits can be found at http://github.com/andrewor14/dagre-d3/ */!function(e){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=e();else if("function"==typeof define&&define.amd)define([],e);else{var f;"undefined"!=typeof window?f=window:"undefined"!=typeof global?f=global:"undefined"!=typeof self&&(f=self),f.dagreD3=e()}}(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ +parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; + +}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments[1]:{},peg$FAILED={},peg$startRuleFunctions={start:peg$parsestart,graphStmt:peg$parsegraphStmt},peg$startRuleFunction=peg$parsestart,peg$c0=[],peg$c1=peg$FAILED,peg$c2=null,peg$c3="{",peg$c4={type:"literal",value:"{",description:'"{"'},peg$c5="}",peg$c6={type:"literal",value:"}",description:'"}"'},peg$c7=function(strict,type,id,stmts){return{type:type,id:id,strict:strict!==null,stmts:stmts}},peg$c8=";",peg$c9={type:"literal",value:";",description:'";"'},peg$c10=function(first,rest){var result=[first];for(var i=0;i",description:'"->"'},peg$c33=function(rhs,rest){var result=[rhs];if(rest){for(var i=0;ipos){peg$cachedPos=0;peg$cachedPosDetails={line:1,column:1,seenCR:false}}advance(peg$cachedPosDetails,peg$cachedPos,pos);peg$cachedPos=pos}return peg$cachedPosDetails}function peg$fail(expected){if(peg$currPospeg$maxFailPos){peg$maxFailPos=peg$currPos;peg$maxFailExpected=[]}peg$maxFailExpected.push(expected)}function peg$buildException(message,expected,pos){function cleanupExpected(expected){var i=1;expected.sort(function(a,b){if(a.descriptionb.description){return 1}else{return 0}});while(i1?expectedDescs.slice(0,-1).join(", ")+" or "+expectedDescs[expected.length-1]:expectedDescs[0];foundDesc=found?'"'+stringEscape(found)+'"':"end of input";return"Expected "+expectedDesc+" but "+foundDesc+" found."}var posDetails=peg$computePosDetails(pos),found=pospeg$currPos){s5=input.charAt(peg$currPos);peg$currPos++}else{s5=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c110)}}if(s5!==peg$FAILED){s4=[s4,s5];s3=s4}else{peg$currPos=s3;s3=peg$c1}}else{peg$currPos=s3;s3=peg$c1}while(s3!==peg$FAILED){s2.push(s3);s3=peg$currPos;s4=peg$currPos;peg$silentFails++;if(input.substr(peg$currPos,2)===peg$c108){s5=peg$c108;peg$currPos+=2}else{s5=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c109)}}peg$silentFails--;if(s5===peg$FAILED){s4=peg$c30}else{peg$currPos=s4;s4=peg$c1}if(s4!==peg$FAILED){if(input.length>peg$currPos){s5=input.charAt(peg$currPos);peg$currPos++}else{s5=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c110)}}if(s5!==peg$FAILED){s4=[s4,s5];s3=s4}else{peg$currPos=s3;s3=peg$c1}}else{peg$currPos=s3;s3=peg$c1}}if(s2!==peg$FAILED){if(input.substr(peg$currPos,2)===peg$c108){s3=peg$c108;peg$currPos+=2}else{s3=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c109)}}if(s3!==peg$FAILED){s1=[s1,s2,s3];s0=s1}else{peg$currPos=s0;s0=peg$c1}}else{peg$currPos=s0;s0=peg$c1}}else{peg$currPos=s0;s0=peg$c1}}peg$silentFails--;if(s0===peg$FAILED){s1=peg$FAILED;if(peg$silentFails===0){peg$fail(peg$c101)}}return s0}function peg$parse_(){var s0;s0=peg$parsewhitespace();if(s0===peg$FAILED){s0=peg$parsecomment()}return s0}var _=require("lodash");var directed;peg$result=peg$startRuleFunction();if(peg$result!==peg$FAILED&&peg$currPos===input.length){return peg$result}else{if(peg$result!==peg$FAILED&&peg$currPos":"--",writer=new Writer;if(!g.isMultigraph()){writer.write("strict ")}writer.writeLine((g.isDirected()?"digraph":"graph")+" {");writer.indent();var graphAttrs=g.graph();if(_.isObject(graphAttrs)){_.each(graphAttrs,function(v,k){writer.writeLine(id(k)+"="+id(v)+";")})}writeSubgraph(g,undefined,writer);g.edges().forEach(function(edge){writeEdge(g,edge,ec,writer)});writer.unindent();writer.writeLine("}");return writer.toString()}function writeSubgraph(g,v,writer){var children=g.isCompound()?g.children(v):g.nodes();_.each(children,function(w){if(!g.isCompound()||!g.children(w).length){writeNode(g,w,writer)}else{writer.writeLine("subgraph "+id(w)+" {");writer.indent();if(_.isObject(g.node(w))){_.map(g.node(w),function(val,key){writer.writeLine(id(key)+"="+id(val)+";")})}writeSubgraph(g,w,writer);writer.unindent();writer.writeLine("}")}})}function writeNode(g,v,writer){writer.write(id(v));writeAttrs(g.node(v),writer);writer.writeLine()}function writeEdge(g,edge,ec,writer){var v=edge.v,w=edge.w,attrs=g.edge(edge);writer.write(id(v)+" "+ec+" "+id(w));writeAttrs(attrs,writer);writer.writeLine()}function writeAttrs(attrs,writer){if(_.isObject(attrs)){var attrStrs=_.map(attrs,function(val,key){return id(key)+"="+id(val)});if(attrStrs.length){writer.write(" ["+attrStrs.join(",")+"]")}}}function id(obj){if(typeof obj==="number"||obj.toString().match(UNESCAPED_ID_PATTERN)){return obj}return'"'+obj.toString().replace(/"/g,'\\"')+'"'}function Writer(){this._indent="";this._content="";this._shouldIndent=true}Writer.prototype.INDENT=" ";Writer.prototype.indent=function(){this._indent+=this.INDENT};Writer.prototype.unindent=function(){this._indent=this._indent.slice(this.INDENT.length)};Writer.prototype.writeLine=function(line){this.write((line||"")+"\n");this._shouldIndent=true};Writer.prototype.write=function(str){if(this._shouldIndent){this._shouldIndent=false;this._content+=this._indent}this._content+=str};Writer.prototype.toString=function(){return this._content}},{lodash:28}],9:[function(require,module,exports){var _=require("lodash");module.exports=_.clone(require("./lib"));module.exports.json=require("./lib/json");module.exports.alg=require("./lib/alg")},{"./lib":25,"./lib/alg":16,"./lib/json":26,lodash:28}],10:[function(require,module,exports){var _=require("lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{lodash:28}],11:[function(require,module,exports){var _=require("lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{lodash:28}],12:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"./dijkstra":13,lodash:28}],13:[function(require,module,exports){var _=require("lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":23,lodash:28}],14:[function(require,module,exports){var _=require("lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"./tarjan":21,lodash:28}],15:[function(require,module,exports){var _=require("lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":23,"../graph":24,lodash:28}],21:[function(require,module,exports){var _=require("lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{lodash:28}],22:[function(require,module,exports){var _=require("lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{lodash:28}],23:[function(require,module,exports){var _=require("lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(v,w,value,name){var valueSpecified=arguments.length>2;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{lodash:28}],25:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":24,"./version":27}],26:[function(require,module,exports){var _=require("lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":24,lodash:28}],27:[function(require,module,exports){module.exports="0.8.1"},{}],28:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f "+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index5' : ' ▴'; + sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; } @@ -113,7 +113,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; } @@ -134,7 +134,7 @@ sorttable = { this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); // build an array to sort. This is a Schwartzian transform thing, diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css new file mode 100644 index 000000000000..3b4ae2ed354b --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#dag-viz-graph a, #dag-viz-graph a:hover { + text-decoration: none; +} + +#dag-viz-graph .label { + font-weight: normal; + text-shadow: none; +} + +#dag-viz-graph svg path { + stroke: #444; + stroke-width: 1.5px; +} + +#dag-viz-graph svg g.cluster rect { + stroke-width: 1px; +} + +#dag-viz-graph div#empty-dag-viz-message { + margin: 15px; +} + +/* Job page specific styles */ + +#dag-viz-graph svg.job marker#marker-arrow path { + fill: #333; + stroke-width: 0px; +} + +#dag-viz-graph svg.job g.cluster rect { + fill: #A0DFFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.skipped rect { + fill: #D6D6D6; + stroke: #B7B7B7; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.stage rect { + fill: #FFFFFF; + stroke: #FF99AC; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.stage.skipped rect { + stroke: #ADADAD; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g#cross-stage-edges path { + fill: none; +} + +#dag-viz-graph svg.job g.cluster text { + fill: #333; +} + +#dag-viz-graph svg.job g.cluster.skipped text { + fill: #666; +} + +#dag-viz-graph svg.job g.node circle { + fill: #444; +} + +#dag-viz-graph svg.job g.node.cached circle { + fill: #A3F545; + stroke: #52C366; + stroke-width: 2px; +} + +/* Stage page specific styles */ + +#dag-viz-graph svg.stage g.cluster rect { + fill: #A0DFFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + +#dag-viz-graph svg.stage g.cluster.stage rect { + fill: #FFFFFF; + stroke: #FFA6B6; + stroke-width: 1px; +} + +#dag-viz-graph svg.stage g.node g.label text tspan { + fill: #333; +} + +#dag-viz-graph svg.stage g.cluster text { + fill: #333; +} + +#dag-viz-graph svg.stage g.node rect { + fill: #C3EBFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + +#dag-viz-graph svg.stage g.node.cached rect { + fill: #B3F5C5; + stroke: #52C366; + stroke-width: 2px; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js new file mode 100644 index 000000000000..83dbea40b63f --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -0,0 +1,520 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This file contains the logic to render the RDD DAG visualization in the UI. + * + * This DAG describes the relationships between + * (1) an RDD and its dependencies, + * (2) an RDD and its operation scopes, and + * (3) an RDD's operation scopes and the stage / job hierarchy + * + * An operation scope is a general, named code block that instantiates RDDs + * (e.g. filter, textFile, reduceByKey). An operation scope can be nested inside + * of other scopes if the corresponding RDD operation invokes other such operations + * (for more detail, see o.a.s.rdd.RDDOperationScope). + * + * A stage may include one or more operation scopes if the RDD operations are + * streamlined into one stage (e.g. rdd.map(...).filter(...).flatMap(...)). + * On the flip side, an operation scope may also include one or many stages, + * or even jobs if the RDD operation is higher level than Spark's scheduling + * primitives (e.g. take, any SQL query). + * + * In the visualization, an RDD is expressed as a node, and its dependencies + * as directed edges (from parent to child). operation scopes, stages, and + * jobs are expressed as clusters that may contain one or many nodes. These + * clusters may be nested inside of each other in the scenarios described + * above. + * + * The visualization is rendered in an SVG contained in "div#dag-viz-graph", + * and its input data is expected to be populated in "div#dag-viz-metadata" + * by Spark's UI code. This is currently used only on the stage page and on + * the job page. + * + * This requires jQuery, d3, and dagre-d3. Note that we use a custom release + * of dagre-d3 (http://github.com/andrewor14/dagre-d3) for some specific + * functionality. For more detail, please track the changes in that project + * since it was forked (commit 101503833a8ce5fe369547f6addf3e71172ce10b). + */ + +var VizConstants = { + svgMarginX: 16, + svgMarginY: 16, + stageSep: 40, + graphPrefix: "graph_", + nodePrefix: "node_", + clusterPrefix: "cluster_" +}; + +var JobPageVizConstants = { + clusterLabelSize: 12, + stageClusterLabelSize: 14, + rankSep: 40 +}; + +var StagePageVizConstants = { + clusterLabelSize: 14, + stageClusterLabelSize: 14, + rankSep: 40 +}; + +/* + * Return "expand-dag-viz-arrow-job" if forJob is true. + * Otherwise, return "expand-dag-viz-arrow-stage". + */ +function expandDagVizArrowKey(forJob) { + return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage"; +} + +/* + * Show or hide the RDD DAG visualization. + * + * The graph is only rendered the first time this is called. + * This is the narrow interface called from the Scala UI code. + */ +function toggleDagViz(forJob) { + var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true"; + status = !status; + + var arrowSelector = ".expand-dag-viz-arrow"; + $(arrowSelector).toggleClass('arrow-closed'); + $(arrowSelector).toggleClass('arrow-open'); + var shouldShow = $(arrowSelector).hasClass("arrow-open"); + if (shouldShow) { + var shouldRender = graphContainer().select("*").empty(); + if (shouldRender) { + renderDagViz(forJob); + } + graphContainer().style("display", "block"); + } else { + // Save the graph for later so we don't have to render it again + graphContainer().style("display", "none"); + } + + window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status); +} + +$(function (){ + if ($("#stage-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(false), "false"); + toggleDagViz(false); + } else if ($("#job-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(true), "false"); + toggleDagViz(true); + } +}); + +/* + * Render the RDD DAG visualization. + * + * Input DOM hierarchy: + * div#dag-viz-metadata > + * div.stage-metadata > + * div.[dot-file | incoming-edge | outgoing-edge] + * + * Output DOM hierarchy: + * div#dag-viz-graph > + * svg > + * g.cluster_stage_[stageId] + * + * Note that the input metadata is populated by o.a.s.ui.UIUtils.showDagViz. + * Any changes in the input format here must be reflected there. + */ +function renderDagViz(forJob) { + + // If there is not a dot file to render, fail fast and report error + var jobOrStage = forJob ? "job" : "stage"; + if (metadataContainer().empty() || + metadataContainer().selectAll("div").empty()) { + var message = + "No visualization information available for this " + jobOrStage + "!
" + + "If this is an old " + jobOrStage + ", its visualization metadata may have been " + + "cleaned up over time.
You may consider increasing the value of "; + if (forJob) { + message += "spark.ui.retainedJobs and spark.ui.retainedStages."; + } else { + message += "spark.ui.retainedStages"; + } + graphContainer().append("div").attr("id", "empty-dag-viz-message").html(message); + return; + } + + // Render + var svg = graphContainer().append("svg").attr("class", jobOrStage); + if (forJob) { + renderDagVizForJob(svg); + } else { + renderDagVizForStage(svg); + } + + // Find cached RDDs and mark them as such + metadataContainer().selectAll(".cached-rdd").each(function(v) { + var rddId = d3.select(this).text().trim(); + var nodeId = VizConstants.nodePrefix + rddId; + svg.selectAll("g." + nodeId).classed("cached", true); + }); + + resizeSvg(svg); +} + +/* Render the RDD DAG visualization on the stage page. */ +function renderDagVizForStage(svgContainer) { + var metadata = metadataContainer().select(".stage-metadata"); + var dot = metadata.select(".dot-file").text().trim(); + var containerId = VizConstants.graphPrefix + metadata.attr("stage-id"); + var container = svgContainer.append("g").attr("id", containerId); + renderDot(dot, container, false); + + // Round corners on rectangles + svgContainer + .selectAll("rect") + .attr("rx", "5") + .attr("ry", "5"); +} + +/* + * Render the RDD DAG visualization on the job page. + * + * Due to limitations in dagre-d3, each stage is rendered independently so that + * we have more control on how to position them. Unfortunately, this means we + * cannot rely on dagre-d3 to render edges that cross stages and must render + * these manually on our own. + */ +function renderDagVizForJob(svgContainer) { + var crossStageEdges = []; + + // Each div.stage-metadata contains the information needed to generate the graph + // for a stage. This includes the DOT file produced from the appropriate UI listener, + // any incoming and outgoing edges, and any cached RDDs that belong to this stage. + metadataContainer().selectAll(".stage-metadata").each(function(d, i) { + var metadata = d3.select(this); + var dot = metadata.select(".dot-file").text(); + var stageId = metadata.attr("stage-id"); + var containerId = VizConstants.graphPrefix + stageId; + var isSkipped = metadata.attr("skipped") == "true"; + var container; + if (isSkipped) { + container = svgContainer + .append("g") + .attr("id", containerId) + .attr("skipped", "true"); + } else { + // Link each graph to the corresponding stage page (TODO: handle stage attempts) + // Use the link from the stage table so it also works for the history server + var attemptId = 0 + var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) + .select("a.name-link") + .attr("href") + "&expandDagViz=true"; + container = svgContainer + .append("a") + .attr("xlink:href", stageLink) + .append("g") + .attr("id", containerId); + } + + // Now we need to shift the container for this stage so it doesn't overlap with + // existing ones, taking into account the position and width of the last stage's + // container. We do not need to do this for the first stage of this job. + if (i > 0) { + var existingStages = svgContainer.selectAll("g.cluster.stage") + if (!existingStages.empty()) { + var lastStage = d3.select(existingStages[0].pop()); + var lastStageWidth = toFloat(lastStage.select("rect").attr("width")); + var lastStagePosition = getAbsolutePosition(lastStage); + var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep; + container.attr("transform", "translate(" + offset + ", 0)"); + } + } + + // Actually render the stage + renderDot(dot, container, true); + + // Mark elements as skipped if appropriate. Unfortunately we need to mark all + // elements instead of the parent container because of CSS override rules. + if (isSkipped) { + container.selectAll("g").classed("skipped", true); + } + + // Round corners on rectangles + container + .selectAll("rect") + .attr("rx", "4") + .attr("ry", "4"); + + // If there are any incoming edges into this graph, keep track of them to render + // them separately later. Note that we cannot draw them now because we need to + // put these edges in a separate container that is on top of all stage graphs. + metadata.selectAll(".incoming-edge").each(function(v) { + var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4] + crossStageEdges.push(edge); + }); + }); + + addTooltipsForRDDs(svgContainer); + drawCrossStageEdges(crossStageEdges, svgContainer); +} + +/* Render the dot file as an SVG in the given container. */ +function renderDot(dot, container, forJob) { + var escaped_dot = dot + .replace(/</g, "<") + .replace(/>/g, ">") + .replace(/"/g, "\""); + var g = graphlibDot.read(escaped_dot); + var renderer = new dagreD3.render(); + preprocessGraphLayout(g, forJob); + renderer(container, g); + + // Find the stage cluster and mark it for styling and post-processing + container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); +} + +/* -------------------- * + * | Helper functions | * + * -------------------- */ + +// Helper d3 accessors +function graphContainer() { return d3.select("#dag-viz-graph"); } +function metadataContainer() { return d3.select("#dag-viz-metadata"); } + +/* + * Helper function to pre-process the graph layout. + * This step is necessary for certain styles that affect the positioning + * and sizes of graph elements, e.g. padding, font style, shape. + */ +function preprocessGraphLayout(g, forJob) { + var nodes = g.nodes(); + for (var i = 0; i < nodes.length; i++) { + var isCluster = g.children(nodes[i]).length > 0; + if (!isCluster) { + var node = g.node(nodes[i]); + if (forJob) { + // Do not display RDD name on job page + node.shape = "circle"; + node.labelStyle = "font-size: 0px"; + } else { + node.labelStyle = "font-size: 12px"; + } + node.padding = "5"; + } + } + // Curve the edges + var edges = g.edges(); + for (var j = 0; j < edges.length; j++) { + var edge = g.edge(edges[j]); + edge.lineInterpolate = "basis"; + } + // Adjust vertical separation between nodes + if (forJob) { + g.graph().rankSep = JobPageVizConstants.rankSep; + } else { + g.graph().rankSep = StagePageVizConstants.rankSep; + } +} + +/* + * Helper function to size the SVG appropriately such that all elements are displyed. + * This assumes that all outermost elements are clusters (rectangles). + */ +function resizeSvg(svg) { + var allClusters = svg.selectAll("g.cluster rect")[0]; + var startX = -VizConstants.svgMarginX + + toFloat(d3.min(allClusters, function(e) { + return getAbsolutePosition(d3.select(e)).x; + })); + var startY = -VizConstants.svgMarginY + + toFloat(d3.min(allClusters, function(e) { + return getAbsolutePosition(d3.select(e)).y; + })); + var endX = VizConstants.svgMarginX + + toFloat(d3.max(allClusters, function(e) { + var t = d3.select(e); + return getAbsolutePosition(t).x + toFloat(t.attr("width")); + })); + var endY = VizConstants.svgMarginY + + toFloat(d3.max(allClusters, function(e) { + var t = d3.select(e); + return getAbsolutePosition(t).y + toFloat(t.attr("height")); + })); + var width = endX - startX; + var height = endY - startY; + svg.attr("viewBox", startX + " " + startY + " " + width + " " + height) + .attr("width", width) + .attr("height", height); +} + +/* + * (Job page only) Helper function to draw edges that cross stage boundaries. + * We need to do this manually because we render each stage separately in dagre-d3. + */ +function drawCrossStageEdges(edges, svgContainer) { + if (edges.length == 0) { + return; + } + // Draw the paths first + var edgesContainer = svgContainer.append("g").attr("id", "cross-stage-edges"); + for (var i = 0; i < edges.length; i++) { + var fromRDDId = edges[i][0]; + var toRDDId = edges[i][1]; + connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer); + } + // Now draw the arrows by borrowing the arrow marker generated by dagre-d3 + var dagreD3Marker = svgContainer.select("g.edgePaths marker"); + if (!dagreD3Marker.empty()) { + svgContainer + .append(function() { return dagreD3Marker.node().cloneNode(true); }) + .attr("id", "marker-arrow"); + svgContainer.selectAll("g > path").attr("marker-end", "url(#marker-arrow)"); + svgContainer.selectAll("g.edgePaths def").remove(); // We no longer need these + } +} + +/* + * (Job page only) Helper function to compute the absolute + * position of the specified element in our graph. + */ +function getAbsolutePosition(d3selection) { + if (d3selection.empty()) { + throw "Attempted to get absolute position of an empty selection."; + } + var obj = d3selection; + var _x = toFloat(obj.attr("x")) || 0; + var _y = toFloat(obj.attr("y")) || 0; + while (!obj.empty()) { + var transformText = obj.attr("transform"); + if (transformText) { + var translate = d3.transform(transformText).translate; + _x += toFloat(translate[0]); + _y += toFloat(translate[1]); + } + // Climb upwards to find how our parents are translated + obj = d3.select(obj.node().parentNode); + // Stop when we've reached the graph container itself + if (obj.node() == graphContainer().node()) { + break; + } + } + return { x: _x, y: _y }; +} + +/* (Job page only) Helper function to connect two RDDs with a curved edge. */ +function connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer) { + var fromNodeId = VizConstants.nodePrefix + fromRDDId; + var toNodeId = VizConstants.nodePrefix + toRDDId; + var fromPos = getAbsolutePosition(svgContainer.select("g." + fromNodeId)); + var toPos = getAbsolutePosition(svgContainer.select("g." + toNodeId)); + + // On the job page, RDDs are rendered as dots (circles). When rendering the path, + // we need to account for the radii of these circles. Otherwise the arrow heads + // will bleed into the circle itself. + var delta = toFloat(svgContainer + .select("g.node." + toNodeId) + .select("circle") + .attr("r")); + if (fromPos.x < toPos.x) { + fromPos.x += delta; + toPos.x -= delta; + } else if (fromPos.x > toPos.x) { + fromPos.x -= delta; + toPos.x += delta; + } + + var points; + if (fromPos.y == toPos.y) { + // If they are on the same rank, curve the middle part of the edge + // upward a little to avoid interference with things in between + // e.g. _______ + // _____/ \_____ + points = [ + [fromPos.x, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.2, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.3, fromPos.y - 20], + [fromPos.x + (toPos.x - fromPos.x) * 0.7, fromPos.y - 20], + [fromPos.x + (toPos.x - fromPos.x) * 0.8, toPos.y], + [toPos.x, toPos.y] + ]; + } else { + // Otherwise, draw a curved edge that flattens out on both ends + // e.g. _____ + // / + // | + // _____/ + points = [ + [fromPos.x, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.4, fromPos.y], + [fromPos.x + (toPos.x - fromPos.x) * 0.6, toPos.y], + [toPos.x, toPos.y] + ]; + } + + var line = d3.svg.line().interpolate("basis"); + edgesContainer.append("path").datum(points).attr("d", line); +} + +/* (Job page only) Helper function to add tooltips for RDDs. */ +function addTooltipsForRDDs(svgContainer) { + svgContainer.selectAll("g.node").each(function() { + var node = d3.select(this); + var tooltipText = node.attr("name"); + if (tooltipText) { + node.select("circle") + .attr("data-toggle", "tooltip") + .attr("data-placement", "bottom") + .attr("title", tooltipText); + } + // Link tooltips for all nodes that belong to the same RDD + node.on("mouseenter", function() { triggerTooltipForRDD(node, true); }); + node.on("mouseleave", function() { triggerTooltipForRDD(node, false); }); + }); + + $("[data-toggle=tooltip]") + .filter("g.node circle") + .tooltip({ container: "body", trigger: "manual" }); +} + +/* + * (Job page only) Helper function to show or hide tooltips for all nodes + * in the graph that refer to the same RDD the specified node represents. + */ +function triggerTooltipForRDD(d3node, show) { + var classes = d3node.node().classList; + for (var i = 0; i < classes.length; i++) { + var clazz = classes[i]; + var isRDDClass = clazz.indexOf(VizConstants.nodePrefix) == 0; + if (isRDDClass) { + graphContainer().selectAll("g." + clazz).each(function() { + var circle = d3.select(this).select("circle").node(); + var showOrHide = show ? "show" : "hide"; + $(circle).tooltip(showOrHide); + }); + } + } +} + +/* Helper function to convert attributes to numeric values. */ +function toFloat(f) { + if (f) { + return parseFloat(f.toString().replace(/px$/, "")); + } else { + return f; + } +} + diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css new file mode 100644 index 000000000000..0f400461c529 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +div#application-timeline, div#job-timeline { + margin-bottom: 30px; +} + +#application-timeline div.legend-area, +#job-timeline div.legend-area { + margin-top: 5px; +} + +#task-assignment-timeline div.legend-area { + width: 574px; +} + +#task-assignment-timeline .legend-area > svg { + width: 100%; + height: 55px; +} + +#task-assignment-timeline div.item.range { + padding: 0px; + height: 26px; + border-width: 0; +} + +.task-assignment-timeline-content { + width: 100%; +} + +.task-assignment-timeline-duration-bar { + width: 100%; + height: 26px; +} + +rect.scheduler-delay-proportion { + fill: #80B1D3; + stroke: #6B94B0; +} + +rect.deserialization-time-proportion { + fill: #FB8072; + stroke: #D26B5F; +} + +rect.shuffle-read-time-proportion { + fill: #FDB462; + stroke: #D39651; +} + +rect.executor-runtime-proportion { + fill: #B3DE69; + stroke: #95B957; +} + +rect.shuffle-write-time-proportion { + fill: #FFED6F; + stroke: #D5C65C; +} + +rect.serialization-time-proportion { + fill: #BC80BD; + stroke: #9D6B9E; +} + +rect.getting-result-time-proportion { + fill: #8DD3C7; + stroke: #75B0A6; +} + +.vis.timeline { + line-height: 14px; +} + +.vis.timeline div.content { + width: 100%; +} + +.vis.timeline .item.stage { + cursor: pointer; +} + +.vis.timeline .item.stage.succeeded { + background-color: #A0DFFF; + border-color: #3EC0FF; +} + +.vis.timeline .item.stage.succeeded.selected { + background-color: #A0DFFF; + border-color: #3EC0FF; + z-index: auto; +} + +.legend-area rect.completed-stage-legend { + fill: #A0DFFF; + stroke: #3EC0FF; +} + +.vis.timeline .item.stage.failed { + background-color: #FFA1B0; + border-color: #FF4D6D; +} + +.vis.timeline .item.stage.failed.selected { + background-color: #FFA1B0; + border-color: #FF4D6D; + z-index: auto; +} + +.legend-area rect.failed-stage-legend { + fill: #FFA1B0; + stroke: #FF4D6D; +} + +.vis.timeline .item.stage.running { + background-color: #A2FCC0; + border-color: #36F572; +} + +.vis.timeline .item.stage.running.selected { + background-color: #A2FCC0; + border-color: #36F572; + z-index: auto; +} + +.legend-area rect.active-stage-legend { + fill: #A2FCC0; + stroke: #36F572; +} + +.vis.timeline .foreground { + cursor: move; +} + +.vis.timeline .item.job { + cursor: pointer; +} + +.vis.timeline .item.job.succeeded { + background-color: #A0DFFF; + border-color: #3EC0FF; +} + +.vis.timeline .item.job.succeeded.selected { + background-color: #A0DFFF; + border-color: #3EC0FF; + z-index: auto; +} + +.legend-area rect.succeeded-job-legend { + fill: #A0DFFF; + stroke: #3EC0FF; +} + +.vis.timeline .item.job.failed { + background-color: #FFA1B0; + border-color: #FF4D6D; +} + +.vis.timeline .item.job.failed.selected { + background-color: #FFA1B0; + border-color: #FF4D6D; + z-index: auto; +} + +.legend-area rect.failed-job-legend { + fill: #FFA1B0; + stroke: #FF4D6D; +} + +.vis.timeline .item.job.running { + background-color: #A2FCC0; + border-color: #36F572; +} + +.vis.timeline .item.job.running.selected { + background-color: #A2FCC0; + border-color: #36F572; + z-index: auto; +} + +.legend-area rect.running-job-legend { + fill: #A2FCC0; + stroke: #36F572; +} + +.vis.timeline .item.executor.added { + background-color: #A0DFFF; + border-color: #3EC0FF; +} + +.legend-area rect.executor-added-legend { + fill: #A0DFFF; + stroke: #3EC0FF; +} + +.vis.timeline .item.executor.removed { + background-color: #FFA1B0; + border-color: #FF4D6D; +} + +.legend-area rect.executor-removed-legend { + fill: #FFA1B0; + stroke: #FF4D6D; +} + +.vis.timeline .item.executor.selected { + background-color: #A2FCC0; + border-color: #36F572; + z-index: 2; +} + +tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { + background-color: #D6FFE4 !important; +} + +#application-timeline.collapsed { + display: none; +} + +#job-timeline.collapsed { + display: none; +} + +#task-assignment-timeline.collapsed { + display: none; +} + +.control-panel { + margin-bottom: 5px; +} + +.control-panel input[type="checkbox"] { + margin: 0; +} + +span.expand-application-timeline, span.expand-job-timeline, +span.expand-task-assignment-timeline { + cursor: pointer; +} + +.control-panel input + span { + cursor: pointer; +} + +.vis.timeline .item.range .content { + position: unset; +} + +.vis.timeline .item .tooltip-inner { + max-width: unset !important; +} + +.vispanel.center { + font-size: 12px; + line-height: 12px; +} + +.legend-area text { + fill: #4D4D4D; +} + +.additional-metrics ul { + list-style: none; + margin-left: 15px; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js new file mode 100644 index 000000000000..f4453c71df1e --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +function drawApplicationTimeline(groupArray, eventObjArray, startTime) { + var groups = new vis.DataSet(groupArray); + var items = new vis.DataSet(eventObjArray); + var container = $("#application-timeline")[0]; + var options = { + groupOrder: function(a, b) { + return a.value - b.value + }, + editable: false, + showCurrentTime: false, + min: startTime, + zoomable: false + }; + + var applicationTimeline = new vis.Timeline(container); + applicationTimeline.setOptions(options); + applicationTimeline.setGroups(groups); + applicationTimeline.setItems(items); + + setupZoomable("#application-timeline-zoom-lock", applicationTimeline); + setupExecutorEventAction(); + + function setupJobEventAction() { + $(".item.range.job.application-timeline-object").each(function() { + var getSelectorForJobEntry = function(baseElem) { + var jobIdText = $($(baseElem).find(".application-timeline-content")[0]).text(); + var jobId = jobIdText.match("\\(Job (\\d+)\\)")[1]; + return "#job-" + jobId; + }; + + $(this).click(function() { + var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href") + window.location.href = jobPagePath + }); + + $(this).hover( + function() { + $(getSelectorForJobEntry(this)).addClass("corresponding-item-hover"); + $($(this).find("div.application-timeline-content")[0]).tooltip("show"); + }, + function() { + $(getSelectorForJobEntry(this)).removeClass("corresponding-item-hover"); + $($(this).find("div.application-timeline-content")[0]).tooltip("hide"); + } + ); + }); + } + + setupJobEventAction(); + + $("span.expand-application-timeline").click(function() { + var status = window.localStorage.getItem("expand-application-timeline") == "true"; + status = !status; + + $("#application-timeline").toggleClass('collapsed'); + + // Switch the class of the arrow from open to closed. + $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open'); + $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-application-timeline", "" + status); + }); +} + +$(function (){ + if (window.localStorage.getItem("expand-application-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-application-timeline", "false"); + $("span.expand-application-timeline").trigger('click'); + } +}); + +function drawJobTimeline(groupArray, eventObjArray, startTime) { + var groups = new vis.DataSet(groupArray); + var items = new vis.DataSet(eventObjArray); + var container = $('#job-timeline')[0]; + var options = { + groupOrder: function(a, b) { + return a.value - b.value; + }, + editable: false, + showCurrentTime: false, + min: startTime, + zoomable: false, + }; + + var jobTimeline = new vis.Timeline(container); + jobTimeline.setOptions(options); + jobTimeline.setGroups(groups); + jobTimeline.setItems(items); + + setupZoomable("#job-timeline-zoom-lock", jobTimeline); + setupExecutorEventAction(); + + function setupStageEventAction() { + $(".item.range.stage.job-timeline-object").each(function() { + var getSelectorForStageEntry = function(baseElem) { + var stageIdText = $($(baseElem).find(".job-timeline-content")[0]).text(); + var stageIdAndAttempt = stageIdText.match("\\(Stage (\\d+\\.\\d+)\\)")[1].split("."); + return "#stage-" + stageIdAndAttempt[0] + "-" + stageIdAndAttempt[1]; + }; + + $(this).click(function() { + var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href") + window.location.href = stagePagePath + }); + + $(this).hover( + function() { + $(getSelectorForStageEntry(this)).addClass("corresponding-item-hover"); + $($(this).find("div.job-timeline-content")[0]).tooltip("show"); + }, + function() { + $(getSelectorForStageEntry(this)).removeClass("corresponding-item-hover"); + $($(this).find("div.job-timeline-content")[0]).tooltip("hide"); + } + ); + }); + } + + setupStageEventAction(); + + $("span.expand-job-timeline").click(function() { + var status = window.localStorage.getItem("expand-job-timeline") == "true"; + status = !status; + + $("#job-timeline").toggleClass('collapsed'); + + // Switch the class of the arrow from open to closed. + $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open'); + $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-job-timeline", "" + status); + }); +} + +$(function (){ + if (window.localStorage.getItem("expand-job-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-job-timeline", "false"); + $("span.expand-job-timeline").trigger('click'); + } +}); + +function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { + var groups = new vis.DataSet(groupArray); + var items = new vis.DataSet(eventObjArray); + var container = $("#task-assignment-timeline")[0] + var options = { + groupOrder: function(a, b) { + return a.value - b.value + }, + editable: false, + align: 'left', + selectable: false, + showCurrentTime: false, + min: minLaunchTime, + max: maxFinishTime, + zoomable: false + }; + + var taskTimeline = new vis.Timeline(container) + taskTimeline.setOptions(options); + taskTimeline.setGroups(groups); + taskTimeline.setItems(items); + + // If a user zooms while a tooltip is displayed, the user may zoom such that the cursor is no + // longer over the task that the tooltip corresponds to. So, when a user zooms, we should hide + // any currently displayed tooltips. + var currentDisplayedTooltip = null; + $("#task-assignment-timeline").on({ + "mouseenter": function() { + currentDisplayedTooltip = this; + }, + "mouseleave": function() { + currentDisplayedTooltip = null; + } + }, ".task-assignment-timeline-content"); + taskTimeline.on("rangechange", function(prop) { + if (currentDisplayedTooltip !== null) { + $(currentDisplayedTooltip).tooltip("hide"); + } + }); + + setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); + + $("span.expand-task-assignment-timeline").click(function() { + var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true"; + status = !status; + + $("#task-assignment-timeline").toggleClass("collapsed"); + + // Switch the class of the arrow from open to closed. + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + + window.localStorage.setItem("expand-task-assignment-timeline", "" + status); + }); +} + +$(function (){ + if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-task-assignment-timeline", "false"); + $("span.expand-task-assignment-timeline").trigger('click'); + } +}); + +function setupExecutorEventAction() { + $(".item.box.executor").each(function () { + $(this).hover( + function() { + $($(this).find(".executor-event-content")[0]).tooltip("show"); + }, + function() { + $($(this).find(".executor-event-content")[0]).tooltip("hide"); + } + ); + }); +} + +function setupZoomable(id, timeline) { + $(id + ' > input[type="checkbox"]').click(function() { + if (this.checked) { + timeline.setOptions({zoomable: true}); + } else { + timeline.setOptions({zoomable: false}); + } + }); + + $(id + " > span").click(function() { + $(this).parent().find('input:checkbox').trigger('click'); + }); +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/vis.min.css b/core/src/main/resources/org/apache/spark/ui/static/vis.min.css new file mode 100644 index 000000000000..a390c40d6757 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/vis.min.css @@ -0,0 +1 @@ +.vis .overlay{position:absolute;top:0;left:0;width:100%;height:100%;z-index:10}.vis-active{box-shadow:0 0 10px #86d5f8}.vis [class*=span]{min-height:0;width:auto}.vis.timeline.root{position:relative;border:1px solid #bfbfbf;overflow:hidden;padding:0;margin:0;box-sizing:border-box}.vis.timeline .vispanel{position:absolute;padding:0;margin:0;box-sizing:border-box}.vis.timeline .vispanel.bottom,.vis.timeline .vispanel.center,.vis.timeline .vispanel.left,.vis.timeline .vispanel.right,.vis.timeline .vispanel.top{border:1px #bfbfbf}.vis.timeline .vispanel.center,.vis.timeline .vispanel.left,.vis.timeline .vispanel.right{border-top-style:solid;border-bottom-style:solid;overflow:hidden}.vis.timeline .vispanel.bottom,.vis.timeline .vispanel.center,.vis.timeline .vispanel.top{border-left-style:solid;border-right-style:solid}.vis.timeline .background{overflow:hidden}.vis.timeline .vispanel>.content{position:relative}.vis.timeline .vispanel .shadow{position:absolute;width:100%;height:1px;box-shadow:0 0 10px rgba(0,0,0,.8)}.vis.timeline .vispanel .shadow.top{top:-1px;left:0}.vis.timeline .vispanel .shadow.bottom{bottom:-1px;left:0}.vis.timeline .labelset{position:relative;overflow:hidden;box-sizing:border-box}.vis.timeline .labelset .vlabel{position:relative;left:0;top:0;width:100%;color:#4d4d4d;box-sizing:border-box;border-bottom:1px solid #bfbfbf}.vis.timeline .labelset .vlabel:last-child{border-bottom:none}.vis.timeline .labelset .vlabel .inner{display:inline-block;padding:5px}.vis.timeline .labelset .vlabel .inner.hidden{padding:0}.vis.timeline .itemset{position:relative;padding:0;margin:0;box-sizing:border-box}.vis.timeline .itemset .background,.vis.timeline .itemset .foreground{position:absolute;width:100%;height:100%;overflow:visible}.vis.timeline .axis{position:absolute;width:100%;height:0;left:0;z-index:1}.vis.timeline .foreground .group{position:relative;box-sizing:border-box;border-bottom:1px solid #bfbfbf}.vis.timeline .foreground .group:last-child{border-bottom:none}.vis.timeline .item{position:absolute;color:#1A1A1A;border-color:#97B0F8;border-width:1px;background-color:#D5DDF6;display:inline-block;padding:5px}.vis.timeline .item.selected{border-color:#FFC200;background-color:#FFF785;z-index:2}.vis.timeline .editable .item.selected{cursor:move}.vis.timeline .item.point.selected{background-color:#FFF785}.vis.timeline .item.box{text-align:center;border-style:solid;border-radius:2px}.vis.timeline .item.point{background:0 0}.vis.timeline .item.dot{position:absolute;padding:0;border-width:4px;border-style:solid;border-radius:4px}.vis.timeline .item.range{border-style:solid;border-radius:2px;box-sizing:border-box}.vis.timeline .item.background{overflow:hidden;border:none;background-color:rgba(213,221,246,.4);box-sizing:border-box;padding:0;margin:0}.vis.timeline .item.range .content{position:relative;display:inline-block;max-width:100%;overflow:hidden}.vis.timeline .item.background .content{position:absolute;display:inline-block;overflow:hidden;max-width:100%;margin:5px}.vis.timeline .item.line{padding:0;position:absolute;width:0;border-left-width:1px;border-left-style:solid}.vis.timeline .item .content{white-space:nowrap;overflow:hidden}.vis.timeline .item .delete{background:url(img/timeline/delete.png) top center no-repeat;position:absolute;width:24px;height:24px;top:0;right:-24px;cursor:pointer}.vis.timeline .item.range .drag-left{position:absolute;width:24px;height:100%;top:0;left:-4px;cursor:w-resize}.vis.timeline .item.range .drag-right{position:absolute;width:24px;height:100%;top:0;right:-4px;cursor:e-resize}.vis.timeline .timeaxis{position:relative;overflow:hidden}.vis.timeline .timeaxis.foreground{top:0;left:0;width:100%}.vis.timeline .timeaxis.background{position:absolute;top:0;left:0;width:100%;height:100%}.vis.timeline .timeaxis .text{position:absolute;color:#4d4d4d;padding:3px;white-space:nowrap}.vis.timeline .timeaxis .text.measure{position:absolute;padding-left:0;padding-right:0;margin-left:0;margin-right:0;visibility:hidden}.vis.timeline .timeaxis .grid.vertical{position:absolute;border-left:1px solid}.vis.timeline .timeaxis .grid.minor{border-color:#e5e5e5}.vis.timeline .timeaxis .grid.major{border-color:#bfbfbf}.vis.timeline .currenttime{background-color:#FF7F6E;width:2px;z-index:1}.vis.timeline .customtime{background-color:#6E94FF;width:2px;cursor:move;z-index:1}.vis.timeline .vispanel.background.horizontal .grid.horizontal{position:absolute;width:100%;height:0;border-bottom:1px solid}.vis.timeline .vispanel.background.horizontal .grid.minor{border-color:#e5e5e5}.vis.timeline .vispanel.background.horizontal .grid.major{border-color:#bfbfbf}.vis.timeline .dataaxis .yAxis.major{width:100%;position:absolute;color:#4d4d4d;white-space:nowrap}.vis.timeline .dataaxis .yAxis.major.measure{padding:0;margin:0;border:0;visibility:hidden;width:auto}.vis.timeline .dataaxis .yAxis.minor{position:absolute;width:100%;color:#bebebe;white-space:nowrap}.vis.timeline .dataaxis .yAxis.minor.measure{padding:0;margin:0;border:0;visibility:hidden;width:auto}.vis.timeline .dataaxis .yAxis.title{position:absolute;color:#4d4d4d;white-space:nowrap;bottom:20px;text-align:center}.vis.timeline .dataaxis .yAxis.title.measure{padding:0;margin:0;visibility:hidden;width:auto}.vis.timeline .dataaxis .yAxis.title.left{bottom:0;-webkit-transform-origin:left top;-moz-transform-origin:left top;-ms-transform-origin:left top;-o-transform-origin:left top;transform-origin:left bottom;-webkit-transform:rotate(-90deg);-moz-transform:rotate(-90deg);-ms-transform:rotate(-90deg);-o-transform:rotate(-90deg);transform:rotate(-90deg)}.vis.timeline .dataaxis .yAxis.title.right{bottom:0;-webkit-transform-origin:right bottom;-moz-transform-origin:right bottom;-ms-transform-origin:right bottom;-o-transform-origin:right bottom;transform-origin:right bottom;-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.vis.timeline .legend{background-color:rgba(247,252,255,.65);padding:5px;border-color:#b3b3b3;border-style:solid;border-width:1px;box-shadow:2px 2px 10px rgba(154,154,154,.55)}.vis.timeline .legendText{white-space:nowrap;display:inline-block}.vis.timeline .graphGroup0{fill:#4f81bd;fill-opacity:0;stroke-width:2px;stroke:#4f81bd}.vis.timeline .graphGroup1{fill:#f79646;fill-opacity:0;stroke-width:2px;stroke:#f79646}.vis.timeline .graphGroup2{fill:#8c51cf;fill-opacity:0;stroke-width:2px;stroke:#8c51cf}.vis.timeline .graphGroup3{fill:#75c841;fill-opacity:0;stroke-width:2px;stroke:#75c841}.vis.timeline .graphGroup4{fill:#ff0100;fill-opacity:0;stroke-width:2px;stroke:#ff0100}.vis.timeline .graphGroup5{fill:#37d8e6;fill-opacity:0;stroke-width:2px;stroke:#37d8e6}.vis.timeline .graphGroup6{fill:#042662;fill-opacity:0;stroke-width:2px;stroke:#042662}.vis.timeline .graphGroup7{fill:#00ff26;fill-opacity:0;stroke-width:2px;stroke:#00ff26}.vis.timeline .graphGroup8{fill:#f0f;fill-opacity:0;stroke-width:2px;stroke:#f0f}.vis.timeline .graphGroup9{fill:#8f3938;fill-opacity:0;stroke-width:2px;stroke:#8f3938}.vis.timeline .fill{fill-opacity:.1;stroke:none}.vis.timeline .bar{fill-opacity:.5;stroke-width:1px}.vis.timeline .point{stroke-width:2px;fill-opacity:1}.vis.timeline .legendBackground{stroke-width:1px;fill-opacity:.9;fill:#fff;stroke:#c2c2c2}.vis.timeline .outline{stroke-width:1px;fill-opacity:1;fill:#fff;stroke:#e5e5e5}.vis.timeline .iconFill{fill-opacity:.3;stroke:none}div.network-manipulationDiv{border-width:0;border-bottom:1px;border-style:solid;border-color:#d6d9d8;background:#fff;background:-moz-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#fff),color-stop(48%,#fcfcfc),color-stop(50%,#fafafa),color-stop(100%,#fcfcfc));background:-webkit-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-o-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-ms-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:linear-gradient(to bottom,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffffff', endColorstr='#fcfcfc', GradientType=0);position:absolute;left:0;top:0;width:100%;height:30px}div.network-manipulation-editMode{position:absolute;left:0;top:0;height:30px;margin-top:20px}div.network-manipulation-closeDiv{position:absolute;right:0;top:0;width:30px;height:30px;background-position:20px 3px;background-repeat:no-repeat;background-image:url(img/network/cross.png);cursor:pointer;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}div.network-manipulation-closeDiv:hover{opacity:.6}span.network-manipulationUI{font-family:verdana;font-size:12px;-moz-border-radius:15px;border-radius:15px;display:inline-block;background-position:0 0;background-repeat:no-repeat;height:24px;margin:-14px 0 0 10px;vertical-align:middle;cursor:pointer;padding:0 8px;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}span.network-manipulationUI:hover{box-shadow:1px 1px 8px rgba(0,0,0,.2)}span.network-manipulationUI:active{box-shadow:1px 1px 8px rgba(0,0,0,.5)}span.network-manipulationUI.back{background-image:url(img/network/backIcon.png)}span.network-manipulationUI.none:hover{box-shadow:1px 1px 8px transparent;cursor:default}span.network-manipulationUI.none:active{box-shadow:1px 1px 8px transparent}span.network-manipulationUI.none{padding:0}span.network-manipulationUI.notification{margin:2px;font-weight:700}span.network-manipulationUI.add{background-image:url(img/network/addNodeIcon.png)}span.network-manipulationUI.edit{background-image:url(img/network/editIcon.png)}span.network-manipulationUI.edit.editmode{background-color:#fcfcfc;border-style:solid;border-width:1px;border-color:#ccc}span.network-manipulationUI.connect{background-image:url(img/network/connectIcon.png)}span.network-manipulationUI.delete{background-image:url(img/network/deleteIcon.png)}span.network-manipulationLabel{margin:0 0 0 23px;line-height:25px}div.network-seperatorLine{display:inline-block;width:1px;height:20px;background-color:#bdbdbd;margin:5px 7px 0 15px}div.network-navigation_wrapper{position:absolute;left:0;top:0;width:100%;height:100%}div.network-navigation{width:34px;height:34px;-moz-border-radius:17px;border-radius:17px;position:absolute;display:inline-block;background-position:2px 2px;background-repeat:no-repeat;cursor:pointer;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}div.network-navigation:hover{box-shadow:0 0 3px 3px rgba(56,207,21,.3)}div.network-navigation:active{box-shadow:0 0 1px 3px rgba(56,207,21,.95)}div.network-navigation.up{background-image:url(img/network/upArrow.png);bottom:50px;left:55px}div.network-navigation.down{background-image:url(img/network/downArrow.png);bottom:10px;left:55px}div.network-navigation.left{background-image:url(img/network/leftArrow.png);bottom:10px;left:15px}div.network-navigation.right{background-image:url(img/network/rightArrow.png);bottom:10px;left:95px}div.network-navigation.zoomIn{background-image:url(img/network/plus.png);bottom:10px;right:15px}div.network-navigation.zoomOut{background-image:url(img/network/minus.png);bottom:10px;right:55px}div.network-navigation.zoomExtends{background-image:url(img/network/zoomExtends.png);bottom:50px;right:15px} \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/vis.min.js b/core/src/main/resources/org/apache/spark/ui/static/vis.min.js new file mode 100644 index 000000000000..2b3b1d60463f --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/vis.min.js @@ -0,0 +1,38 @@ +/** + * vis.js + * https://github.com/almende/vis + * + * A dynamic, browser-based visualization library. + * + * @version 3.9.0 + * @date 2015-01-16 + * + * @license + * Copyright (C) 2011-2014 Almende B.V, http://almende.com + * + * Vis.js is dual licensed under both + * + * * The Apache 2.0 License + * http://www.apache.org/licenses/LICENSE-2.0 + * + * and + * + * * The MIT License + * http://opensource.org/licenses/MIT + * + * Vis.js may be distributed under either license. + */ +"use strict";!function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define(e):"object"==typeof exports?exports.vis=e():t.vis=e()}(this,function(){return function(t){function e(s){if(i[s])return i[s].exports;var o=i[s]={exports:{},id:s,loaded:!1};return t[s].call(o.exports,o,o.exports,e),o.loaded=!0,o.exports}var i={};return e.m=t,e.c=i,e.p="",e(0)}([function(t,e,i){e.util=i(1),e.DOMutil=i(2),e.DataSet=i(3),e.DataView=i(4),e.Queue=i(5),e.Graph3d=i(6),e.graph3d={Camera:i(7),Filter:i(8),Point2d:i(9),Point3d:i(10),Slider:i(11),StepNumber:i(12)},e.Timeline=i(13),e.Graph2d=i(14),e.timeline={DateUtil:i(15),DataStep:i(16),Range:i(17),stack:i(18),TimeStep:i(19),components:{items:{Item:i(31),BackgroundItem:i(32),BoxItem:i(33),PointItem:i(34),RangeItem:i(35)},Component:i(20),CurrentTime:i(21),CustomTime:i(22),DataAxis:i(23),GraphGroup:i(24),Group:i(25),BackgroundGroup:i(26),ItemSet:i(27),Legend:i(28),LineGraph:i(29),TimeAxis:i(30)}},e.Network=i(36),e.network={Edge:i(37),Groups:i(38),Images:i(39),Node:i(40),Popup:i(41),dotparser:i(42),gephiParser:i(43)},e.Graph=function(){throw new Error("Graph is renamed to Network. Please create a graph as new vis.Network(...)")},e.moment=i(44),e.hammer=i(45),e.Hammer=i(45)},function(t,e,i){var s=i(44);e.isNumber=function(t){return t instanceof Number||"number"==typeof t},e.isString=function(t){return t instanceof String||"string"==typeof t},e.isDate=function(t){if(t instanceof Date)return!0;if(e.isString(t)){var i=o.exec(t);if(i)return!0;if(!isNaN(Date.parse(t)))return!0}return!1},e.isDataTable=function(t){return"undefined"!=typeof google&&google.visualization&&google.visualization.DataTable&&t instanceof google.visualization.DataTable},e.randomUUID=function(){var t=function(){return Math.floor(65536*Math.random()).toString(16)};return t()+t()+"-"+t()+"-"+t()+"-"+t()+"-"+t()+t()+t()},e.extend=function(t){for(var e=1,i=arguments.length;i>e;e++){var s=arguments[e];for(var o in s)s.hasOwnProperty(o)&&(t[o]=s[o])}return t},e.selectiveExtend=function(t,e){if(!Array.isArray(t))throw new Error("Array with property names expected as first argument");for(var i=2;ii;i++)if(t[i]!=e[i])return!1;return!0},e.convert=function(t,i){var n;if(void 0===t)return void 0;if(null===t)return null;if(!i)return t;if("string"!=typeof i&&!(i instanceof String))throw new Error("Type must be a string");switch(i){case"boolean":case"Boolean":return Boolean(t);case"number":case"Number":return Number(t.valueOf());case"string":case"String":return String(t);case"Date":if(e.isNumber(t))return new Date(t);if(t instanceof Date)return new Date(t.valueOf());if(s.isMoment(t))return new Date(t.valueOf());if(e.isString(t))return n=o.exec(t),n?new Date(Number(n[1])):s(t).toDate();throw new Error("Cannot convert object of type "+e.getType(t)+" to type Date");case"Moment":if(e.isNumber(t))return s(t);if(t instanceof Date)return s(t.valueOf());if(s.isMoment(t))return s(t);if(e.isString(t))return n=o.exec(t),s(n?Number(n[1]):t);throw new Error("Cannot convert object of type "+e.getType(t)+" to type Date");case"ISODate":if(e.isNumber(t))return new Date(t);if(t instanceof Date)return t.toISOString();if(s.isMoment(t))return t.toDate().toISOString();if(e.isString(t))return n=o.exec(t),n?new Date(Number(n[1])).toISOString():new Date(t).toISOString();throw new Error("Cannot convert object of type "+e.getType(t)+" to type ISODate");case"ASPDate":if(e.isNumber(t))return"/Date("+t+")/";if(t instanceof Date)return"/Date("+t.valueOf()+")/";if(e.isString(t)){n=o.exec(t);var r;return r=n?new Date(Number(n[1])).valueOf():new Date(t).valueOf(),"/Date("+r+")/"}throw new Error("Cannot convert object of type "+e.getType(t)+" to type ASPDate");default:throw new Error('Unknown type "'+i+'"')}};var o=/^\/?Date\((\-?\d+)/i;e.getType=function(t){var e=typeof t;return"object"==e?null==t?"null":t instanceof Boolean?"Boolean":t instanceof Number?"Number":t instanceof String?"String":Array.isArray(t)?"Array":t instanceof Date?"Date":"Object":"number"==e?"Number":"boolean"==e?"Boolean":"string"==e?"String":e},e.getAbsoluteLeft=function(t){return t.getBoundingClientRect().left},e.getAbsoluteTop=function(t){return t.getBoundingClientRect().top},e.addClassName=function(t,e){var i=t.className.split(" ");-1==i.indexOf(e)&&(i.push(e),t.className=i.join(" "))},e.removeClassName=function(t,e){var i=t.className.split(" "),s=i.indexOf(e);-1!=s&&(i.splice(s,1),t.className=i.join(" "))},e.forEach=function(t,e){var i,s;if(Array.isArray(t))for(i=0,s=t.length;s>i;i++)e(t[i],i,t);else for(i in t)t.hasOwnProperty(i)&&e(t[i],i,t)},e.toArray=function(t){var e=[];for(var i in t)t.hasOwnProperty(i)&&e.push(t[i]);return e},e.updateProperty=function(t,e,i){return t[e]!==i?(t[e]=i,!0):!1},e.addEventListener=function(t,e,i,s){t.addEventListener?(void 0===s&&(s=!1),"mousewheel"===e&&navigator.userAgent.indexOf("Firefox")>=0&&(e="DOMMouseScroll"),t.addEventListener(e,i,s)):t.attachEvent("on"+e,i)},e.removeEventListener=function(t,e,i,s){t.removeEventListener?(void 0===s&&(s=!1),"mousewheel"===e&&navigator.userAgent.indexOf("Firefox")>=0&&(e="DOMMouseScroll"),t.removeEventListener(e,i,s)):t.detachEvent("on"+e,i)},e.preventDefault=function(t){t||(t=window.event),t.preventDefault?t.preventDefault():t.returnValue=!1},e.getTarget=function(t){t||(t=window.event);var e;return t.target?e=t.target:t.srcElement&&(e=t.srcElement),void 0!=e.nodeType&&3==e.nodeType&&(e=e.parentNode),e},e.option={},e.option.asBoolean=function(t,e){return"function"==typeof t&&(t=t()),null!=t?0!=t:e||null},e.option.asNumber=function(t,e){return"function"==typeof t&&(t=t()),null!=t?Number(t)||e||null:e||null},e.option.asString=function(t,e){return"function"==typeof t&&(t=t()),null!=t?String(t):e||null},e.option.asSize=function(t,i){return"function"==typeof t&&(t=t()),e.isString(t)?t:e.isNumber(t)?t+"px":i||null},e.option.asElement=function(t,e){return"function"==typeof t&&(t=t()),t||e||null},e.hexToRGB=function(t){var e=/^#?([a-f\d])([a-f\d])([a-f\d])$/i;t=t.replace(e,function(t,e,i,s){return e+e+i+i+s+s});var i=/^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(t);return i?{r:parseInt(i[1],16),g:parseInt(i[2],16),b:parseInt(i[3],16)}:null},e.RGBToHex=function(t,e,i){return"#"+((1<<24)+(t<<16)+(e<<8)+i).toString(16).slice(1)},e.parseColor=function(t){var i;if(e.isString(t)){if(e.isValidRGB(t)){var s=t.substr(4).substr(0,t.length-5).split(",");t=e.RGBToHex(s[0],s[1],s[2])}if(e.isValidHex(t)){var o=e.hexToHSV(t),n={h:o.h,s:.45*o.s,v:Math.min(1,1.05*o.v)},r={h:o.h,s:Math.min(1,1.25*o.v),v:.6*o.v},a=e.HSVToHex(r.h,r.h,r.v),h=e.HSVToHex(n.h,n.s,n.v);i={background:t,border:a,highlight:{background:h,border:a},hover:{background:h,border:a}}}else i={background:t,border:t,highlight:{background:t,border:t},hover:{background:t,border:t}}}else i={},i.background=t.background||"white",i.border=t.border||i.background,e.isString(t.highlight)?i.highlight={border:t.highlight,background:t.highlight}:(i.highlight={},i.highlight.background=t.highlight&&t.highlight.background||i.background,i.highlight.border=t.highlight&&t.highlight.border||i.border),e.isString(t.hover)?i.hover={border:t.hover,background:t.hover}:(i.hover={},i.hover.background=t.hover&&t.hover.background||i.background,i.hover.border=t.hover&&t.hover.border||i.border);return i},e.RGBToHSV=function(t,e,i){t/=255,e/=255,i/=255;var s=Math.min(t,Math.min(e,i)),o=Math.max(t,Math.max(e,i));if(s==o)return{h:0,s:0,v:s};var n=t==s?e-i:i==s?t-e:i-t,r=t==s?3:i==s?1:5,a=60*(r-n/(o-s))/360,h=(o-s)/o,d=o;return{h:a,s:h,v:d}};var n={split:function(t){var e={};return t.split(";").forEach(function(t){if(""!=t.trim()){var i=t.split(":"),s=i[0].trim(),o=i[1].trim();e[s]=o}}),e},join:function(t){return Object.keys(t).map(function(e){return e+": "+t[e]}).join("; ")}};e.addCssText=function(t,i){var s=n.split(t.style.cssText),o=n.split(i),r=e.extend(s,o);t.style.cssText=n.join(r)},e.removeCssText=function(t,e){var i=n.split(t.style.cssText),s=n.split(e);for(var o in s)s.hasOwnProperty(o)&&delete i[o];t.style.cssText=n.join(i)},e.HSVToRGB=function(t,e,i){var s,o,n,r=Math.floor(6*t),a=6*t-r,h=i*(1-e),d=i*(1-a*e),l=i*(1-(1-a)*e);switch(r%6){case 0:s=i,o=l,n=h;break;case 1:s=d,o=i,n=h;break;case 2:s=h,o=i,n=l;break;case 3:s=h,o=d,n=i;break;case 4:s=l,o=h,n=i;break;case 5:s=i,o=h,n=d}return{r:Math.floor(255*s),g:Math.floor(255*o),b:Math.floor(255*n)}},e.HSVToHex=function(t,i,s){var o=e.HSVToRGB(t,i,s);return e.RGBToHex(o.r,o.g,o.b)},e.hexToHSV=function(t){var i=e.hexToRGB(t);return e.RGBToHSV(i.r,i.g,i.b)},e.isValidHex=function(t){var e=/(^#[0-9A-F]{6}$)|(^#[0-9A-F]{3}$)/i.test(t);return e},e.isValidRGB=function(t){t=t.replace(" ","");var e=/rgb\((\d{1,3}),(\d{1,3}),(\d{1,3})\)/i.test(t);return e},e.selectiveBridgeObject=function(t,i){if("object"==typeof i){for(var s=Object.create(i),o=0;o=r&&o>n;){var h=Math.floor((r+a)/2),d=t[h],l=void 0===s?d[i]:d[i][s],c=e(l);if(0==c)return h;-1==c?r=h+1:a=h-1,n++}return-1},e.binarySearchValue=function(t,e,i,s){for(var o,n,r,a,h=1e4,d=0,l=0,c=t.length-1;c>=l&&h>d;){if(a=Math.floor(.5*(c+l)),o=t[Math.max(0,a-1)][i],n=t[a][i],r=t[Math.min(t.length-1,a+1)][i],n==e)return a;if(e>o&&n>e)return"before"==s?Math.max(0,a-1):a;if(e>n&&r>e)return"before"==s?a:Math.min(t.length-1,a+1);e>n?l=a+1:c=a-1,d++}return-1},e.easeInOutQuad=function(t,e,i,s){var o=i-e;return t/=s/2,1>t?o/2*t*t+e:(t--,-o/2*(t*(t-2)-1)+e)},e.easingFunctions={linear:function(t){return t},easeInQuad:function(t){return t*t},easeOutQuad:function(t){return t*(2-t)},easeInOutQuad:function(t){return.5>t?2*t*t:-1+(4-2*t)*t},easeInCubic:function(t){return t*t*t},easeOutCubic:function(t){return--t*t*t+1},easeInOutCubic:function(t){return.5>t?4*t*t*t:(t-1)*(2*t-2)*(2*t-2)+1},easeInQuart:function(t){return t*t*t*t},easeOutQuart:function(t){return 1- --t*t*t*t},easeInOutQuart:function(t){return.5>t?8*t*t*t*t:1-8*--t*t*t*t},easeInQuint:function(t){return t*t*t*t*t},easeOutQuint:function(t){return 1+--t*t*t*t*t},easeInOutQuint:function(t){return.5>t?16*t*t*t*t*t:1+16*--t*t*t*t*t}}},function(t,e){e.prepareElements=function(t){for(var e in t)t.hasOwnProperty(e)&&(t[e].redundant=t[e].used,t[e].used=[])},e.cleanupElements=function(t){for(var e in t)if(t.hasOwnProperty(e)&&t[e].redundant){for(var i=0;i0?(s=e[t].redundant[0],e[t].redundant.shift()):(s=document.createElementNS("http://www.w3.org/2000/svg",t),i.appendChild(s)):(s=document.createElementNS("http://www.w3.org/2000/svg",t),e[t]={used:[],redundant:[]},i.appendChild(s)),e[t].used.push(s),s},e.getDOMElement=function(t,e,i,s){var o;return e.hasOwnProperty(t)?e[t].redundant.length>0?(o=e[t].redundant[0],e[t].redundant.shift()):(o=document.createElement(t),void 0!==s?i.insertBefore(o,s):i.appendChild(o)):(o=document.createElement(t),e[t]={used:[],redundant:[]},void 0!==s?i.insertBefore(o,s):i.appendChild(o)),e[t].used.push(o),o},e.drawPoint=function(t,i,s,o,n){var r;return"circle"==s.options.drawPoints.style?(r=e.getSVGElement("circle",o,n),r.setAttributeNS(null,"cx",t),r.setAttributeNS(null,"cy",i),r.setAttributeNS(null,"r",.5*s.options.drawPoints.size)):(r=e.getSVGElement("rect",o,n),r.setAttributeNS(null,"x",t-.5*s.options.drawPoints.size),r.setAttributeNS(null,"y",i-.5*s.options.drawPoints.size),r.setAttributeNS(null,"width",s.options.drawPoints.size),r.setAttributeNS(null,"height",s.options.drawPoints.size)),void 0!==s.options.drawPoints.styles&&r.setAttributeNS(null,"style",s.group.options.drawPoints.styles),r.setAttributeNS(null,"class",s.className+" point"),r},e.drawBar=function(t,i,s,o,n,r,a){if(0!=o){0>o&&(o*=-1,i-=o);var h=e.getSVGElement("rect",r,a);h.setAttributeNS(null,"x",t-.5*s),h.setAttributeNS(null,"y",i),h.setAttributeNS(null,"width",s),h.setAttributeNS(null,"height",o),h.setAttributeNS(null,"class",n)}}},function(t,e,i){function s(t,e){if(!t||Array.isArray(t)||o.isDataTable(t)||(e=t,t=null),this._options=e||{},this._data={},this._fieldId=this._options.fieldId||"id",this._type={},this._options.type)for(var i in this._options.type)if(this._options.type.hasOwnProperty(i)){var s=this._options.type[i];this._type[i]="Date"==s||"ISODate"==s||"ASPDate"==s?"Date":s}if(this._options.convert)throw new Error('Option "convert" is deprecated. Use "type" instead.');this._subscribers={},t&&this.add(t),this.setOptions(e)}var o=i(1),n=i(5);s.prototype.setOptions=function(t){t&&void 0!==t.queue&&(t.queue===!1?this._queue&&(this._queue.destroy(),delete this._queue):(this._queue||(this._queue=n.extend(this,{replace:["add","update","remove"]})),"object"==typeof t.queue&&this._queue.setOptions(t.queue)))},s.prototype.on=function(t,e){var i=this._subscribers[t];i||(i=[],this._subscribers[t]=i),i.push({callback:e})},s.prototype.subscribe=s.prototype.on,s.prototype.off=function(t,e){var i=this._subscribers[t];i&&(this._subscribers[t]=i.filter(function(t){return t.callback!=e}))},s.prototype.unsubscribe=s.prototype.off,s.prototype._trigger=function(t,e,i){if("*"==t)throw new Error("Cannot trigger event *");var s=[];t in this._subscribers&&(s=s.concat(this._subscribers[t])),"*"in this._subscribers&&(s=s.concat(this._subscribers["*"]));for(var o=0;or;r++)i=n._addItem(t[r]),s.push(i);else if(o.isDataTable(t))for(var h=this._getColumnNames(t),d=0,l=t.getNumberOfRows();l>d;d++){for(var c={},p=0,u=h.length;u>p;p++){var m=h[p];c[m]=t.getValue(d,p)}i=n._addItem(c),s.push(i)}else{if(!(t instanceof Object))throw new Error("Unknown dataType");i=n._addItem(t),s.push(i)}return s.length&&this._trigger("add",{items:s},e),s},s.prototype.update=function(t,e){var i=[],s=[],n=[],r=this,a=r._fieldId,h=function(t){var e=t[a];r._data[e]?(e=r._updateItem(t),s.push(e),n.push(t)):(e=r._addItem(t),i.push(e))};if(Array.isArray(t))for(var d=0,l=t.length;l>d;d++)h(t[d]);else if(o.isDataTable(t))for(var c=this._getColumnNames(t),p=0,u=t.getNumberOfRows();u>p;p++){for(var m={},f=0,g=c.length;g>f;f++){var v=c[f];m[v]=t.getValue(p,f)}h(m)}else{if(!(t instanceof Object))throw new Error("Unknown dataType");h(t)}return i.length&&this._trigger("add",{items:i},e),s.length&&this._trigger("update",{items:s,data:n},e),i.concat(s)},s.prototype.get=function(){var t,e,i,s,n=this,r=o.getType(arguments[0]);"String"==r||"Number"==r?(t=arguments[0],i=arguments[1],s=arguments[2]):"Array"==r?(e=arguments[0],i=arguments[1],s=arguments[2]):(i=arguments[0],s=arguments[1]);var a;if(i&&i.returnType){var h=["DataTable","Array","Object"];if(a=-1==h.indexOf(i.returnType)?"Array":i.returnType,s&&a!=o.getType(s))throw new Error('Type of parameter "data" ('+o.getType(s)+") does not correspond with specified options.type ("+i.type+")");if("DataTable"==a&&!o.isDataTable(s))throw new Error('Parameter "data" must be a DataTable when options.type is "DataTable"')}else a=s&&"DataTable"==o.getType(s)?"DataTable":"Array";var d,l,c,p,u=i&&i.type||this._options.type,m=i&&i.filter,f=[];if(void 0!=t)d=n._getItem(t,u),m&&!m(d)&&(d=null);else if(void 0!=e)for(c=0,p=e.length;p>c;c++)d=n._getItem(e[c],u),(!m||m(d))&&f.push(d);else for(l in this._data)this._data.hasOwnProperty(l)&&(d=n._getItem(l,u),(!m||m(d))&&f.push(d));if(i&&i.order&&void 0==t&&this._sort(f,i.order),i&&i.fields){var g=i.fields;if(void 0!=t)d=this._filterFields(d,g);else for(c=0,p=f.length;p>c;c++)f[c]=this._filterFields(f[c],g)}if("DataTable"==a){var v=this._getColumnNames(s);if(void 0!=t)n._appendRow(s,v,d);else for(c=0;cc;c++)s.push(f[c]);return s}return f},s.prototype.getIds=function(t){var e,i,s,o,n,r=this._data,a=t&&t.filter,h=t&&t.order,d=t&&t.type||this._options.type,l=[];if(a)if(h){n=[];for(s in r)r.hasOwnProperty(s)&&(o=this._getItem(s,d),a(o)&&n.push(o));for(this._sort(n,h),e=0,i=n.length;i>e;e++)l[e]=n[e][this._fieldId]}else for(s in r)r.hasOwnProperty(s)&&(o=this._getItem(s,d),a(o)&&l.push(o[this._fieldId]));else if(h){n=[];for(s in r)r.hasOwnProperty(s)&&n.push(r[s]);for(this._sort(n,h),e=0,i=n.length;i>e;e++)l[e]=n[e][this._fieldId]}else for(s in r)r.hasOwnProperty(s)&&(o=r[s],l.push(o[this._fieldId]));return l},s.prototype.getDataSet=function(){return this},s.prototype.forEach=function(t,e){var i,s,o=e&&e.filter,n=e&&e.type||this._options.type,r=this._data;if(e&&e.order)for(var a=this.get(e),h=0,d=a.length;d>h;h++)i=a[h],s=i[this._fieldId],t(i,s);else for(s in r)r.hasOwnProperty(s)&&(i=this._getItem(s,n),(!o||o(i))&&t(i,s))},s.prototype.map=function(t,e){var i,s=e&&e.filter,o=e&&e.type||this._options.type,n=[],r=this._data;for(var a in r)r.hasOwnProperty(a)&&(i=this._getItem(a,o),(!s||s(i))&&n.push(t(i,a)));return e&&e.order&&this._sort(n,e.order),n},s.prototype._filterFields=function(t,e){var i={};for(var s in t)t.hasOwnProperty(s)&&-1!=e.indexOf(s)&&(i[s]=t[s]);return i},s.prototype._sort=function(t,e){if(o.isString(e)){var i=e;t.sort(function(t,e){var s=t[i],o=e[i];return s>o?1:o>s?-1:0})}else{if("function"!=typeof e)throw new TypeError("Order must be a function or a string");t.sort(e)}},s.prototype.remove=function(t,e){var i,s,o,n=[];if(Array.isArray(t))for(i=0,s=t.length;s>i;i++)o=this._remove(t[i]),null!=o&&n.push(o);else o=this._remove(t),null!=o&&n.push(o);return n.length&&this._trigger("remove",{items:n},e),n},s.prototype._remove=function(t){if(o.isNumber(t)||o.isString(t)){if(this._data[t])return delete this._data[t],t}else if(t instanceof Object){var e=t[this._fieldId];if(e&&this._data[e])return delete this._data[e],e}return null},s.prototype.clear=function(t){var e=Object.keys(this._data);return this._data={},this._trigger("remove",{items:e},t),e},s.prototype.max=function(t){var e=this._data,i=null,s=null;for(var o in e)if(e.hasOwnProperty(o)){var n=e[o],r=n[t];null!=r&&(!i||r>s)&&(i=n,s=r)}return i},s.prototype.min=function(t){var e=this._data,i=null,s=null;for(var o in e)if(e.hasOwnProperty(o)){var n=e[o],r=n[t];null!=r&&(!i||s>r)&&(i=n,s=r)}return i},s.prototype.distinct=function(t){var e,i=this._data,s=[],n=this._options.type&&this._options.type[t]||null,r=0;for(var a in i)if(i.hasOwnProperty(a)){var h=i[a],d=h[t],l=!1;for(e=0;r>e;e++)if(s[e]==d){l=!0;break}l||void 0===d||(s[r]=d,r++)}if(n)for(e=0;ei;i++)e[i]=t.getColumnId(i)||t.getColumnLabel(i);return e},s.prototype._appendRow=function(t,e,i){for(var s=t.addRow(),o=0,n=e.length;n>o;o++){var r=e[o];t.setValue(s,o,i[r])}},t.exports=s},function(t,e,i){function s(t,e){this._data=null,this._ids={},this._options=e||{},this._fieldId="id",this._subscribers={};var i=this;this.listener=function(){i._onEvent.apply(i,arguments)},this.setData(t)}var o=i(1),n=i(3);s.prototype.setData=function(t){var e,i,s;if(this._data){this._data.unsubscribe&&this._data.unsubscribe("*",this.listener),e=[];for(var o in this._ids)this._ids.hasOwnProperty(o)&&e.push(o);this._ids={},this._trigger("remove",{items:e})}if(this._data=t,this._data){for(this._fieldId=this._options.fieldId||this._data&&this._data.options&&this._data.options.fieldId||"id",e=this._data.getIds({filter:this._options&&this._options.filter}),i=0,s=e.length;s>i;i++)o=e[i],this._ids[o]=!0;this._trigger("add",{items:e}),this._data.on&&this._data.on("*",this.listener)}},s.prototype.get=function(){var t,e,i,s=this,n=o.getType(arguments[0]);"String"==n||"Number"==n||"Array"==n?(t=arguments[0],e=arguments[1],i=arguments[2]):(e=arguments[0],i=arguments[1]);var r=o.extend({},this._options,e);this._options.filter&&e&&e.filter&&(r.filter=function(t){return s._options.filter(t)&&e.filter(t)});var a=[];return void 0!=t&&a.push(t),a.push(r),a.push(i),this._data&&this._data.get.apply(this._data,a)},s.prototype.getIds=function(t){var e;if(this._data){var i,s=this._options.filter;i=t&&t.filter?s?function(e){return s(e)&&t.filter(e)}:t.filter:s,e=this._data.getIds({filter:i,order:t&&t.order})}else e=[];return e},s.prototype.getDataSet=function(){for(var t=this;t instanceof s;)t=t._data;return t||null},s.prototype._onEvent=function(t,e,i){var s,o,n,r,a=e&&e.items,h=this._data,d=[],l=[],c=[];if(a&&h){switch(t){case"add":for(s=0,o=a.length;o>s;s++)n=a[s],r=this.get(n),r&&(this._ids[n]=!0,d.push(n));break;case"update":for(s=0,o=a.length;o>s;s++)n=a[s],r=this.get(n),r?this._ids[n]?l.push(n):(this._ids[n]=!0,d.push(n)):this._ids[n]&&(delete this._ids[n],c.push(n));break;case"remove":for(s=0,o=a.length;o>s;s++)n=a[s],this._ids[n]&&(delete this._ids[n],c.push(n))}d.length&&this._trigger("add",{items:d},i),l.length&&this._trigger("update",{items:l},i),c.length&&this._trigger("remove",{items:c},i)}},s.prototype.on=n.prototype.on,s.prototype.off=n.prototype.off,s.prototype._trigger=n.prototype._trigger,s.prototype.subscribe=s.prototype.on,s.prototype.unsubscribe=s.prototype.off,t.exports=s},function(t){function e(t){this.delay=null,this.max=1/0,this._queue=[],this._timeout=null,this._extended=null,this.setOptions(t)}e.prototype.setOptions=function(t){t&&"undefined"!=typeof t.delay&&(this.delay=t.delay),t&&"undefined"!=typeof t.max&&(this.max=t.max),this._flushIfNeeded()},e.extend=function(t,i){var s=new e(i);if(void 0!==t.flush)throw new Error("Target object already has a property flush");t.flush=function(){s.flush()};var o=[{name:"flush",original:void 0}];if(i&&i.replace)for(var n=0;nthis.max&&this.flush(),clearTimeout(this._timeout),this.queue.length>0&&"number"==typeof this.delay){var t=this;this._timeout=setTimeout(function(){t.flush()},this.delay)}},e.prototype.flush=function(){for(;this._queue.length>0;){var t=this._queue.shift();t.fn.apply(t.context||t.fn,t.args||[])}},t.exports=e},function(t,e,i){function s(t,e,i){if(!(this instanceof s))throw new SyntaxError("Constructor must be called with the new operator");this.containerElement=t,this.width="400px",this.height="400px",this.margin=10,this.defaultXCenter="55%",this.defaultYCenter="50%",this.xLabel="x",this.yLabel="y",this.zLabel="z";var o=function(t){return t};this.xValueLabel=o,this.yValueLabel=o,this.zValueLabel=o,this.filterLabel="time",this.legendLabel="value",this.style=s.STYLE.DOT,this.showPerspective=!0,this.showGrid=!0,this.keepAspectRatio=!0,this.showShadow=!1,this.showGrayBottom=!1,this.showTooltip=!1,this.verticalRatio=.5,this.animationInterval=1e3,this.animationPreload=!1,this.camera=new p,this.eye=new l(0,0,-1),this.dataTable=null,this.dataPoints=null,this.colX=void 0,this.colY=void 0,this.colZ=void 0,this.colValue=void 0,this.colFilter=void 0,this.xMin=0,this.xStep=void 0,this.xMax=1,this.yMin=0,this.yStep=void 0,this.yMax=1,this.zMin=0,this.zStep=void 0,this.zMax=1,this.valueMin=0,this.valueMax=1,this.xBarWidth=1,this.yBarWidth=1,this.colorAxis="#4D4D4D",this.colorGrid="#D3D3D3",this.colorDot="#7DC1FF",this.colorDotBorder="#3267D2",this.create(),this.setOptions(i),e&&this.setData(e)}function o(t){return"clientX"in t?t.clientX:t.targetTouches[0]&&t.targetTouches[0].clientX||0}function n(t){return"clientY"in t?t.clientY:t.targetTouches[0]&&t.targetTouches[0].clientY||0}var r=i(56),a=i(3),h=i(4),d=i(1),l=i(10),c=i(9),p=i(7),u=i(8),m=i(11),f=i(12);r(s.prototype),s.prototype._setScale=function(){this.scale=new l(1/(this.xMax-this.xMin),1/(this.yMax-this.yMin),1/(this.zMax-this.zMin)),this.keepAspectRatio&&(this.scale.x3&&(this.colFilter=3);else{if(this.style!==s.STYLE.DOTCOLOR&&this.style!==s.STYLE.DOTSIZE&&this.style!==s.STYLE.BARCOLOR&&this.style!==s.STYLE.BARSIZE)throw'Unknown style "'+this.style+'"';this.colX=0,this.colY=1,this.colZ=2,this.colValue=3,t.getNumberOfColumns()>4&&(this.colFilter=4)}},s.prototype.getNumberOfRows=function(t){return t.length},s.prototype.getNumberOfColumns=function(t){var e=0;for(var i in t[0])t[0].hasOwnProperty(i)&&e++;return e},s.prototype.getDistinctValues=function(t,e){for(var i=[],s=0;st[s][e]&&(i.min=t[s][e]),i.maxt;t++){var m=(t-p)/(u-p),g=240*m,v=this._hsv2rgb(g,1,1);c.strokeStyle=v,c.beginPath(),c.moveTo(h,r+t),c.lineTo(a,r+t),c.stroke()}c.strokeStyle=this.colorAxis,c.strokeRect(h,r,i,n)}if(this.style===s.STYLE.DOTSIZE&&(c.strokeStyle=this.colorAxis,c.fillStyle=this.colorDot,c.beginPath(),c.moveTo(h,r),c.lineTo(a,r),c.lineTo(a-i+e,d),c.lineTo(h,d),c.closePath(),c.fill(),c.stroke()),this.style===s.STYLE.DOTCOLOR||this.style===s.STYLE.DOTSIZE){var y=5,b=new f(this.valueMin,this.valueMax,(this.valueMax-this.valueMin)/5,!0);for(b.start(),b.getCurrent()0?this.yMin:this.yMax,o=this._convert3Dto2D(new l(x,r,this.zMin)),Math.cos(2*_)>0?(g.textAlign="center",g.textBaseline="top",o.y+=b):Math.sin(2*_)<0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(" "+this.xValueLabel(i.getCurrent())+" ",o.x,o.y),i.next()}for(g.lineWidth=1,s=void 0===this.defaultYStep,i=new f(this.yMin,this.yMax,this.yStep,s),i.start(),i.getCurrent()0?this.xMin:this.xMax,o=this._convert3Dto2D(new l(n,i.getCurrent(),this.zMin)),Math.cos(2*_)<0?(g.textAlign="center",g.textBaseline="top",o.y+=b):Math.sin(2*_)>0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(" "+this.yValueLabel(i.getCurrent())+" ",o.x,o.y),i.next();for(g.lineWidth=1,s=void 0===this.defaultZStep,i=new f(this.zMin,this.zMax,this.zStep,s),i.start(),i.getCurrent()0?this.xMin:this.xMax,r=Math.sin(_)<0?this.yMin:this.yMax;!i.end();)t=this._convert3Dto2D(new l(n,r,i.getCurrent())),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(t.x-b,t.y),g.stroke(),g.textAlign="right",g.textBaseline="middle",g.fillStyle=this.colorAxis,g.fillText(this.zValueLabel(i.getCurrent())+" ",t.x-5,t.y),i.next();g.lineWidth=1,t=this._convert3Dto2D(new l(n,r,this.zMin)),e=this._convert3Dto2D(new l(n,r,this.zMax)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(e.x,e.y),g.stroke(),g.lineWidth=1,p=this._convert3Dto2D(new l(this.xMin,this.yMin,this.zMin)),u=this._convert3Dto2D(new l(this.xMax,this.yMin,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(p.x,p.y),g.lineTo(u.x,u.y),g.stroke(),p=this._convert3Dto2D(new l(this.xMin,this.yMax,this.zMin)),u=this._convert3Dto2D(new l(this.xMax,this.yMax,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(p.x,p.y),g.lineTo(u.x,u.y),g.stroke(),g.lineWidth=1,t=this._convert3Dto2D(new l(this.xMin,this.yMin,this.zMin)),e=this._convert3Dto2D(new l(this.xMin,this.yMax,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(e.x,e.y),g.stroke(),t=this._convert3Dto2D(new l(this.xMax,this.yMin,this.zMin)),e=this._convert3Dto2D(new l(this.xMax,this.yMax,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(e.x,e.y),g.stroke();var w=this.xLabel;w.length>0&&(c=.1/this.scale.y,n=(this.xMin+this.xMax)/2,r=Math.cos(_)>0?this.yMin-c:this.yMax+c,o=this._convert3Dto2D(new l(n,r,this.zMin)),Math.cos(2*_)>0?(g.textAlign="center",g.textBaseline="top"):Math.sin(2*_)<0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(w,o.x,o.y));var S=this.yLabel;S.length>0&&(d=.1/this.scale.x,n=Math.sin(_)>0?this.xMin-d:this.xMax+d,r=(this.yMin+this.yMax)/2,o=this._convert3Dto2D(new l(n,r,this.zMin)),Math.cos(2*_)<0?(g.textAlign="center",g.textBaseline="top"):Math.sin(2*_)>0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(S,o.x,o.y));var M=this.zLabel;M.length>0&&(h=30,n=Math.cos(_)>0?this.xMin:this.xMax,r=Math.sin(_)<0?this.yMin:this.yMax,a=(this.zMin+this.zMax)/2,o=this._convert3Dto2D(new l(n,r,a)),g.textAlign="right",g.textBaseline="middle",g.fillStyle=this.colorAxis,g.fillText(M,o.x-h,o.y))},s.prototype._hsv2rgb=function(t,e,i){var s,o,n,r,a,h;switch(r=i*e,a=Math.floor(t/60),h=r*(1-Math.abs(t/60%2-1)),a){case 0:s=r,o=h,n=0;break;case 1:s=h,o=r,n=0;break;case 2:s=0,o=r,n=h;break;case 3:s=0,o=h,n=r;break;case 4:s=h,o=0,n=r;break;case 5:s=r,o=0,n=h;break;default:s=0,o=0,n=0}return"RGB("+parseInt(255*s)+","+parseInt(255*o)+","+parseInt(255*n)+")"},s.prototype._redrawDataGrid=function(){var t,e,i,o,n,r,a,h,d,c,p,u,m,f=this.frame.canvas,g=f.getContext("2d");if(!(void 0===this.dataPoints||this.dataPoints.length<=0)){for(n=0;n0}else r=!0;r?(m=(t.point.z+e.point.z+i.point.z+o.point.z)/4,c=240*(1-(m-this.zMin)*this.scale.z/this.verticalRatio),p=1,this.showShadow?(u=Math.min(1+S.x/M/2,1),a=this._hsv2rgb(c,p,u),h=a):(u=1,a=this._hsv2rgb(c,p,u),h=this.colorAxis)):(a="gray",h=this.colorAxis),d=.5,g.lineWidth=d,g.fillStyle=a,g.strokeStyle=h,g.beginPath(),g.moveTo(t.screen.x,t.screen.y),g.lineTo(e.screen.x,e.screen.y),g.lineTo(o.screen.x,o.screen.y),g.lineTo(i.screen.x,i.screen.y),g.closePath(),g.fill(),g.stroke()}}else for(n=0;np&&(p=0);var u,m,f;this.style===s.STYLE.DOTCOLOR?(u=240*(1-(d.point.value-this.valueMin)*this.scale.value),m=this._hsv2rgb(u,1,1),f=this._hsv2rgb(u,1,.8)):this.style===s.STYLE.DOTSIZE?(m=this.colorDot,f=this.colorDotBorder):(u=240*(1-(d.point.z-this.zMin)*this.scale.z/this.verticalRatio),m=this._hsv2rgb(u,1,1),f=this._hsv2rgb(u,1,.8)),i.lineWidth=1,i.strokeStyle=f,i.fillStyle=m,i.beginPath(),i.arc(d.screen.x,d.screen.y,p,0,2*Math.PI,!0),i.fill(),i.stroke()}}},s.prototype._redrawDataBar=function(){var t,e,i,o,n=this.frame.canvas,r=n.getContext("2d");if(!(void 0===this.dataPoints||this.dataPoints.length<=0)){for(t=0;t0&&(t=this.dataPoints[0],s.lineWidth=1,s.strokeStyle="blue",s.beginPath(),s.moveTo(t.screen.x,t.screen.y)),e=1;e0&&s.stroke()}},s.prototype._onMouseDown=function(t){if(t=t||window.event,this.leftButtonDown&&this._onMouseUp(t),this.leftButtonDown=t.which?1===t.which:1===t.button,this.leftButtonDown||this.touchDown){this.startMouseX=o(t),this.startMouseY=n(t),this.startStart=new Date(this.start),this.startEnd=new Date(this.end),this.startArmRotation=this.camera.getArmRotation(),this.frame.style.cursor="move";var e=this;this.onmousemove=function(t){e._onMouseMove(t)},this.onmouseup=function(t){e._onMouseUp(t)},d.addEventListener(document,"mousemove",e.onmousemove),d.addEventListener(document,"mouseup",e.onmouseup),d.preventDefault(t)}},s.prototype._onMouseMove=function(t){t=t||window.event;var e=parseFloat(o(t))-this.startMouseX,i=parseFloat(n(t))-this.startMouseY,s=this.startArmRotation.horizontal+e/200,r=this.startArmRotation.vertical+i/200,a=4,h=Math.sin(a/360*2*Math.PI);Math.abs(Math.sin(s))0?1:0>t?-1:0}var s=e[0],o=e[1],n=e[2],r=i((o.x-s.x)*(t.y-s.y)-(o.y-s.y)*(t.x-s.x)),a=i((n.x-o.x)*(t.y-o.y)-(n.y-o.y)*(t.x-o.x)),h=i((s.x-n.x)*(t.y-n.y)-(s.y-n.y)*(t.x-n.x));return!(0!=r&&0!=a&&r!=a||0!=a&&0!=h&&a!=h||0!=r&&0!=h&&r!=h)},s.prototype._dataPointFromXY=function(t,e){var i,o=100,n=null,r=null,a=null,h=new c(t,e);if(this.style===s.STYLE.BAR||this.style===s.STYLE.BARCOLOR||this.style===s.STYLE.BARSIZE)for(i=this.dataPoints.length-1;i>=0;i--){n=this.dataPoints[i];var d=n.surfaces;if(d)for(var l=d.length-1;l>=0;l--){var p=d[l],u=p.corners,m=[u[0].screen,u[1].screen,u[2].screen],f=[u[2].screen,u[3].screen,u[0].screen];if(this._insideTriangle(h,m)||this._insideTriangle(h,f))return n}}else for(i=0;ib)&&o>b&&(a=b,r=n)}}return r},s.prototype._showTooltip=function(t){var e,i,s;this.tooltip?(e=this.tooltip.dom.content,i=this.tooltip.dom.line,s=this.tooltip.dom.dot):(e=document.createElement("div"),e.style.position="absolute",e.style.padding="10px",e.style.border="1px solid #4d4d4d",e.style.color="#1a1a1a",e.style.background="rgba(255,255,255,0.7)",e.style.borderRadius="2px",e.style.boxShadow="5px 5px 10px rgba(128,128,128,0.5)",i=document.createElement("div"),i.style.position="absolute",i.style.height="40px",i.style.width="0",i.style.borderLeft="1px solid #4d4d4d",s=document.createElement("div"),s.style.position="absolute",s.style.height="0",s.style.width="0",s.style.border="5px solid #4d4d4d",s.style.borderRadius="5px",this.tooltip={dataPoint:null,dom:{content:e,line:i,dot:s}}),this._hideTooltip(),this.tooltip.dataPoint=t,e.innerHTML="function"==typeof this.showTooltip?this.showTooltip(t.point):"
x:"+t.point.x+"
y:"+t.point.y+"
z:"+t.point.z+"
",e.style.left="0",e.style.top="0",this.frame.appendChild(e),this.frame.appendChild(i),this.frame.appendChild(s);var o=e.offsetWidth,n=e.offsetHeight,r=i.offsetHeight,a=s.offsetWidth,h=s.offsetHeight,d=t.screen.x-o/2;d=Math.min(Math.max(d,10),this.frame.clientWidth-10-o),i.style.left=t.screen.x+"px",i.style.top=t.screen.y-r+"px",e.style.left=d+"px",e.style.top=t.screen.y-r-n+"px",s.style.left=t.screen.x-a/2+"px",s.style.top=t.screen.y-h/2+"px"},s.prototype._hideTooltip=function(){if(this.tooltip){this.tooltip.dataPoint=null;for(var t in this.tooltip.dom)if(this.tooltip.dom.hasOwnProperty(t)){var e=this.tooltip.dom[t];e&&e.parentNode&&e.parentNode.removeChild(e)}}},t.exports=s},function(t,e,i){function s(){this.armLocation=new o,this.armRotation={},this.armRotation.horizontal=0,this.armRotation.vertical=0,this.armLength=1.7,this.cameraLocation=new o,this.cameraRotation=new o(.5*Math.PI,0,0),this.calculateCameraOrientation()}var o=i(10);s.prototype.setArmLocation=function(t,e,i){this.armLocation.x=t,this.armLocation.y=e,this.armLocation.z=i,this.calculateCameraOrientation()},s.prototype.setArmRotation=function(t,e){void 0!==t&&(this.armRotation.horizontal=t),void 0!==e&&(this.armRotation.vertical=e,this.armRotation.vertical<0&&(this.armRotation.vertical=0),this.armRotation.vertical>.5*Math.PI&&(this.armRotation.vertical=.5*Math.PI)),(void 0!==t||void 0!==e)&&this.calculateCameraOrientation()},s.prototype.getArmRotation=function(){var t={};return t.horizontal=this.armRotation.horizontal,t.vertical=this.armRotation.vertical,t},s.prototype.setArmLength=function(t){void 0!==t&&(this.armLength=t,this.armLength<.71&&(this.armLength=.71),this.armLength>5&&(this.armLength=5),this.calculateCameraOrientation())},s.prototype.getArmLength=function(){return this.armLength},s.prototype.getCameraLocation=function(){return this.cameraLocation},s.prototype.getCameraRotation=function(){return this.cameraRotation},s.prototype.calculateCameraOrientation=function(){this.cameraLocation.x=this.armLocation.x-this.armLength*Math.sin(this.armRotation.horizontal)*Math.cos(this.armRotation.vertical),this.cameraLocation.y=this.armLocation.y-this.armLength*Math.cos(this.armRotation.horizontal)*Math.cos(this.armRotation.vertical),this.cameraLocation.z=this.armLocation.z+this.armLength*Math.sin(this.armRotation.vertical),this.cameraRotation.x=Math.PI/2-this.armRotation.vertical,this.cameraRotation.y=0,this.cameraRotation.z=-this.armRotation.horizontal},t.exports=s},function(t,e,i){function s(t,e,i){this.data=t,this.column=e,this.graph=i,this.index=void 0,this.value=void 0,this.values=i.getDistinctValues(t.get(),this.column),this.values.sort(function(t,e){return t>e?1:e>t?-1:0}),this.values.length>0&&this.selectValue(0),this.dataPoints=[],this.loaded=!1,this.onLoadCallback=void 0,i.animationPreload?(this.loaded=!1,this.loadInBackground()):this.loaded=!0}var o=i(4);s.prototype.isLoaded=function(){return this.loaded},s.prototype.getLoadedProgress=function(){for(var t=this.values.length,e=0;this.dataPoints[e];)e++;return Math.round(e/t*100)},s.prototype.getLabel=function(){return this.graph.filterLabel},s.prototype.getColumn=function(){return this.column},s.prototype.getSelectedValue=function(){return void 0===this.index?void 0:this.values[this.index]},s.prototype.getValues=function(){return this.values},s.prototype.getValue=function(t){if(t>=this.values.length)throw"Error: index out of range";return this.values[t]},s.prototype._getDataPoints=function(t){if(void 0===t&&(t=this.index),void 0===t)return[];var e;if(this.dataPoints[t])e=this.dataPoints[t];else{var i={};i.column=this.column,i.value=this.values[t];var s=new o(this.data,{filter:function(t){return t[i.column]==i.value}}).get();e=this.graph._getDataPoints(s),this.dataPoints[t]=e}return e},s.prototype.setOnLoadCallback=function(t){this.onLoadCallback=t},s.prototype.selectValue=function(t){if(t>=this.values.length)throw"Error: index out of range";this.index=t,this.value=this.values[t]},s.prototype.loadInBackground=function(t){void 0===t&&(t=0);var e=this.graph.frame;if(t0&&(t--,this.setIndex(t))},s.prototype.next=function(){var t=this.getIndex();t0?this.setIndex(0):this.index=void 0},s.prototype.setIndex=function(t){if(!(ts&&(s=0),s>this.values.length-1&&(s=this.values.length-1),s},s.prototype.indexToLeft=function(t){var e=parseFloat(this.frame.bar.style.width)-this.frame.slide.clientWidth-10,i=t/(this.values.length-1)*e,s=i+3;return s},s.prototype._onMouseMove=function(t){var e=t.clientX-this.startClientX,i=this.startSlideX+e,s=this.leftToIndex(i);this.setIndex(s),o.preventDefault()},s.prototype._onMouseUp=function(){this.frame.style.cursor="auto",o.removeEventListener(document,"mousemove",this.onmousemove),o.removeEventListener(document,"mouseup",this.onmouseup),o.preventDefault()},t.exports=s},function(t){function e(t,e,i,s){this._start=0,this._end=0,this._step=1,this.prettyStep=!0,this.precision=5,this._current=0,this.setRange(t,e,i,s)}e.prototype.setRange=function(t,e,i,s){this._start=t?t:0,this._end=e?e:0,this.setStep(i,s)},e.prototype.setStep=function(t,i){void 0===t||0>=t||(void 0!==i&&(this.prettyStep=i),this._step=this.prettyStep===!0?e.calculatePrettyStep(t):t)},e.calculatePrettyStep=function(t){var e=function(t){return Math.log(t)/Math.LN10},i=Math.pow(10,Math.round(e(t))),s=2*Math.pow(10,Math.round(e(t/2))),o=5*Math.pow(10,Math.round(e(t/5))),n=i;return Math.abs(s-t)<=Math.abs(n-t)&&(n=s),Math.abs(o-t)<=Math.abs(n-t)&&(n=o),0>=n&&(n=1),n},e.prototype.getCurrent=function(){return parseFloat(this._current.toPrecision(this.precision))},e.prototype.getStep=function(){return this._step},e.prototype.start=function(){this._current=this._start-this._start%this._step},e.prototype.next=function(){this._current+=this._step},e.prototype.end=function(){return this._current>this._end},t.exports=e},function(t,e,i){function s(t,e,i,r){if(!(this instanceof s))throw new SyntaxError("Constructor must be called with the new operator");if(!(Array.isArray(i)||i instanceof n)&&i instanceof Object){var h=r;r=i,i=h}var u=this;this.defaultOptions={start:null,end:null,autoResize:!0,orientation:"bottom",width:null,height:null,maxHeight:null,minHeight:null},this.options=o.deepExtend({},this.defaultOptions),this._create(t),this.components=[],this.body={dom:this.dom,domProps:this.props,emitter:{on:this.on.bind(this),off:this.off.bind(this),emit:this.emit.bind(this)},hiddenDates:[],util:{snap:null,toScreen:u._toScreen.bind(u),toGlobalScreen:u._toGlobalScreen.bind(u),toTime:u._toTime.bind(u),toGlobalTime:u._toGlobalTime.bind(u)}},this.range=new a(this.body),this.components.push(this.range),this.body.range=this.range,this.timeAxis=new d(this.body),this.components.push(this.timeAxis),this.body.util.snap=this.timeAxis.snap.bind(this.timeAxis),this.currentTime=new l(this.body),this.components.push(this.currentTime),this.customTime=new c(this.body),this.components.push(this.customTime),this.itemSet=new p(this.body),this.components.push(this.itemSet),this.itemsData=null,this.groupsData=null,r&&this.setOptions(r),i&&this.setGroups(i),e?this.setItems(e):this.redraw()}var o=(i(56),i(45),i(1)),n=i(3),r=i(4),a=i(17),h=i(46),d=i(30),l=i(21),c=i(22),p=i(27);s.prototype=new h,s.prototype.setItems=function(t){var e,i=null==this.itemsData;if(e=t?t instanceof n||t instanceof r?t:new n(t,{type:{start:"Date",end:"Date"}}):null,this.itemsData=e,this.itemSet&&this.itemSet.setItems(e),i)if(void 0!=this.options.start||void 0!=this.options.end){if(void 0==this.options.start||void 0==this.options.end)var s=this._getDataRange();var o=void 0!=this.options.start?this.options.start:s.start,a=void 0!=this.options.end?this.options.end:s.end;this.setWindow(o,a,{animate:!1})}else this.fit({animate:!1})},s.prototype.setGroups=function(t){var e;e=t?t instanceof n||t instanceof r?t:new n(t):null,this.groupsData=e,this.itemSet.setGroups(e)},s.prototype.setSelection=function(t,e){this.itemSet&&this.itemSet.setSelection(t),e&&e.focus&&this.focus(t,e)},s.prototype.getSelection=function(){return this.itemSet&&this.itemSet.getSelection()||[]},s.prototype.focus=function(t,e){if(this.itemsData&&void 0!=t){var i=Array.isArray(t)?t:[t],s=this.itemsData.getDataSet().get(i,{type:{start:"Date",end:"Date"}}),o=null,n=null;if(s.forEach(function(t){var e=t.start.valueOf(),i="end"in t?t.end.valueOf():t.start.valueOf();(null===o||o>e)&&(o=e),(null===n||i>n)&&(n=i)}),null!==o&&null!==n){var r=(o+n)/2,a=Math.max(this.range.end-this.range.start,1.1*(n-o)),h=e&&void 0!==e.animate?e.animate:!0;this.range.setRange(r-a/2,r+a/2,h)}}},s.prototype.getItemRange=function(){var t=this.itemsData.getDataSet(),e=null,i=null;if(t){var s=t.min("start");e=s?o.convert(s.start,"Date").valueOf():null;var n=t.max("start");n&&(i=o.convert(n.start,"Date").valueOf());var r=t.max("end");r&&(i=null==i?o.convert(r.end,"Date").valueOf():Math.max(i,o.convert(r.end,"Date").valueOf()))}return{min:null!=e?new Date(e):null,max:null!=i?new Date(i):null}},t.exports=s},function(t,e,i){function s(t,e,i,s){if(!(Array.isArray(i)||i instanceof n)&&i instanceof Object){var r=s;s=i,i=r}var h=this;this.defaultOptions={start:null,end:null,autoResize:!0,orientation:"bottom",width:null,height:null,maxHeight:null,minHeight:null},this.options=o.deepExtend({},this.defaultOptions),this._create(t),this.components=[],this.body={dom:this.dom,domProps:this.props,emitter:{on:this.on.bind(this),off:this.off.bind(this),emit:this.emit.bind(this)},hiddenDates:[],util:{snap:null,toScreen:h._toScreen.bind(h),toGlobalScreen:h._toGlobalScreen.bind(h),toTime:h._toTime.bind(h),toGlobalTime:h._toGlobalTime.bind(h)}},this.range=new a(this.body),this.components.push(this.range),this.body.range=this.range,this.timeAxis=new d(this.body),this.components.push(this.timeAxis),this.body.util.snap=this.timeAxis.snap.bind(this.timeAxis),this.currentTime=new l(this.body),this.components.push(this.currentTime),this.customTime=new c(this.body),this.components.push(this.customTime),this.linegraph=new p(this.body),this.components.push(this.linegraph),this.itemsData=null,this.groupsData=null,s&&this.setOptions(s),i&&this.setGroups(i),e?this.setItems(e):this.redraw()}var o=(i(56),i(45),i(1)),n=i(3),r=i(4),a=i(17),h=i(46),d=i(30),l=i(21),c=i(22),p=i(29);s.prototype=new h,s.prototype.setItems=function(t){var e,i=null==this.itemsData;if(e=t?t instanceof n||t instanceof r?t:new n(t,{type:{start:"Date",end:"Date"}}):null,this.itemsData=e,this.linegraph&&this.linegraph.setItems(e),i)if(void 0!=this.options.start||void 0!=this.options.end){var s=void 0!=this.options.start?this.options.start:null,o=void 0!=this.options.end?this.options.end:null;this.setWindow(s,o,{animate:!1})}else this.fit({animate:!1})},s.prototype.setGroups=function(t){var e;e=t?t instanceof n||t instanceof r?t:new n(t):null,this.groupsData=e,this.linegraph.setGroups(e)},s.prototype.getLegend=function(t,e,i){return void 0===e&&(e=15),void 0===i&&(i=15),void 0!==this.linegraph.groups[t]?this.linegraph.groups[t].getLegend(e,i):"cannot find group:"+t},s.prototype.isGroupVisible=function(t){return void 0!==this.linegraph.groups[t]?this.linegraph.groups[t].visible&&(void 0===this.linegraph.options.groups.visibility[t]||1==this.linegraph.options.groups.visibility[t]):!1},s.prototype.getItemRange=function(){var t=null,e=null;for(var i in this.linegraph.groups)if(this.linegraph.groups.hasOwnProperty(i)&&1==this.linegraph.groups[i].visible)for(var s=0;sr?r:t,e=null==e?r:r>e?r:e}return{min:null!=t?new Date(t):null,max:null!=e?new Date(e):null}},t.exports=s},function(t,e,i){var s=i(44);e.convertHiddenOptions=function(t,e){if(t.hiddenDates=[],e&&1==Array.isArray(e)){for(var i=0;i=4*a){var p=0,u=n.clone();switch(i[h].repeat){case"daily":d.day()!=l.day()&&(p=1),d.dayOfYear(o.dayOfYear()),d.year(o.year()),d.subtract(7,"days"),l.dayOfYear(o.dayOfYear()),l.year(o.year()),l.subtract(7-p,"days"),u.add(1,"weeks");break;case"weekly":var m=l.diff(d,"days"),f=d.day();d.date(o.date()),d.month(o.month()),d.year(o.year()),l=d.clone(),d.day(f),l.day(f),l.add(m,"days"),d.subtract(1,"weeks"),l.subtract(1,"weeks"),u.add(1,"weeks");break;case"monthly":d.month()!=l.month()&&(p=1),d.month(o.month()),d.year(o.year()),d.subtract(1,"months"),l.month(o.month()),l.year(o.year()),l.subtract(1,"months"),l.add(p,"months"),u.add(1,"months");break;case"yearly":d.year()!=l.year()&&(p=1),d.year(o.year()),d.subtract(1,"years"),l.year(o.year()),l.subtract(1,"years"),l.add(p,"years"),u.add(1,"years");break;default:return void console.log("Wrong repeat format, allowed are: daily, weekly, monthly, yearly. Given:",i[h].repeat)}for(;u>d;)switch(t.hiddenDates.push({start:d.valueOf(),end:l.valueOf()}),i[h].repeat){case"daily":d.add(1,"days"),l.add(1,"days");break;case"weekly":d.add(1,"weeks"),l.add(1,"weeks");break;case"monthly":d.add(1,"months"),l.add(1,"months");break;case"yearly":d.add(1,"y"),l.add(1,"y");break;default:return void console.log("Wrong repeat format, allowed are: daily, weekly, monthly, yearly. Given:",i[h].repeat)}t.hiddenDates.push({start:d.valueOf(),end:l.valueOf()})}}e.removeDuplicates(t);var g=e.isHidden(t.range.start,t.hiddenDates),v=e.isHidden(t.range.end,t.hiddenDates),y=t.range.start,b=t.range.end;1==g.hidden&&(y=1==t.range.startToFront?g.startDate-1:g.endDate+1),1==v.hidden&&(b=1==t.range.endToFront?v.startDate-1:v.endDate+1),(1==g.hidden||1==v.hidden)&&t.range._applyRange(y,b)}},e.removeDuplicates=function(t){for(var e=t.hiddenDates,i=[],s=0;s=e[s].start&&e[o].end<=e[s].end?e[o].remove=!0:e[o].start>=e[s].start&&e[o].start<=e[s].end?(e[s].end=e[o].end,e[o].remove=!0):e[o].end>=e[s].start&&e[o].end<=e[s].end&&(e[s].start=e[o].start,e[o].remove=!0));for(var s=0;s=r&&a>o){i=!0;break}}if(1==i&&o=e&&i>r&&(s+=r-n)}return s},e.correctTimeForHidden=function(t,i,o){return o=s(o).toDate().valueOf(),o-=e.getHiddenDurationBefore(t,i,o)},e.getHiddenDurationBefore=function(t,e,i){var o=0;i=s(i).toDate().valueOf();for(var n=0;n=e.start&&a=a&&(o+=a-r)}return o},e.getAccumulatedHiddenDuration=function(t,e,i){for(var s=0,o=0,n=e.start,r=0;r=e.start&&h=i)break;s+=h-a}}return s},e.snapAwayFromHidden=function(t,i,s,o){var n=e.isHidden(i,t);return 1==n.hidden?0>s?1==o?n.startDate-(n.endDate-i)-1:n.startDate-1:1==o?n.endDate+(i-n.startDate)+1:n.endDate+1:i},e.isHidden=function(t,e){for(var i=0;i=s&&o>t)return{hidden:!0,startDate:s,endDate:o}}return{hidden:!1,startDate:s,endDate:o}}},function(t){function e(t,e,i,s,o,n){this.current=0,this.autoScale=!0,this.stepIndex=0,this.step=1,this.scale=1,this.marginStart,this.marginEnd,this.deadSpace=0,this.majorSteps=[1,2,5,10],this.minorSteps=[.25,.5,1,2],this.alignZeros=n,this.setRange(t,e,i,s,o)}e.prototype.setRange=function(t,e,i,s,o){this._start=void 0===o.min?t:o.min,this._end=void 0===o.max?e:o.max,this._start==this._end&&(this._start-=.75,this._end+=1),1==this.autoScale&&this.setMinimumStep(i,s),this.setFirst(o)},e.prototype.setMinimumStep=function(t,e){var i=this._end-this._start,s=1.2*i,o=t*(s/e),n=Math.round(Math.log(s)/Math.LN10),r=-1,a=Math.pow(10,n),h=0;0>n&&(h=n);for(var d=!1,l=h;Math.abs(l)<=Math.abs(n);l++){a=Math.pow(10,l);for(var c=0;c=o){d=!0,r=c;break}}if(1==d)break}this.stepIndex=r,this.scale=a,this.step=a*this.minorSteps[r]},e.prototype.setFirst=function(t){void 0===t&&(t={});var e=void 0===t.min?this._start-2*this.scale*this.minorSteps[this.stepIndex]:t.min,i=void 0===t.max?this._end+this.scale*this.minorSteps[this.stepIndex]:t.max;this.marginEnd=void 0===t.max?this.roundToMinor(i):t.max,this.marginStart=void 0===t.min?this.roundToMinor(e):t.min,1==this.alignZeros&&(this.marginEnd-this.marginStart)%this.step!=0&&(this.marginEnd+=this.marginEnd%this.step),this.deadSpace=this.roundToMinor(i)-i+this.roundToMinor(e)-e,this.marginRange=this.marginEnd-this.marginStart,this.current=this.marginEnd},e.prototype.roundToMinor=function(t){var e=t-t%(this.scale*this.minorSteps[this.stepIndex]);return t%(this.scale*this.minorSteps[this.stepIndex])>.5*this.scale*this.minorSteps[this.stepIndex]?e+this.scale*this.minorSteps[this.stepIndex]:e},e.prototype.hasNext=function(){return this.current>=this.marginStart},e.prototype.next=function(){var t=this.current;this.current-=this.step,this.current==t&&(this.current=this._end)},e.prototype.previous=function(){this.current+=this.step,this.marginEnd+=this.step,this.marginRange=this.marginEnd-this.marginStart},e.prototype.getCurrent=function(t){var e=Math.abs(this.current)0;s--){if("0"!=i[s]){if("."==i[s]||","==i[s]){i=i.slice(0,s);break}break}i=i.slice(0,s)}}else{var o="",n=i.indexOf("e");if(-1!=n&&(o=i.slice(n),i=i.slice(0,n)),n=Math.max(i.indexOf(","),i.indexOf(".")),-1===n?(0!==t&&(i+="."),n=i.length+t):0!==t&&(n+=t+1),n>i.length)for(var r=n-i.length;r>0;r--)i+="0";else i=i.slice(0,n);i+=o}return i},e.prototype.snap=function(){},e.prototype.isMajor=function(){return this.current%(this.scale*this.majorSteps[this.stepIndex])==0},t.exports=e},function(t,e,i){function s(t,e){var i=a().hours(0).minutes(0).seconds(0).milliseconds(0);this.start=i.clone().add(-3,"days").valueOf(),this.end=i.clone().add(4,"days").valueOf(),this.body=t,this.deltaDifference=0,this.scaleOffset=0,this.startToFront=!1,this.endToFront=!0,this.defaultOptions={start:null,end:null,direction:"horizontal",moveable:!0,zoomable:!0,min:null,max:null,zoomMin:10,zoomMax:31536e10},this.options=r.extend({},this.defaultOptions),this.props={touch:{}},this.animateTimer=null,this.body.emitter.on("panstart",this._onDragStart.bind(this)),this.body.emitter.on("panmove",this._onDrag.bind(this)),this.body.emitter.on("panend",this._onDragEnd.bind(this)),this.body.emitter.on("press",this._onHold.bind(this)),this.body.emitter.on("mousewheel",this._onMouseWheel.bind(this)),this.body.emitter.on("touch",this._onTouch.bind(this)),this.body.emitter.on("pinch",this._onPinch.bind(this)),this.setOptions(e)}function o(t){if("horizontal"!=t&&"vertical"!=t)throw new TypeError('Unknown direction "'+t+'". Choose "horizontal" or "vertical".')}function n(t,e){return{x:t.x-r.getAbsoluteLeft(e),y:t.y-r.getAbsoluteTop(e)}}var r=i(1),a=(i(47),i(44)),h=i(20),d=i(15);s.prototype=new h,s.prototype.setOptions=function(t){if(t){var e=["direction","min","max","zoomMin","zoomMax","moveable","zoomable","activate","hiddenDates"];r.selectiveExtend(e,this.options,t),("start"in t||"end"in t)&&this.setRange(t.start,t.end)}},s.prototype.setRange=function(t,e,i,s){s!==!0&&(s=!1);var o=void 0!=t?r.convert(t,"Date").valueOf():null,n=void 0!=e?r.convert(e,"Date").valueOf():null;if(this._cancelAnimation(),i){var a=this,h=this.start,l=this.end,c="number"==typeof i?i:500,p=(new Date).valueOf(),u=!1,m=function(){if(!a.props.touch.dragging){var t=(new Date).valueOf(),e=t-p,i=e>c,g=i||null===o?o:r.easeInOutQuad(e,h,o,c),v=i||null===n?n:r.easeInOutQuad(e,l,n,c);f=a._applyRange(g,v),d.updateHiddenDates(a.body,a.options.hiddenDates),u=u||f,f&&a.body.emitter.emit("rangechange",{start:new Date(a.start),end:new Date(a.end),byUser:s}),i?u&&a.body.emitter.emit("rangechanged",{start:new Date(a.start),end:new Date(a.end),byUser:s}):a.animateTimer=setTimeout(m,20)}};return m()}var f=this._applyRange(o,n);if(d.updateHiddenDates(this.body,this.options.hiddenDates),f){var g={start:new Date(this.start),end:new Date(this.end),byUser:s};this.body.emitter.emit("rangechange",g),this.body.emitter.emit("rangechanged",g)}},s.prototype._cancelAnimation=function(){this.animateTimer&&(clearTimeout(this.animateTimer),this.animateTimer=null)},s.prototype._applyRange=function(t,e){var i,s=null!=t?r.convert(t,"Date").valueOf():this.start,o=null!=e?r.convert(e,"Date").valueOf():this.end,n=null!=this.options.max?r.convert(this.options.max,"Date").valueOf():null,a=null!=this.options.min?r.convert(this.options.min,"Date").valueOf():null;if(isNaN(s)||null===s)throw new Error('Invalid start "'+t+'"');if(isNaN(o)||null===o)throw new Error('Invalid end "'+e+'"');if(s>o&&(o=s),null!==a&&a>s&&(i=a-s,s+=i,o+=i,null!=n&&o>n&&(o=n)),null!==n&&o>n&&(i=o-n,s-=i,o-=i,null!=a&&a>s&&(s=a)),null!==this.options.zoomMin){var h=parseFloat(this.options.zoomMin);0>h&&(h=0),h>o-s&&(this.end-this.start===h?(s=this.start,o=this.end):(i=h-(o-s),s-=i/2,o+=i/2))}if(null!==this.options.zoomMax){var d=parseFloat(this.options.zoomMax);0>d&&(d=0),o-s>d&&(this.end-this.start===d?(s=this.start,o=this.end):(i=o-s-d,s+=i/2,o-=i/2))}var l=this.start!=s||this.end!=o;return s>=this.start&&s<=this.end||o>=this.start&&o<=this.end||this.start>=s&&this.start<=o||this.end>=s&&this.end<=o||this.body.emitter.emit("checkRangedItems"),this.start=s,this.end=o,l},s.prototype.getRange=function(){return{start:this.start,end:this.end}},s.prototype.conversion=function(t,e){return s.conversion(this.start,this.end,t,e)},s.conversion=function(t,e,i,s){return void 0===s&&(s=0),0!=i&&e-t!=0?{offset:t,scale:i/(e-t-s)}:{offset:0,scale:1}},s.prototype._onDragStart=function(t){this.deltaDifference=0,this.previousDelta=0,this.options.moveable&&this.props.touch.allowDragging&&(this.props.touch.start=this.start,this.props.touch.end=this.end,this.props.touch.dragging=!0,this.body.dom.root&&(this.body.dom.root.style.cursor="move"),t.preventDefault())},s.prototype._onDrag=function(t){if(this.options.moveable&&this.props.touch.allowDragging){var e=this.options.direction;o(e);var i="horizontal"==e?t.deltaX:t.deltaY;i-=this.deltaDifference;var s=this.props.touch.end-this.props.touch.start,n=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end);s-=n;var r="horizontal"==e?this.body.domProps.center.width:this.body.domProps.center.height,a=-i/r*s,h=this.props.touch.start+a,l=this.props.touch.end+a,c=d.snapAwayFromHidden(this.body.hiddenDates,h,this.previousDelta-i,!0),p=d.snapAwayFromHidden(this.body.hiddenDates,l,this.previousDelta-i,!0);if(c!=h||p!=l)return this.deltaDifference+=i,this.props.touch.start=c,this.props.touch.end=p,void this._onDrag(t);this.previousDelta=i,this._applyRange(h,l),this.body.emitter.emit("rangechange",{start:new Date(this.start),end:new Date(this.end),byUser:!0}),t.preventDefault()}},s.prototype._onDragEnd=function(){this.options.moveable&&this.props.touch.allowDragging&&(this.props.touch.dragging=!1,this.body.dom.root&&(this.body.dom.root.style.cursor="auto"),this.body.emitter.emit("rangechanged",{start:new Date(this.start),end:new Date(this.end),byUser:!0}))},s.prototype._onMouseWheel=function(t){if(this.options.zoomable&&this.options.moveable){var e=0;if(t.wheelDelta?e=t.wheelDelta/120:t.detail&&(e=-t.detail/3),e){var i;i=0>e?1-e/5:1/(1+e/5);var s=n({x:t.pageX,y:t.pageY},this.body.dom.center),o=this._pointerToDate(s);this.zoom(i,o,e)}t.preventDefault()}},s.prototype._onTouch=function(){this.props.touch.start=this.start,this.props.touch.end=this.end,this.props.touch.allowDragging=!0,this.props.touch.center=null,this.scaleOffset=0,this.deltaDifference=0},s.prototype._onHold=function(){this.props.touch.allowDragging=!1},s.prototype._onPinch=function(t){if(this.options.zoomable&&this.options.moveable){this.props.touch.allowDragging=!1,this.props.touch.center||(this.props.touch.center=n(t.center,this.body.dom.center));var e=1/(t.scale+this.scaleOffset),i=this._pointerToDate(this.props.touch.center),s=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end),o=d.getHiddenDurationBefore(this.body.hiddenDates,this,i),r=s-o,a=i-o+(this.props.touch.start-(i-o))*e,h=i+r+(this.props.touch.end-(i+r))*e;this.startToFront=0>=1-e,this.endToFront=0>=e-1;var l=d.snapAwayFromHidden(this.body.hiddenDates,a,1-e,!0),c=d.snapAwayFromHidden(this.body.hiddenDates,h,e-1,!0);(l!=a||c!=h)&&(this.props.touch.start=l,this.props.touch.end=c,this.scaleOffset=1-t.scale,a=l,h=c),this.setRange(a,h,!1,!0),this.startToFront=!1,this.endToFront=!0,t.preventDefault()}},s.prototype._pointerToDate=function(t){var e,i=this.options.direction;if(o(i),"horizontal"==i)return this.body.util.toTime(t.x).valueOf();var s=this.body.domProps.center.height;return e=this.conversion(s),t.y/e.scale+e.offset},s.prototype.zoom=function(t,e,i){null==e&&(e=(this.start+this.end)/2);var s=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end),o=d.getHiddenDurationBefore(this.body.hiddenDates,this,e),n=s-o,r=e-o+(this.start-(e-o))*t,a=e+n+(this.end-(e+n))*t;this.startToFront=i>0?!1:!0,this.endToFront=-i>0?!1:!0;var h=d.snapAwayFromHidden(this.body.hiddenDates,r,i,!0),l=d.snapAwayFromHidden(this.body.hiddenDates,a,-i,!0);(h!=r||l!=a)&&(r=h,a=l),this.setRange(r,a,!1,!0),this.startToFront=!1,this.endToFront=!0},s.prototype.move=function(t){var e=this.end-this.start,i=this.start+e*t,s=this.end+e*t;this.start=i,this.end=s},s.prototype.moveTo=function(t){var e=(this.start+this.end)/2,i=e-t,s=this.start-i,o=this.end-i;this.setRange(s,o)},t.exports=s},function(t,e){var i=.001;e.orderByStart=function(t){t.sort(function(t,e){return t.data.start-e.data.start})},e.orderByEnd=function(t){t.sort(function(t,e){var i="end"in t.data?t.data.end:t.data.start,s="end"in e.data?e.data.end:e.data.start;return i-s})},e.stack=function(t,i,s){var o,n;if(s)for(o=0,n=t.length;n>o;o++)t[o].top=null;for(o=0,n=t.length;n>o;o++){var r=t[o];if(r.stack&&null===r.top){r.top=i.axis;do{for(var a=null,h=0,d=t.length;d>h;h++){var l=t[h];if(null!==l.top&&l!==r&&l.stack&&e.collision(r,l,i.item)){a=l;break}}null!=a&&(r.top=a.top+a.height+i.item.vertical)}while(a)}}},e.nostack=function(t,e,i){var s,o,n;for(s=0,o=t.length;o>s;s++)if(void 0!==t[s].data.subgroup){n=e.axis;for(var r in i)i.hasOwnProperty(r)&&1==i[r].visible&&i[r].indexe.left&&t.top-s.vertical+ie.top}},function(t,e,i){function s(t,e,i,o){this.current=new Date,this._start=new Date,this._end=new Date,this.autoScale=!0,this.scale="day",this.step=1,this.setRange(t,e,i),this.switchedDay=!1,this.switchedMonth=!1,this.switchedYear=!1,this.hiddenDates=o,void 0===o&&(this.hiddenDates=[]),this.format=s.FORMAT}var o=i(44),n=i(15),r=i(1);s.FORMAT={minorLabels:{millisecond:"SSS",second:"s",minute:"HH:mm",hour:"HH:mm",weekday:"ddd D",day:"D",month:"MMM",year:"YYYY"},majorLabels:{millisecond:"HH:mm:ss",second:"D MMMM HH:mm",minute:"ddd D MMMM",hour:"ddd D MMMM",weekday:"MMMM YYYY",day:"MMMM YYYY",month:"YYYY",year:""}},s.prototype.setFormat=function(t){var e=r.deepExtend({},s.FORMAT);this.format=r.deepExtend(e,t)},s.prototype.setRange=function(t,e,i){if(!(t instanceof Date&&e instanceof Date))throw"No legal start or end date in method setRange";this._start=void 0!=t?new Date(t.valueOf()):new Date,this._end=void 0!=e?new Date(e.valueOf()):new Date,this.autoScale&&this.setMinimumStep(i)},s.prototype.first=function(){this.current=new Date(this._start.valueOf()),this.roundToMinor()},s.prototype.roundToMinor=function(){switch(this.scale){case"year":this.current.setFullYear(this.step*Math.floor(this.current.getFullYear()/this.step)),this.current.setMonth(0);case"month":this.current.setDate(1);case"day":case"weekday":this.current.setHours(0);case"hour":this.current.setMinutes(0);case"minute":this.current.setSeconds(0);case"second":this.current.setMilliseconds(0)}if(1!=this.step)switch(this.scale){case"millisecond":this.current.setMilliseconds(this.current.getMilliseconds()-this.current.getMilliseconds()%this.step);break;case"second":this.current.setSeconds(this.current.getSeconds()-this.current.getSeconds()%this.step);break;case"minute":this.current.setMinutes(this.current.getMinutes()-this.current.getMinutes()%this.step);break;case"hour":this.current.setHours(this.current.getHours()-this.current.getHours()%this.step);break;case"weekday":case"day":this.current.setDate(this.current.getDate()-1-(this.current.getDate()-1)%this.step+1);break;case"month":this.current.setMonth(this.current.getMonth()-this.current.getMonth()%this.step);break;case"year":this.current.setFullYear(this.current.getFullYear()-this.current.getFullYear()%this.step)}},s.prototype.hasNext=function(){return this.current.valueOf()<=this._end.valueOf()},s.prototype.next=function(){var t=this.current.valueOf();if(this.current.getMonth()<6)switch(this.scale){case"millisecond":this.current=new Date(this.current.valueOf()+this.step);break;case"second":this.current=new Date(this.current.valueOf()+1e3*this.step);break;case"minute":this.current=new Date(this.current.valueOf()+1e3*this.step*60);break;case"hour":this.current=new Date(this.current.valueOf()+1e3*this.step*60*60);var e=this.current.getHours();this.current.setHours(e-e%this.step);break;case"weekday":case"day":this.current.setDate(this.current.getDate()+this.step);break;case"month":this.current.setMonth(this.current.getMonth()+this.step);break;case"year":this.current.setFullYear(this.current.getFullYear()+this.step)}else switch(this.scale){case"millisecond":this.current=new Date(this.current.valueOf()+this.step);break;case"second":this.current.setSeconds(this.current.getSeconds()+this.step);break;case"minute":this.current.setMinutes(this.current.getMinutes()+this.step); +break;case"hour":this.current.setHours(this.current.getHours()+this.step);break;case"weekday":case"day":this.current.setDate(this.current.getDate()+this.step);break;case"month":this.current.setMonth(this.current.getMonth()+this.step);break;case"year":this.current.setFullYear(this.current.getFullYear()+this.step)}if(1!=this.step)switch(this.scale){case"millisecond":this.current.getMilliseconds()0&&(this.step=e),this.autoScale=!1},s.prototype.setAutoScale=function(t){this.autoScale=t},s.prototype.setMinimumStep=function(t){if(void 0!=t){var e=31104e6,i=2592e6,s=864e5,o=36e5,n=6e4,r=1e3,a=1;1e3*e>t&&(this.scale="year",this.step=1e3),500*e>t&&(this.scale="year",this.step=500),100*e>t&&(this.scale="year",this.step=100),50*e>t&&(this.scale="year",this.step=50),10*e>t&&(this.scale="year",this.step=10),5*e>t&&(this.scale="year",this.step=5),e>t&&(this.scale="year",this.step=1),3*i>t&&(this.scale="month",this.step=3),i>t&&(this.scale="month",this.step=1),5*s>t&&(this.scale="day",this.step=5),2*s>t&&(this.scale="day",this.step=2),s>t&&(this.scale="day",this.step=1),s/2>t&&(this.scale="weekday",this.step=1),4*o>t&&(this.scale="hour",this.step=4),o>t&&(this.scale="hour",this.step=1),15*n>t&&(this.scale="minute",this.step=15),10*n>t&&(this.scale="minute",this.step=10),5*n>t&&(this.scale="minute",this.step=5),n>t&&(this.scale="minute",this.step=1),15*r>t&&(this.scale="second",this.step=15),10*r>t&&(this.scale="second",this.step=10),5*r>t&&(this.scale="second",this.step=5),r>t&&(this.scale="second",this.step=1),200*a>t&&(this.scale="millisecond",this.step=200),100*a>t&&(this.scale="millisecond",this.step=100),50*a>t&&(this.scale="millisecond",this.step=50),10*a>t&&(this.scale="millisecond",this.step=10),5*a>t&&(this.scale="millisecond",this.step=5),a>t&&(this.scale="millisecond",this.step=1)}},s.prototype.snap=function(t){var e=new Date(t.valueOf());if("year"==this.scale){var i=e.getFullYear()+Math.round(e.getMonth()/12);e.setFullYear(Math.round(i/this.step)*this.step),e.setMonth(0),e.setDate(0),e.setHours(0),e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0)}else if("month"==this.scale)e.getDate()>15?(e.setDate(1),e.setMonth(e.getMonth()+1)):e.setDate(1),e.setHours(0),e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0);else if("day"==this.scale){switch(this.step){case 5:case 2:e.setHours(24*Math.round(e.getHours()/24));break;default:e.setHours(12*Math.round(e.getHours()/12))}e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0)}else if("weekday"==this.scale){switch(this.step){case 5:case 2:e.setHours(12*Math.round(e.getHours()/12));break;default:e.setHours(6*Math.round(e.getHours()/6))}e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0)}else if("hour"==this.scale){switch(this.step){case 4:e.setMinutes(60*Math.round(e.getMinutes()/60));break;default:e.setMinutes(30*Math.round(e.getMinutes()/30))}e.setSeconds(0),e.setMilliseconds(0)}else if("minute"==this.scale){switch(this.step){case 15:case 10:e.setMinutes(5*Math.round(e.getMinutes()/5)),e.setSeconds(0);break;case 5:e.setSeconds(60*Math.round(e.getSeconds()/60));break;default:e.setSeconds(30*Math.round(e.getSeconds()/30))}e.setMilliseconds(0)}else if("second"==this.scale)switch(this.step){case 15:case 10:e.setSeconds(5*Math.round(e.getSeconds()/5)),e.setMilliseconds(0);break;case 5:e.setMilliseconds(1e3*Math.round(e.getMilliseconds()/1e3));break;default:e.setMilliseconds(500*Math.round(e.getMilliseconds()/500))}else if("millisecond"==this.scale){var s=this.step>5?this.step/2:1;e.setMilliseconds(Math.round(e.getMilliseconds()/s)*s)}return e},s.prototype.isMajor=function(){if(1==this.switchedYear)switch(this.switchedYear=!1,this.scale){case"year":case"month":case"weekday":case"day":case"hour":case"minute":case"second":case"millisecond":return!0;default:return!1}else if(1==this.switchedMonth)switch(this.switchedMonth=!1,this.scale){case"weekday":case"day":case"hour":case"minute":case"second":case"millisecond":return!0;default:return!1}else if(1==this.switchedDay)switch(this.switchedDay=!1,this.scale){case"millisecond":case"second":case"minute":case"hour":return!0;default:return!1}switch(this.scale){case"millisecond":return 0==this.current.getMilliseconds();case"second":return 0==this.current.getSeconds();case"minute":return 0==this.current.getHours()&&0==this.current.getMinutes();case"hour":return 0==this.current.getHours();case"weekday":case"day":return 1==this.current.getDate();case"month":return 0==this.current.getMonth();case"year":return!1;default:return!1}},s.prototype.getLabelMinor=function(t){void 0==t&&(t=this.current);var e=this.format.minorLabels[this.scale];return e&&e.length>0?o(t).format(e):""},s.prototype.getLabelMajor=function(t){void 0==t&&(t=this.current);var e=this.format.majorLabels[this.scale];return e&&e.length>0?o(t).format(e):""},s.prototype.getClassName=function(){function t(t){return t/h%2==0?" even":" odd"}function e(t){return t.isSame(new Date,"day")?" today":t.isSame(o().add(1,"day"),"day")?" tomorrow":t.isSame(o().add(-1,"day"),"day")?" yesterday":""}function i(t){return t.isSame(new Date,"week")?" current-week":""}function s(t){return t.isSame(new Date,"month")?" current-month":""}function n(t){return t.isSame(new Date,"year")?" current-year":""}var r=o(this.current),a=r.locale?r.locale("en"):r.lang("en"),h=this.step;switch(this.scale){case"millisecond":return t(a.milliseconds()).trim();case"second":return t(a.seconds()).trim();case"minute":return t(a.minutes()).trim();case"hour":var d=a.hours();return 4==this.step&&(d=d+"-"+(d+4)),d+"h"+e(a)+t(a.hours());case"weekday":return a.format("dddd").toLowerCase()+e(a)+i(a)+t(a.date());case"day":var l=a.date(),c=a.format("MMMM").toLowerCase();return"day"+l+" "+c+s(a)+t(l-1);case"month":return a.format("MMMM").toLowerCase()+s(a)+t(a.month());case"year":var p=a.year();return"year"+p+n(a)+t(p);default:return""}},t.exports=s},function(t){function e(){this.options=null,this.props=null}e.prototype.setOptions=function(t){t&&util.extend(this.options,t)},e.prototype.redraw=function(){return!1},e.prototype.destroy=function(){},e.prototype._isResized=function(){var t=this.props._previousWidth!==this.props.width||this.props._previousHeight!==this.props.height;return this.props._previousWidth=this.props.width,this.props._previousHeight=this.props.height,t},t.exports=e},function(t,e,i){function s(t,e){this.body=t,this.defaultOptions={showCurrentTime:!0,locales:a,locale:"en"},this.options=o.extend({},this.defaultOptions),this.offset=0,this._create(),this.setOptions(e)}var o=i(1),n=i(20),r=i(44),a=i(48);s.prototype=new n,s.prototype._create=function(){var t=document.createElement("div");t.className="currenttime",t.style.position="absolute",t.style.top="0px",t.style.height="100%",this.bar=t},s.prototype.destroy=function(){this.options.showCurrentTime=!1,this.redraw(),this.body=null},s.prototype.setOptions=function(t){t&&o.selectiveExtend(["showCurrentTime","locale","locales"],this.options,t)},s.prototype.redraw=function(){if(this.options.showCurrentTime){var t=this.body.dom.backgroundVertical;this.bar.parentNode!=t&&(this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),t.appendChild(this.bar),this.start());var e=new Date((new Date).valueOf()+this.offset),i=this.body.util.toScreen(e),s=this.options.locales[this.options.locale],o=s.current+" "+s.time+": "+r(e).format("dddd, MMMM Do YYYY, H:mm:ss");o=o.charAt(0).toUpperCase()+o.substring(1),this.bar.style.left=i+"px",this.bar.title=o}else this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),this.stop();return!1},s.prototype.start=function(){function t(){e.stop();var i=e.body.range.conversion(e.body.domProps.center.width).scale,s=1/i/10;30>s&&(s=30),s>1e3&&(s=1e3),e.redraw(),e.currentTimeTimer=setTimeout(t,s)}var e=this;t()},s.prototype.stop=function(){void 0!==this.currentTimeTimer&&(clearTimeout(this.currentTimeTimer),delete this.currentTimeTimer)},s.prototype.setCurrentTime=function(t){var e=o.convert(t,"Date").valueOf(),i=(new Date).valueOf();this.offset=e-i,this.redraw()},s.prototype.getCurrentTime=function(){return new Date((new Date).valueOf()+this.offset)},t.exports=s},function(t,e,i){function s(t,e){this.body=t,this.defaultOptions={showCustomTime:!1,locales:h,locale:"en"},this.options=n.extend({},this.defaultOptions),this.customTime=new Date,this.eventParams={},this._create(),this.setOptions(e)}var o=i(45),n=i(1),r=i(20),a=i(44),h=i(48);s.prototype=new r,s.prototype.setOptions=function(t){t&&n.selectiveExtend(["showCustomTime","locale","locales"],this.options,t)},s.prototype._create=function(){var t=document.createElement("div");t.className="customtime",t.style.position="absolute",t.style.top="0px",t.style.height="100%",this.bar=t;var e=document.createElement("div");e.style.position="relative",e.style.top="0px",e.style.left="-10px",e.style.height="100%",e.style.width="20px",t.appendChild(e),this.hammer=new o(e),this.hammer.on("panstart",this._onDragStart.bind(this)),this.hammer.on("panmove",this._onDrag.bind(this)),this.hammer.on("panend",this._onDragEnd.bind(this)),this.hammer.on("pan",function(t){t.preventDefault()})},s.prototype.destroy=function(){this.options.showCustomTime=!1,this.redraw(),this.hammer.enable(!1),this.hammer=null,this.body=null},s.prototype.redraw=function(){if(this.options.showCustomTime){var t=this.body.dom.backgroundVertical;this.bar.parentNode!=t&&(this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),t.appendChild(this.bar));var e=this.body.util.toScreen(this.customTime),i=this.options.locales[this.options.locale],s=i.time+": "+a(this.customTime).format("dddd, MMMM Do YYYY, H:mm:ss");s=s.charAt(0).toUpperCase()+s.substring(1),this.bar.style.left=e+"px",this.bar.title=s}else this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar);return!1},s.prototype.setCustomTime=function(t){this.customTime=n.convert(t,"Date"),this.redraw()},s.prototype.getCustomTime=function(){return new Date(this.customTime.valueOf())},s.prototype._onDragStart=function(t){this.eventParams.dragging=!0,this.eventParams.customTime=this.customTime,t.stopPropagation(),t.preventDefault()},s.prototype._onDrag=function(t){if(this.eventParams.dragging){var e=this.body.util.toScreen(this.eventParams.customTime)+t.deltaX,i=this.body.util.toTime(e);this.setCustomTime(i),this.body.emitter.emit("timechange",{time:new Date(this.customTime.valueOf())}),t.stopPropagation(),t.preventDefault()}},s.prototype._onDragEnd=function(t){this.eventParams.dragging&&(this.body.emitter.emit("timechanged",{time:new Date(this.customTime.valueOf())}),t.stopPropagation(),t.preventDefault())},t.exports=s},function(t,e,i){function s(t,e,i,s){this.id=o.randomUUID(),this.body=t,this.defaultOptions={orientation:"left",showMinorLabels:!0,showMajorLabels:!0,icons:!0,majorLinesOffset:7,minorLinesOffset:4,labelOffsetX:10,labelOffsetY:2,iconWidth:20,width:"40px",visible:!0,alignZeros:!0,customRange:{left:{min:void 0,max:void 0},right:{min:void 0,max:void 0}},title:{left:{text:void 0},right:{text:void 0}},format:{left:{decimals:void 0},right:{decimals:void 0}}},this.linegraphOptions=s,this.linegraphSVG=i,this.props={},this.DOMelements={lines:{},labels:{},title:{}},this.dom={},this.range={start:0,end:0},this.options=o.extend({},this.defaultOptions),this.conversionFactor=1,this.setOptions(e),this.width=Number((""+this.options.width).replace("px","")),this.minWidth=this.width,this.height=this.linegraphSVG.offsetHeight,this.hidden=!1,this.stepPixels=25,this.stepPixelsForced=25,this.zeroCrossing=-1,this.lineOffset=0,this.master=!0,this.svgElements={},this.iconsRemoved=!1,this.groups={},this.amountOfGroups=0,this._create();var n=this;this.body.emitter.on("verticalDrag",function(){n.dom.lineContainer.style.top=n.body.domProps.scrollTop+"px"})}var o=i(1),n=i(2),r=i(20),a=i(16);s.prototype=new r,s.prototype.addGroup=function(t,e){this.groups.hasOwnProperty(t)||(this.groups[t]=e),this.amountOfGroups+=1},s.prototype.updateGroup=function(t,e){this.groups[t]=e},s.prototype.removeGroup=function(t){this.groups.hasOwnProperty(t)&&(delete this.groups[t],this.amountOfGroups-=1)},s.prototype.setOptions=function(t){if(t){var e=!1;this.options.orientation!=t.orientation&&void 0!==t.orientation&&(e=!0);var i=["orientation","showMinorLabels","showMajorLabels","icons","majorLinesOffset","minorLinesOffset","labelOffsetX","labelOffsetY","iconWidth","width","visible","customRange","title","format","alignZeros"];o.selectiveExtend(i,this.options,t),this.minWidth=Number((""+this.options.width).replace("px","")),1==e&&this.dom.frame&&(this.hide(),this.show())}},s.prototype._create=function(){this.dom.frame=document.createElement("div"),this.dom.frame.style.width=this.options.width,this.dom.frame.style.height=this.height,this.dom.lineContainer=document.createElement("div"),this.dom.lineContainer.style.width="100%",this.dom.lineContainer.style.height=this.height,this.dom.lineContainer.style.position="relative",this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="absolute",this.svg.style.top="0px",this.svg.style.height="100%",this.svg.style.width="100%",this.svg.style.display="block",this.dom.frame.appendChild(this.svg)},s.prototype._redrawGroupIcons=function(){n.prepareElements(this.svgElements);var t,e=this.options.iconWidth,i=15,s=4,o=s+.5*i;t="left"==this.options.orientation?s:this.width-e-s;for(var r in this.groups)this.groups.hasOwnProperty(r)&&(1!=this.groups[r].visible||void 0!==this.linegraphOptions.visibility[r]&&1!=this.linegraphOptions.visibility[r]||(this.groups[r].drawIcon(t,o,this.svgElements,this.svg,e,i),o+=i+s));n.cleanupElements(this.svgElements),this.iconsRemoved=!1},s.prototype._cleanupIcons=function(){0==this.iconsRemoved&&(n.prepareElements(this.svgElements),n.cleanupElements(this.svgElements),this.iconsRemoved=!0)},s.prototype.show=function(){this.hidden=!1,this.dom.frame.parentNode||("left"==this.options.orientation?this.body.dom.left.appendChild(this.dom.frame):this.body.dom.right.appendChild(this.dom.frame)),this.dom.lineContainer.parentNode||this.body.dom.backgroundHorizontal.appendChild(this.dom.lineContainer)},s.prototype.hide=function(){this.hidden=!0,this.dom.frame.parentNode&&this.dom.frame.parentNode.removeChild(this.dom.frame),this.dom.lineContainer.parentNode&&this.dom.lineContainer.parentNode.removeChild(this.dom.lineContainer)},s.prototype.setRange=function(t,e){0==this.master&&1==this.options.alignZeros&&-1!=this.zeroCrossing&&t>0&&(t=0),this.range.start=t,this.range.end=e},s.prototype.redraw=function(){var t=!1,e=0;this.dom.lineContainer.style.top=this.body.domProps.scrollTop+"px";for(var i in this.groups)this.groups.hasOwnProperty(i)&&(1!=this.groups[i].visible||void 0!==this.linegraphOptions.visibility[i]&&1!=this.linegraphOptions.visibility[i]||e++);if(0==this.amountOfGroups||0==e)this.hide();else{this.show(),this.height=Number(this.linegraphSVG.style.height.replace("px","")),this.dom.lineContainer.style.height=this.height+"px",this.width=1==this.options.visible?Number((""+this.options.width).replace("px","")):0;var s=this.props,o=this.dom.frame;o.className="dataaxis",this._calculateCharSize();var n=this.options.orientation,r=this.options.showMinorLabels,a=this.options.showMajorLabels;s.minorLabelHeight=r?s.minorCharHeight:0,s.majorLabelHeight=a?s.majorCharHeight:0,s.minorLineWidth=this.body.dom.backgroundHorizontal.offsetWidth-this.lineOffset-this.width+2*this.options.minorLinesOffset,s.minorLineHeight=1,s.majorLineWidth=this.body.dom.backgroundHorizontal.offsetWidth-this.lineOffset-this.width+2*this.options.majorLinesOffset,s.majorLineHeight=1,"left"==n?(o.style.top="0",o.style.left="0",o.style.bottom="",o.style.width=this.width+"px",o.style.height=this.height+"px",this.props.width=this.body.domProps.left.width,this.props.height=this.body.domProps.left.height):(o.style.top="",o.style.bottom="0",o.style.left="0",o.style.width=this.width+"px",o.style.height=this.height+"px",this.props.width=this.body.domProps.right.width,this.props.height=this.body.domProps.right.height),t=this._redrawLabels(),t=this._isResized()||t,1==this.options.icons?this._redrawGroupIcons():this._cleanupIcons(),this._redrawTitle(n)}return t},s.prototype._redrawLabels=function(){var t=!1;n.prepareElements(this.DOMelements.lines),n.prepareElements(this.DOMelements.labels);var e=this.options.orientation,i=this.master?this.props.majorCharHeight||10:this.stepPixelsForced,s=new a(this.range.start,this.range.end,i,this.dom.frame.offsetHeight,this.options.customRange[this.options.orientation],0==this.master&&this.options.alignZeros);this.step=s;var o=(this.dom.frame.offsetHeight-s.deadSpace*(this.dom.frame.offsetHeight/s.marginRange))/((s.marginRange-s.deadSpace)/s.step);this.stepPixels=o;var r=this.height/o,h=0;if(0==this.master){o=this.stepPixelsForced,h=Math.round(this.dom.frame.offsetHeight/o-r);for(var d=0;.5*h>d;d++)s.previous();if(r=this.height/o,-1!=this.zeroCrossing&&1==this.options.alignZeros){var l=s.marginEnd/s.step-this.zeroCrossing;if(l>0)for(var d=0;l>d;d++)s.next();else if(0>l)for(var d=0;-l>d;d++)s.previous()}}else r+=.25;this.valueAtZero=s.marginEnd;var c,p=0,u=1;void 0!==this.options.format[e]&&(c=this.options.format[e].decimals),this.maxLabelSize=0;for(var m=0;u=0&&this._redrawLabel(m-2,s.getCurrent(c),e,"yAxis major",this.props.majorCharHeight),this._redrawLine(m,e,"grid horizontal major",this.options.majorLinesOffset,this.props.majorLineWidth)):this._redrawLine(m,e,"grid horizontal minor",this.options.minorLinesOffset,this.props.minorLineWidth),1==this.master&&0==s.current&&(this.zeroCrossing=u),u++}this.conversionFactor=0==this.master?m/(this.valueAtZero-s.current):this.dom.frame.offsetHeight/s.marginRange;var g=0;void 0!==this.options.title[e]&&void 0!==this.options.title[e].text&&(g=this.props.titleCharHeight);var v=1==this.options.icons?Math.max(this.options.iconWidth,g)+this.options.labelOffsetX+15:g+this.options.labelOffsetX+15;return this.maxLabelSize>this.width-v&&1==this.options.visible?(this.width=this.maxLabelSize+v,this.options.width=this.width+"px",n.cleanupElements(this.DOMelements.lines),n.cleanupElements(this.DOMelements.labels),this.redraw(),t=!0):this.maxLabelSizethis.minWidth?(this.width=Math.max(this.minWidth,this.maxLabelSize+v),this.options.width=this.width+"px",n.cleanupElements(this.DOMelements.lines),n.cleanupElements(this.DOMelements.labels),this.redraw(),t=!0):(n.cleanupElements(this.DOMelements.lines),n.cleanupElements(this.DOMelements.labels),t=!1),t},s.prototype.convertValue=function(t){var e=this.valueAtZero-t,i=e*this.conversionFactor;return i},s.prototype._redrawLabel=function(t,e,i,s,o){var r=n.getDOMElement("div",this.DOMelements.labels,this.dom.frame);r.className=s,r.innerHTML=e,"left"==i?(r.style.left="-"+this.options.labelOffsetX+"px",r.style.textAlign="right"):(r.style.right="-"+this.options.labelOffsetX+"px",r.style.textAlign="left"),r.style.top=t-.5*o+this.options.labelOffsetY+"px",e+="";var a=Math.max(this.props.majorCharWidth,this.props.minorCharWidth);this.maxLabelSized;d++){var c=this.visibleItems[d];c.repositionY(e)}return s},s.prototype._calculateHeight=function(t){var e,i=this.visibleItems;this.resetSubgroups();var s=this;if(i.length){var n=i[0].top,r=i[0].top+i[0].height;if(o.forEach(i,function(t){n=Math.min(n,t.top),r=Math.max(r,t.top+t.height),void 0!==t.data.subgroup&&(s.subgroups[t.data.subgroup].height=Math.max(s.subgroups[t.data.subgroup].height,t.height),s.subgroups[t.data.subgroup].visible=!0)}),n>t.axis){var a=n-t.axis;r-=a,o.forEach(i,function(t){t.top-=a})}e=r+t.item.vertical/2}else e=t.axis+t.item.vertical;return e=Math.max(e,this.props.label.height)},s.prototype.show=function(){this.dom.label.parentNode||this.itemSet.dom.labelSet.appendChild(this.dom.label),this.dom.foreground.parentNode||this.itemSet.dom.foreground.appendChild(this.dom.foreground),this.dom.background.parentNode||this.itemSet.dom.background.appendChild(this.dom.background),this.dom.axis.parentNode||this.itemSet.dom.axis.appendChild(this.dom.axis)},s.prototype.hide=function(){var t=this.dom.label;t.parentNode&&t.parentNode.removeChild(t);var e=this.dom.foreground;e.parentNode&&e.parentNode.removeChild(e);var i=this.dom.background;i.parentNode&&i.parentNode.removeChild(i);var s=this.dom.axis;s.parentNode&&s.parentNode.removeChild(s)},s.prototype.add=function(t){if(this.items[t.id]=t,t.setParent(this),void 0!==t.data.subgroup&&(void 0===this.subgroups[t.data.subgroup]&&(this.subgroups[t.data.subgroup]={height:0,visible:!1,index:this.subgroupIndex,items:[]},this.subgroupIndex++),this.subgroups[t.data.subgroup].items.push(t)),this.orderSubgroups(),-1==this.visibleItems.indexOf(t)){var e=this.itemSet.body.range;this._checkIfVisible(t,this.visibleItems,e)}},s.prototype.orderSubgroups=function(){if(void 0!==this.subgroupOrderer){var t=[];if("string"==typeof this.subgroupOrderer){for(var e in this.subgroups)t.push({subgroup:e,sortField:this.subgroups[e].items[0].data[this.subgroupOrderer]});t.sort(function(t,e){return t.sortField-e.sortField})}else if("function"==typeof this.subgroupOrderer){for(var e in this.subgroups)t.push(this.subgroups[e].items[0].data);t.sort(this.subgroupOrderer)}if(t.length>0)for(var i=0;it?-1:l>=t?0:1};if(e.length>0)for(n=0;nl}),1==this.checkRangedItems)for(this.checkRangedItems=!1,n=0;nl})}for(n=0;n=0&&(n=e[r],!o(n));r--)void 0===s[n.id]&&(s[n.id]=!0,i.push(n));for(r=t+1;rs;s++){var n=this.visibleItems[s];n.repositionY(e)}return i},s.prototype.show=function(){this.dom.background.parentNode||this.itemSet.dom.background.appendChild(this.dom.background)},t.exports=s},function(t,e,i){function s(t,e){this.body=t,this.defaultOptions={type:null,orientation:"bottom",align:"auto",stack:!0,groupOrder:null,selectable:!0,editable:{updateTime:!1,updateGroup:!1,add:!1,remove:!1},onAdd:function(t,e){e(t)},onUpdate:function(t,e){e(t)},onMove:function(t,e){e(t)},onRemove:function(t,e){e(t)},onMoving:function(t,e){e(t)},margin:{item:{horizontal:10,vertical:10},axis:20},padding:5},this.options=n.extend({},this.defaultOptions),this.itemOptions={type:{start:"Date",end:"Date"}},this.conversion={toScreen:t.util.toScreen,toTime:t.util.toTime},this.dom={},this.props={},this.hammer=null;var i=this;this.itemsData=null,this.groupsData=null,this.itemListeners={add:function(t,e){i._onAdd(e.items)},update:function(t,e){i._onUpdate(e.items)},remove:function(t,e){i._onRemove(e.items)}},this.groupListeners={add:function(t,e){i._onAddGroups(e.items)},update:function(t,e){i._onUpdateGroups(e.items)},remove:function(t,e){i._onRemoveGroups(e.items)}},this.items={},this.groups={},this.groupIds=[],this.selection=[],this.stackDirty=!0,this.touchParams={},this._create(),this.setOptions(e)}var o=i(45),n=i(1),r=i(3),a=i(4),h=i(20),d=i(25),l=i(26),c=i(33),p=i(34),u=i(35),m=i(32),f="__ungrouped__",g="__background__";s.prototype=new h,s.types={background:m,box:c,range:u,point:p},s.prototype._create=function(){var t=document.createElement("div");t.className="itemset",t["timeline-itemset"]=this,this.dom.frame=t;var e=document.createElement("div");e.className="background",t.appendChild(e),this.dom.background=e;var i=document.createElement("div");i.className="foreground",t.appendChild(i),this.dom.foreground=i;var s=document.createElement("div");s.className="axis",this.dom.axis=s;var n=document.createElement("div");n.className="labelset",this.dom.labelSet=n,this._updateUngrouped();var r=new l(g,null,this);r.show(),this.groups[g]=r,this.hammer=new o(this.body.dom.centerContainer),this.hammer.on("hammer.input",function(t){t.isFirst&&this._onTouch(t)}.bind(this)),this.hammer.on("panstart",this._onDragStart.bind(this)),this.hammer.on("panmove",this._onDrag.bind(this)),this.hammer.on("panend",this._onDragEnd.bind(this)),this.hammer.on("tap",this._onSelectItem.bind(this)),this.hammer.on("press",this._onMultiSelectItem.bind(this)),this.hammer.on("doubletap",this._onAddItem.bind(this)),this.show()},s.prototype.setOptions=function(t){if(t){var e=["type","align","orientation","padding","stack","selectable","groupOrder","dataAttributes","template","hide"];n.selectiveExtend(e,this.options,t),"margin"in t&&("number"==typeof t.margin?(this.options.margin.axis=t.margin,this.options.margin.item.horizontal=t.margin,this.options.margin.item.vertical=t.margin):"object"==typeof t.margin&&(n.selectiveExtend(["axis"],this.options.margin,t.margin),"item"in t.margin&&("number"==typeof t.margin.item?(this.options.margin.item.horizontal=t.margin.item,this.options.margin.item.vertical=t.margin.item):"object"==typeof t.margin.item&&n.selectiveExtend(["horizontal","vertical"],this.options.margin.item,t.margin.item)))),"editable"in t&&("boolean"==typeof t.editable?(this.options.editable.updateTime=t.editable,this.options.editable.updateGroup=t.editable,this.options.editable.add=t.editable,this.options.editable.remove=t.editable):"object"==typeof t.editable&&n.selectiveExtend(["updateTime","updateGroup","add","remove"],this.options.editable,t.editable));var i=function(e){var i=t[e];if(i){if(!(i instanceof Function))throw new Error("option "+e+" must be a function "+e+"(item, callback)");this.options[e]=i}}.bind(this);["onAdd","onUpdate","onRemove","onMove","onMoving"].forEach(i),this.markDirty()}},s.prototype.markDirty=function(){this.groupIds=[],this.stackDirty=!0},s.prototype.destroy=function(){this.hide(),this.setItems(null),this.setGroups(null),this.hammer=null,this.body=null,this.conversion=null},s.prototype.hide=function(){this.dom.frame.parentNode&&this.dom.frame.parentNode.removeChild(this.dom.frame),this.dom.axis.parentNode&&this.dom.axis.parentNode.removeChild(this.dom.axis),this.dom.labelSet.parentNode&&this.dom.labelSet.parentNode.removeChild(this.dom.labelSet)},s.prototype.show=function(){this.dom.frame.parentNode||this.body.dom.center.appendChild(this.dom.frame),this.dom.axis.parentNode||this.body.dom.backgroundVertical.appendChild(this.dom.axis),this.dom.labelSet.parentNode||this.body.dom.left.appendChild(this.dom.labelSet)},s.prototype.setSelection=function(t){var e,i,s,o;for(void 0==t&&(t=[]),Array.isArray(t)||(t=[t]),e=0,i=this.selection.length;i>e;e++)s=this.selection[e],o=this.items[s],o&&o.unselect();for(this.selection=[],e=0,i=t.length;i>e;e++)s=t[e],o=this.items[s],o&&(this.selection.push(s),o.select())},s.prototype.getSelection=function(){return this.selection.concat([])},s.prototype.getVisibleItems=function(){var t=this.body.range.getRange(),e=this.body.util.toScreen(t.start),i=this.body.util.toScreen(t.end),s=[];for(var o in this.groups)if(this.groups.hasOwnProperty(o))for(var n=this.groups[o],r=n.visibleItems,a=0;ae&&s.push(h.id)}return s},s.prototype._deselect=function(t){for(var e=this.selection,i=0,s=e.length;s>i;i++)if(e[i]==t){e.splice(i,1);break}},s.prototype.redraw=function(){var t=this.options.margin,e=this.body.range,i=n.option.asSize,s=this.options,o=s.orientation,r=!1,a=this.dom.frame,h=s.editable.updateTime||s.editable.updateGroup;this.props.top=this.body.domProps.top.height+this.body.domProps.border.top,this.props.left=this.body.domProps.left.width+this.body.domProps.border.left,a.className="itemset"+(h?" editable":""),r=this._orderGroups()||r;var d=e.end-e.start,l=d!=this.lastVisibleInterval||this.props.width!=this.props.lastWidth;l&&(this.stackDirty=!0),this.lastVisibleInterval=d,this.props.lastWidth=this.props.width;var c=this.stackDirty,p=this._firstGroup(),u={item:t.item,axis:t.axis},m={item:t.item,axis:t.item.vertical/2},f=0,v=t.axis+t.item.vertical;return this.groups[g].redraw(e,m,c),n.forEach(this.groups,function(t){var i=t==p?u:m,s=t.redraw(e,i,c);r=s||r,f+=t.height}),f=Math.max(f,v),this.stackDirty=!1,a.style.height=i(f),this.props.width=a.offsetWidth,this.props.height=f,this.dom.axis.style.top=i("top"==o?this.body.domProps.top.height+this.body.domProps.border.top:this.body.domProps.top.height+this.body.domProps.centerContainer.height),this.dom.axis.style.left="0",r=this._isResized()||r},s.prototype._firstGroup=function(){var t="top"==this.options.orientation?0:this.groupIds.length-1,e=this.groupIds[t],i=this.groups[e]||this.groups[f];return i||null},s.prototype._updateUngrouped=function(){{var t,e,i=this.groups[f];this.groups[g]}if(this.groupsData){if(i){i.hide(),delete this.groups[f];for(e in this.items)if(this.items.hasOwnProperty(e)){t=this.items[e],t.parent&&t.parent.remove(t);var s=this._getGroupId(t.data),o=this.groups[s];o&&o.add(t)||t.hide()}}}else if(!i){var n=null,r=null;i=new d(n,r,this),this.groups[f]=i;for(e in this.items)this.items.hasOwnProperty(e)&&(t=this.items[e],i.add(t));i.show()}},s.prototype.getLabelSet=function(){return this.dom.labelSet},s.prototype.setItems=function(t){var e,i=this,s=this.itemsData;if(t){if(!(t instanceof r||t instanceof a))throw new TypeError("Data must be an instance of DataSet or DataView");this.itemsData=t}else this.itemsData=null;if(s&&(n.forEach(this.itemListeners,function(t,e){s.off(e,t)}),e=s.getIds(),this._onRemove(e)),this.itemsData){var o=this.id;n.forEach(this.itemListeners,function(t,e){i.itemsData.on(e,t,o)}),e=this.itemsData.getIds(),this._onAdd(e),this._updateUngrouped()}},s.prototype.getItems=function(){return this.itemsData},s.prototype.setGroups=function(t){var e,i=this;if(this.groupsData&&(n.forEach(this.groupListeners,function(t,e){i.groupsData.unsubscribe(e,t)}),e=this.groupsData.getIds(),this.groupsData=null,this._onRemoveGroups(e)),t){if(!(t instanceof r||t instanceof a))throw new TypeError("Data must be an instance of DataSet or DataView");this.groupsData=t}else this.groupsData=null;if(this.groupsData){var s=this.id;n.forEach(this.groupListeners,function(t,e){i.groupsData.on(e,t,s)}),e=this.groupsData.getIds(),this._onAddGroups(e)}this._updateUngrouped(),this._order(),this.body.emitter.emit("change",{queue:!0})},s.prototype.getGroups=function(){return this.groupsData},s.prototype.removeItem=function(t){var e=this.itemsData.get(t),i=this.itemsData.getDataSet();e&&this.options.onRemove(e,function(e){e&&i.remove(t)})},s.prototype._getType=function(t){return t.type||this.options.type||(t.end?"range":"box")},s.prototype._getGroupId=function(t){var e=this._getType(t);return"background"==e&&void 0==t.group?g:this.groupsData?t.group:f},s.prototype._onUpdate=function(t){var e=this;t.forEach(function(t){var i=e.itemsData.get(t,e.itemOptions),o=e.items[t],n=e._getType(i),r=s.types[n];if(o&&(r&&o instanceof r?e._updateItem(o,i):(e._removeItem(o),o=null)),!o){if(!r)throw new TypeError("rangeoverflow"==n?'Item type "rangeoverflow" is deprecated. Use css styling instead: .vis.timeline .item.range .content {overflow: visible;}':'Unknown item type "'+n+'"');o=new r(i,e.conversion,e.options),o.id=t,e._addItem(o)}}),this._order(),this.stackDirty=!0,this.body.emitter.emit("change",{queue:!0})},s.prototype._onAdd=s.prototype._onUpdate,s.prototype._onRemove=function(t){var e=0,i=this;t.forEach(function(t){var s=i.items[t];s&&(e++,i._removeItem(s))}),e&&(this._order(),this.stackDirty=!0,this.body.emitter.emit("change",{queue:!0}))},s.prototype._order=function(){n.forEach(this.groups,function(t){t.order()})},s.prototype._onUpdateGroups=function(t){this._onAddGroups(t)},s.prototype._onAddGroups=function(t){var e=this;t.forEach(function(t){var i=e.groupsData.get(t),s=e.groups[t];if(s)s.setData(i);else{if(t==f||t==g)throw new Error("Illegal group id. "+t+" is a reserved id.");var o=Object.create(e.options);n.extend(o,{height:null}),s=new d(t,i,e),e.groups[t]=s;for(var r in e.items)if(e.items.hasOwnProperty(r)){var a=e.items[r];a.data.group==t&&s.add(a)}s.order(),s.show()}}),this.body.emitter.emit("change",{queue:!0})},s.prototype._onRemoveGroups=function(t){var e=this.groups;t.forEach(function(t){var i=e[t];i&&(i.hide(),delete e[t])}),this.markDirty(),this.body.emitter.emit("change",{queue:!0})},s.prototype._orderGroups=function(){if(this.groupsData){var t=this.groupsData.getIds({order:this.options.groupOrder}),e=!n.equalArray(t,this.groupIds);if(e){var i=this.groups;t.forEach(function(t){i[t].hide()}),t.forEach(function(t){i[t].show()}),this.groupIds=t}return e}return!1},s.prototype._addItem=function(t){this.items[t.id]=t;var e=this._getGroupId(t.data),i=this.groups[e];i&&i.add(t)},s.prototype._updateItem=function(t,e){var i=t.data.group;if(t.setData(e),i!=t.data.group){var s=this.groups[i];s&&s.remove(t);var o=this._getGroupId(t.data),n=this.groups[o];n&&n.add(t)}},s.prototype._removeItem=function(t){t.hide(),delete this.items[t.id];var e=this.selection.indexOf(t.id);-1!=e&&this.selection.splice(e,1),t.parent&&t.parent.remove(t)},s.prototype._constructByEndArray=function(t){for(var e=[],i=0;i0||o.length>0)&&this.body.emitter.emit("select",{items:a})}},s.prototype._onAddItem=function(t){if(this.options.selectable&&this.options.editable.add){var e=this,i=this.body.util.snap||null,o=s.itemFromTarget(t);if(o){var r=e.itemsData.get(o.id);this.options.onUpdate(r,function(t){t&&e.itemsData.getDataSet().update(t)})}else{var a=n.getAbsoluteLeft(this.dom.frame),h=t.center.x-a,d=this.body.util.toTime(h),l={start:i?i(d):d,content:"new item"};if("range"===this.options.type){var c=this.body.util.toTime(h+this.props.width/5);l.end=i?i(c):c}l[this.itemsData._fieldId]=n.randomUUID();var p=s.groupFromTarget(t);p&&(l.group=p.groupId),this.options.onAdd(l,function(t){t&&e.itemsData.getDataSet().add(t)})}}},s.prototype._onMultiSelectItem=function(t){if(this.options.selectable){var e,i=s.itemFromTarget(t);if(i){e=this.getSelection();var o=t.srcEvent&&t.srcEvent.shiftKey||!1;if(o){e.push(i.id);var n=s._getItemRange(this.itemsData.get(e,this.itemOptions));e=[];for(var r in this.items)if(this.items.hasOwnProperty(r)){var a=this.items[r],h=a.data.start,d=void 0!==a.data.end?a.data.end:h;h>=n.min&&d<=n.max&&e.push(a.id)}}else{var l=e.indexOf(i.id);-1==l?e.push(i.id):e.splice(l,1)}this.setSelection(e),this.body.emitter.emit("select",{items:this.getSelection()})}}},s._getItemRange=function(t){var e=null,i=null;return t.forEach(function(t){(null==i||t.starte)&&(e=t.end):(null==e||t.start>e)&&(e=t.start)}),{min:i,max:e}},s.itemFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-item"))return e["timeline-item"];e=e.parentNode}return null},s.groupFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-group"))return e["timeline-group"];e=e.parentNode}return null},s.itemSetFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-itemset"))return e["timeline-itemset"];e=e.parentNode}return null},t.exports=s},function(t,e,i){function s(t,e,i,s){this.body=t,this.defaultOptions={enabled:!0,icons:!0,iconSize:20,iconSpacing:6,left:{visible:!0,position:"top-left"},right:{visible:!0,position:"top-left"}},this.side=i,this.options=o.extend({},this.defaultOptions),this.linegraphOptions=s,this.svgElements={},this.dom={},this.groups={},this.amountOfGroups=0,this._create(),this.setOptions(e)}var o=i(1),n=i(2),r=i(20);s.prototype=new r,s.prototype.clear=function(){this.groups={},this.amountOfGroups=0},s.prototype.addGroup=function(t,e){this.groups.hasOwnProperty(t)||(this.groups[t]=e),this.amountOfGroups+=1},s.prototype.updateGroup=function(t,e){this.groups[t]=e},s.prototype.removeGroup=function(t){this.groups.hasOwnProperty(t)&&(delete this.groups[t],this.amountOfGroups-=1)},s.prototype._create=function(){this.dom.frame=document.createElement("div"),this.dom.frame.className="legend",this.dom.frame.style.position="absolute",this.dom.frame.style.top="10px",this.dom.frame.style.display="block",this.dom.textArea=document.createElement("div"),this.dom.textArea.className="legendText",this.dom.textArea.style.position="relative",this.dom.textArea.style.top="0px",this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="absolute",this.svg.style.top="0px",this.svg.style.width=this.options.iconSize+5+"px",this.svg.style.height="100%",this.dom.frame.appendChild(this.svg),this.dom.frame.appendChild(this.dom.textArea)},s.prototype.hide=function(){this.dom.frame.parentNode&&this.dom.frame.parentNode.removeChild(this.dom.frame)},s.prototype.show=function(){this.dom.frame.parentNode||this.body.dom.center.appendChild(this.dom.frame)},s.prototype.setOptions=function(t){var e=["enabled","orientation","icons","left","right"];o.selectiveDeepExtend(e,this.options,t)},s.prototype.redraw=function(){var t=0;for(var e in this.groups)this.groups.hasOwnProperty(e)&&(1!=this.groups[e].visible||void 0!==this.linegraphOptions.visibility[e]&&1!=this.linegraphOptions.visibility[e]||t++);if(0==this.options[this.side].visible||0==this.amountOfGroups||0==this.options.enabled||0==t)this.hide();else{if(this.show(),"top-left"==this.options[this.side].position||"bottom-left"==this.options[this.side].position?(this.dom.frame.style.left="4px",this.dom.frame.style.textAlign="left",this.dom.textArea.style.textAlign="left",this.dom.textArea.style.left=this.options.iconSize+15+"px",this.dom.textArea.style.right="",this.svg.style.left="0px",this.svg.style.right=""):(this.dom.frame.style.right="4px",this.dom.frame.style.textAlign="right",this.dom.textArea.style.textAlign="right",this.dom.textArea.style.right=this.options.iconSize+15+"px",this.dom.textArea.style.left="",this.svg.style.right="0px",this.svg.style.left=""),"top-left"==this.options[this.side].position||"top-right"==this.options[this.side].position)this.dom.frame.style.top=4-Number(this.body.dom.center.style.top.replace("px",""))+"px",this.dom.frame.style.bottom="";else{var i=this.body.domProps.center.height-this.body.domProps.centerContainer.height;this.dom.frame.style.bottom=4+i+Number(this.body.dom.center.style.top.replace("px",""))+"px",this.dom.frame.style.top=""}0==this.options.icons?(this.dom.frame.style.width=this.dom.textArea.offsetWidth+10+"px",this.dom.textArea.style.right="",this.dom.textArea.style.left="",this.svg.style.width="0px"):(this.dom.frame.style.width=this.options.iconSize+15+this.dom.textArea.offsetWidth+10+"px",this.drawLegendIcons());var s="";for(var e in this.groups)this.groups.hasOwnProperty(e)&&(1!=this.groups[e].visible||void 0!==this.linegraphOptions.visibility[e]&&1!=this.linegraphOptions.visibility[e]||(s+=this.groups[e].content+"
"));this.dom.textArea.innerHTML=s,this.dom.textArea.style.lineHeight=.75*this.options.iconSize+this.options.iconSpacing+"px"}},s.prototype.drawLegendIcons=function(){if(this.dom.frame.parentNode){n.prepareElements(this.svgElements);var t=window.getComputedStyle(this.dom.frame).paddingTop,e=Number(t.replace("px","")),i=e,s=this.options.iconSize,o=.75*this.options.iconSize,r=e+.5*o+3;this.svg.style.width=s+5+e+"px";for(var a in this.groups)this.groups.hasOwnProperty(a)&&(1!=this.groups[a].visible||void 0!==this.linegraphOptions.visibility[a]&&1!=this.linegraphOptions.visibility[a]||(this.groups[a].drawIcon(i,r,this.svgElements,this.svg,s,o),r+=o+this.options.iconSpacing));n.cleanupElements(this.svgElements)}},t.exports=s},function(t,e,i){function s(t,e){this.id=o.randomUUID(),this.body=t,this.defaultOptions={yAxisOrientation:"left",defaultGroup:"default",sort:!0,sampling:!0,graphHeight:"400px",shaded:{enabled:!1,orientation:"bottom"},style:"line",barChart:{width:50,handleOverlap:"overlap",align:"center"},catmullRom:{enabled:!0,parametrization:"centripetal",alpha:.5},drawPoints:{enabled:!0,size:6,style:"square"},dataAxis:{showMinorLabels:!0,showMajorLabels:!0,icons:!1,width:"40px",visible:!0,alignZeros:!0,customRange:{left:{min:void 0,max:void 0},right:{min:void 0,max:void 0}}},legend:{enabled:!1,icons:!0,left:{visible:!0,position:"top-left"},right:{visible:!0,position:"top-right"}},groups:{visibility:{}}},this.options=o.extend({},this.defaultOptions),this.dom={},this.props={},this.hammer=null,this.groups={},this.abortedGraphUpdate=!1,this.updateSVGheight=!1,this.updateSVGheightOnResize=!1;var i=this;this.itemsData=null,this.groupsData=null,this.itemListeners={add:function(t,e){i._onAdd(e.items)},update:function(t,e){i._onUpdate(e.items)},remove:function(t,e){i._onRemove(e.items)}},this.groupListeners={add:function(t,e){i._onAddGroups(e.items)},update:function(t,e){i._onUpdateGroups(e.items)},remove:function(t,e){i._onRemoveGroups(e.items)}},this.items={},this.selection=[],this.lastStart=this.body.range.start,this.touchParams={},this.svgElements={},this.setOptions(e),this.groupsUsingDefaultStyles=[0],this.COUNTER=0,this.body.emitter.on("rangechanged",function(){i.lastStart=i.body.range.start,i.svg.style.left=o.option.asSize(-i.props.width),i.redraw.call(i,!0)}),this._create(),this.framework={svg:this.svg,svgElements:this.svgElements,options:this.options,groups:this.groups},this.body.emitter.emit("change")}var o=i(1),n=i(2),r=i(3),a=i(4),h=i(20),d=i(23),l=i(24),c=i(28),p=i(52),u="__ungrouped__";s.prototype=new h,s.prototype._create=function(){var t=document.createElement("div");t.className="LineGraph",this.dom.frame=t,this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="relative",this.svg.style.height=(""+this.options.graphHeight).replace("px","")+"px",this.svg.style.display="block",t.appendChild(this.svg),this.options.dataAxis.orientation="left",this.yAxisLeft=new d(this.body,this.options.dataAxis,this.svg,this.options.groups),this.options.dataAxis.orientation="right",this.yAxisRight=new d(this.body,this.options.dataAxis,this.svg,this.options.groups),delete this.options.dataAxis.orientation,this.legendLeft=new c(this.body,this.options.legend,"left",this.options.groups),this.legendRight=new c(this.body,this.options.legend,"right",this.options.groups),this.show()},s.prototype.setOptions=function(t){if(t){var e=["sampling","defaultGroup","height","graphHeight","yAxisOrientation","style","barChart","dataAxis","sort","groups"];void 0===t.graphHeight&&void 0!==t.height&&void 0!==this.body.domProps.centerContainer.height?(this.updateSVGheight=!0,this.updateSVGheightOnResize=!0):void 0!==this.body.domProps.centerContainer.height&&void 0!==t.graphHeight&&parseInt((t.graphHeight+"").replace("px",""))0){var d=this.body.util.toGlobalTime(-this.body.domProps.root.width),l=this.body.util.toGlobalTime(2*this.body.domProps.root.width),c={};for(this._getRelevantData(a,c,d,l),this._applySampling(a,c),e=0;eu&&console.log("WARNING: there may be an infinite loop in the _updateGraph emitter cycle."),this.COUNTER=0,this.abortedGraphUpdate=!1,e=0;e0)for(r=0;rs){d.push(h);break}d.push(h)}}else for(a=0;ai&&h.x0)for(var s=0;s0){var n=1,r=o.length,a=this.body.util.toGlobalScreen(o[o.length-1].x)-this.body.util.toGlobalScreen(o[0].x),h=r/a;n=Math.min(Math.ceil(.2*r),Math.max(1,Math.round(h)));for(var d=[],l=0;r>l;l+=n)d.push(o[l]);e[t[s]]=d}}},s.prototype._getYRanges=function(t,e,i){var s,o,n,r,a=[],h=[];if(t.length>0){for(n=0;n0&&(o=this.groups[t[n]],"stack"==r.barChart.handleOverlap&&"bar"==r.style?"left"==r.yAxisOrientation?a=a.concat(o.getYRange(s)):h=h.concat(o.getYRange(s)):i[t[n]]=o.getYRange(s,t[n]));p.getStackedBarYRange(a,i,t,"__barchartLeft","left"),p.getStackedBarYRange(h,i,t,"__barchartRight","right")}},s.prototype._updateYAxis=function(t,e){var i,s,o=!1,n=!1,r=!1,a=1e9,h=1e9,d=-1e9,l=-1e9;if(t.length>0){for(var c=0;ci?i:a,d=s>d?s:d):(r=!0,h=h>i?i:h,l=s>l?s:l));1==n&&this.yAxisLeft.setRange(a,d),1==r&&this.yAxisRight.setRange(h,l)}return o=this._toggleAxisVisiblity(n,this.yAxisLeft)||o,o=this._toggleAxisVisiblity(r,this.yAxisRight)||o,1==r&&1==n?(this.yAxisLeft.drawIcons=!0,this.yAxisRight.drawIcons=!0):(this.yAxisLeft.drawIcons=!1,this.yAxisRight.drawIcons=!1),this.yAxisRight.master=!n,0==this.yAxisRight.master?(this.yAxisLeft.lineOffset=1==r?this.yAxisRight.width:0,o=this.yAxisLeft.redraw()||o,this.yAxisRight.stepPixelsForced=this.yAxisLeft.stepPixels,this.yAxisRight.zeroCrossing=this.yAxisLeft.zeroCrossing,o=this.yAxisRight.redraw()||o):o=this.yAxisRight.redraw()||o,-1!=t.indexOf("__barchartLeft")&&t.splice(t.indexOf("__barchartLeft"),1),-1!=t.indexOf("__barchartRight")&&t.splice(t.indexOf("__barchartRight"),1),o},s.prototype._toggleAxisVisiblity=function(t,e){var i=!1;return 0==t?e.dom.frame.parentNode&&0==e.hidden&&(e.hide(),i=!0):e.dom.frame.parentNode||1!=e.hidden||(e.show(),i=!0),i},s.prototype._convertXcoordinates=function(t){for(var e,i,s=[],o=this.body.util.toScreen,n=0;ny;)y++,l=h.getCurrent(),c=h.isMajor(),u=h.getClassName(),f=m,m=this.body.util.toScreen(l),g=m-f,p&&(p.style.width=g+"px"),this.options.showMinorLabels&&this._repaintMinorText(m,h.getLabelMinor(),t,u),c&&this.options.showMajorLabels?(m>0&&(void 0==v&&(v=m),this._repaintMajorText(m,h.getLabelMajor(),t,u)),p=this._repaintMajorLine(m,t,u)):p=this._repaintMinorLine(m,t,u),h.next();if(this.options.showMajorLabels){var b=this.body.util.toTime(0),_=h.getLabelMajor(b),x=_.length*(this.props.majorCharWidth||10)+10;(void 0==v||v>x)&&this._repaintMajorText(0,_,t,u)}o.forEach(this.dom.redundant,function(t){for(;t.length;){var e=t.pop();e&&e.parentNode&&e.parentNode.removeChild(e)}})},s.prototype._repaintMinorText=function(t,e,i,s){var o=this.dom.redundant.minorTexts.shift();if(!o){var n=document.createTextNode("");o=document.createElement("div"),o.appendChild(n),this.dom.foreground.appendChild(o)}this.dom.minorTexts.push(o),o.childNodes[0].nodeValue=e,o.style.top="top"==i?this.props.majorLabelHeight+"px":"0",o.style.left=t+"px",o.className="text minor "+s},s.prototype._repaintMajorText=function(t,e,i,s){var o=this.dom.redundant.majorTexts.shift();if(!o){var n=document.createTextNode(e);o=document.createElement("div"),o.appendChild(n),this.dom.foreground.appendChild(o)}this.dom.majorTexts.push(o),o.childNodes[0].nodeValue=e,o.className="text major "+s,o.style.top="top"==i?"0":this.props.minorLabelHeight+"px",o.style.left=t+"px"},s.prototype._repaintMinorLine=function(t,e,i){var s=this.dom.redundant.lines.shift();s||(s=document.createElement("div"),this.dom.background.appendChild(s)),this.dom.lines.push(s);var o=this.props;return s.style.top="top"==e?o.majorLabelHeight+"px":this.body.domProps.top.height+"px",s.style.height=o.minorLineHeight+"px",s.style.left=t-o.minorLineWidth/2+"px",s.className="grid vertical minor "+i,s},s.prototype._repaintMajorLine=function(t,e,i){var s=this.dom.redundant.lines.shift();s||(s=document.createElement("div"),this.dom.background.appendChild(s)),this.dom.lines.push(s);var o=this.props;return s.style.top="top"==e?"0":this.body.domProps.top.height+"px",s.style.left=t-o.majorLineWidth/2+"px",s.style.height=o.majorLineHeight+"px",s.className="grid vertical major "+i,s},s.prototype._calculateCharSize=function(){this.dom.measureCharMinor||(this.dom.measureCharMinor=document.createElement("DIV"),this.dom.measureCharMinor.className="text minor measure",this.dom.measureCharMinor.style.position="absolute",this.dom.measureCharMinor.appendChild(document.createTextNode("0")),this.dom.foreground.appendChild(this.dom.measureCharMinor)),this.props.minorCharHeight=this.dom.measureCharMinor.clientHeight,this.props.minorCharWidth=this.dom.measureCharMinor.clientWidth,this.dom.measureCharMajor||(this.dom.measureCharMajor=document.createElement("DIV"),this.dom.measureCharMajor.className="text major measure",this.dom.measureCharMajor.style.position="absolute",this.dom.measureCharMajor.appendChild(document.createTextNode("0")),this.dom.foreground.appendChild(this.dom.measureCharMajor)),this.props.majorCharHeight=this.dom.measureCharMajor.clientHeight,this.props.majorCharWidth=this.dom.measureCharMajor.clientWidth},s.prototype.snap=function(t){return this.step.snap(t)},t.exports=s},function(t,e,i){function s(t,e,i){this.id=null,this.parent=null,this.data=t,this.dom=null,this.conversion=e||{},this.options=i||{},this.selected=!1,this.displayed=!1,this.dirty=!0,this.top=null,this.left=null,this.width=null,this.height=null}var o=i(45),n=i(1);s.prototype.stack=!0,s.prototype.select=function(){this.selected=!0,this.dirty=!0,this.displayed&&this.redraw()},s.prototype.unselect=function(){this.selected=!1,this.dirty=!0,this.displayed&&this.redraw()},s.prototype.setData=function(t){this.data=t,this.dirty=!0,this.displayed&&this.redraw()},s.prototype.setParent=function(t){this.displayed?(this.hide(),this.parent=t,this.parent&&this.show()):this.parent=t},s.prototype.isVisible=function(){return!1},s.prototype.show=function(){return!1},s.prototype.hide=function(){return!1},s.prototype.redraw=function(){},s.prototype.repositionX=function(){},s.prototype.repositionY=function(){},s.prototype._repaintDeleteButton=function(t){if(this.selected&&this.options.editable.remove&&!this.dom.deleteButton){var e=this,i=document.createElement("div");i.className="delete",i.title="Delete this item",new o(i).on("tap",function(t){e.parent.removeFromDataSet(e),t.stopPropagation(),t.preventDefault()}),t.appendChild(i),this.dom.deleteButton=i}else!this.selected&&this.dom.deleteButton&&(this.dom.deleteButton.parentNode&&this.dom.deleteButton.parentNode.removeChild(this.dom.deleteButton),this.dom.deleteButton=null)},s.prototype._updateContents=function(t){var e;if(this.options.template){var i=this.parent.itemSet.itemsData.get(this.id);e=this.options.template(i)}else e=this.data.content;if(e!==this.content){if(e instanceof Element)t.innerHTML="",t.appendChild(e);else if(void 0!=e)t.innerHTML=e;else if("background"!=this.data.type||void 0!==this.data.content)throw new Error('Property "content" missing in item '+this.id);this.content=e}},s.prototype._updateTitle=function(t){null!=this.data.title?t.title=this.data.title||"":t.removeAttribute("title")},s.prototype._updateDataAttributes=function(t){if(this.options.dataAttributes&&this.options.dataAttributes.length>0){var e=[];if(Array.isArray(this.options.dataAttributes))e=this.options.dataAttributes;else{if("all"!=this.options.dataAttributes)return;e=Object.keys(this.data)}for(var i=0;it.start},s.prototype.redraw=function(){var t=this.dom;if(t||(this.dom={},t=this.dom,t.box=document.createElement("div"),t.content=document.createElement("div"),t.content.className="content",t.box.appendChild(t.content),this.dirty=!0),!this.parent)throw new Error("Cannot redraw item: no parent attached");if(!t.box.parentNode){var e=this.parent.dom.background;if(!e)throw new Error("Cannot redraw item: parent has no background container element");e.appendChild(t.box)}if(this.displayed=!0,this.dirty){this._updateContents(this.dom.content),this._updateTitle(this.dom.content),this._updateDataAttributes(this.dom.content),this._updateStyle(this.dom.box);var i=(this.data.className?" "+this.data.className:"")+(this.selected?" selected":"");t.box.className=this.baseClassName+i,this.overflow="hidden"!==window.getComputedStyle(t.content).overflow,this.props.content.width=this.dom.content.offsetWidth,this.height=0,this.dirty=!1}},s.prototype.show=r.prototype.show,s.prototype.hide=r.prototype.hide,s.prototype.repositionX=r.prototype.repositionX,s.prototype.repositionY=function(t){var e="top"===this.options.orientation;this.dom.content.style.top=e?"":"0",this.dom.content.style.bottom=e?"0":"";var i;if(void 0!==this.data.subgroup){var s=this.data.subgroup,o=this.parent.subgroups,r=o[s].index;if(1==e){i=this.parent.subgroups[s].height+t.item.vertical,i+=0==r?t.axis-.5*t.item.vertical:0;var a=this.parent.top;for(var h in o)o.hasOwnProperty(h)&&1==o[h].visible&&o[h].indexr&&(a+=o[h].height+t.item.vertical);i=this.parent.subgroups[s].height+t.item.vertical,this.dom.box.style.top=a+"px",this.dom.box.style.bottom=""}}else this.parent instanceof n?(i=Math.max(this.parent.height,this.parent.itemSet.body.domProps.center.height,this.parent.itemSet.body.domProps.centerContainer.height),this.dom.box.style.top=e?"0":"",this.dom.box.style.bottom=e?"":"0"):(i=this.parent.height,this.dom.box.style.top=this.parent.top+"px",this.dom.box.style.bottom="");this.dom.box.style.height=i+"px"},t.exports=s},function(t,e,i){function s(t,e,i){if(this.props={dot:{width:0,height:0},line:{width:0,height:0}},t&&void 0==t.start)throw new Error('Property "start" missing in item '+t);o.call(this,t,e,i)}{var o=i(31);i(1)}s.prototype=new o(null,null,null),s.prototype.isVisible=function(t){var e=(t.end-t.start)/4;return this.data.start>t.start-e&&this.data.startt.start-e&&this.data.startt.start},s.prototype.redraw=function(){var t=this.dom;if(t||(this.dom={},t=this.dom,t.box=document.createElement("div"),t.content=document.createElement("div"),t.content.className="content",t.box.appendChild(t.content),t.box["timeline-item"]=this,this.dirty=!0),!this.parent)throw new Error("Cannot redraw item: no parent attached");if(!t.box.parentNode){var e=this.parent.dom.foreground;if(!e)throw new Error("Cannot redraw item: parent has no foreground container element");e.appendChild(t.box)}if(this.displayed=!0,this.dirty){this._updateContents(this.dom.content),this._updateTitle(this.dom.box),this._updateDataAttributes(this.dom.box),this._updateStyle(this.dom.box);var i=(this.data.className?" "+this.data.className:"")+(this.selected?" selected":"");t.box.className=this.baseClassName+i,this.overflow="hidden"!==window.getComputedStyle(t.content).overflow,this.dom.content.style.maxWidth="none",this.props.content.width=this.dom.content.offsetWidth,this.height=this.dom.box.offsetHeight,this.dom.content.style.maxWidth="",this.dirty=!1}this._repaintDeleteButton(t.box),this._repaintDragLeft(),this._repaintDragRight()},s.prototype.show=function(){this.displayed||this.redraw()},s.prototype.hide=function(){if(this.displayed){var t=this.dom.box;t.parentNode&&t.parentNode.removeChild(t),this.top=null,this.left=null,this.displayed=!1}},s.prototype.repositionX=function(){var t,e,i=this.parent.width,s=this.conversion.toScreen(this.data.start),o=this.conversion.toScreen(this.data.end);-i>s&&(s=-i),o>2*i&&(o=2*i);var n=Math.max(o-s,1);switch(this.overflow?(this.left=s,this.width=n+this.props.content.width,e=this.props.content.width):(this.left=s,this.width=n,e=Math.min(o-s-2*this.options.padding,this.props.content.width)),this.dom.box.style.left=this.left+"px",this.dom.box.style.width=n+"px",this.options.align){case"left":this.dom.content.style.left="0";break;case"right":this.dom.content.style.left=Math.max(n-e-2*this.options.padding,0)+"px";break;case"center":this.dom.content.style.left=Math.max((n-e-2*this.options.padding)/2,0)+"px";break;default:t=this.overflow?o>0?Math.max(-s,0):-e:0>s?Math.min(-s,o-s-e-2*this.options.padding):0,this.dom.content.style.left=t+"px"}},s.prototype.repositionY=function(){var t=this.options.orientation,e=this.dom.box;e.style.top="top"==t?this.top+"px":this.parent.height-this.top-this.height+"px"},s.prototype._repaintDragLeft=function(){if(this.selected&&this.options.editable.updateTime&&!this.dom.dragLeft){var t=document.createElement("div");t.className="drag-left",t.dragLeftItem=this,this.dom.box.appendChild(t),this.dom.dragLeft=t}else!this.selected&&this.dom.dragLeft&&(this.dom.dragLeft.parentNode&&this.dom.dragLeft.parentNode.removeChild(this.dom.dragLeft),this.dom.dragLeft=null)},s.prototype._repaintDragRight=function(){if(this.selected&&this.options.editable.updateTime&&!this.dom.dragRight){var t=document.createElement("div");t.className="drag-right",t.dragRightItem=this,this.dom.box.appendChild(t),this.dom.dragRight=t}else!this.selected&&this.dom.dragRight&&(this.dom.dragRight.parentNode&&this.dom.dragRight.parentNode.removeChild(this.dom.dragRight),this.dom.dragRight=null)},t.exports=s},function(t,e,i){function s(t,e,i){if(!(this instanceof s))throw new SyntaxError("Constructor must be called with the new operator");this._determineBrowserMethod(),this._initializeMixinLoaders(),this.containerElement=t,this.renderRefreshRate=60,this.renderTimestep=1e3/this.renderRefreshRate,this.renderTime=0,this.physicsTime=0,this.runDoubleSpeed=!1,this.physicsDiscreteStepsize=.5,this.initializing=!0,this.triggerFunctions={add:null,edit:null,editEdge:null,connect:null,del:null},this.defaultOptions={nodes:{mass:1,radiusMin:10,radiusMax:30,radius:10,shape:"ellipse",image:void 0,widthMin:16,widthMax:64,fontColor:"black",fontSize:14,fontFace:"verdana",fontFill:void 0,fontStrokeWidth:0,fontStrokeColor:"white",level:-1,color:{border:"#2B7CE9",background:"#97C2FC",highlight:{border:"#2B7CE9",background:"#D2E5FF"},hover:{border:"#2B7CE9",background:"#D2E5FF"}},group:void 0,borderWidth:1,borderWidthSelected:void 0},edges:{widthMin:1,widthMax:15,width:1,widthSelectionMultiplier:2,hoverWidth:1.5,style:"line",color:{color:"#848484",highlight:"#848484",hover:"#848484"},fontColor:"#343434",fontSize:14,fontFace:"arial",fontFill:"white",fontStrokeWidth:0,fontStrokeColor:"white",labelAlignment:"horizontal",arrowScaleFactor:1,dash:{length:10,gap:5,altLength:void 0},inheritColor:"from"},configurePhysics:!1,physics:{barnesHut:{enabled:!0,thetaInverted:2,gravitationalConstant:-2e3,centralGravity:.3,springLength:95,springConstant:.04,damping:.09},repulsion:{centralGravity:0,springLength:200,springConstant:.05,nodeDistance:100,damping:.09},hierarchicalRepulsion:{enabled:!1,centralGravity:0,springLength:100,springConstant:.01,nodeDistance:150,damping:.09},damping:null,centralGravity:null,springLength:null,springConstant:null},clustering:{enabled:!1,initialMaxNodes:100,clusterThreshold:500,reduceToNodes:300,chainThreshold:.4,clusterEdgeThreshold:20,sectorThreshold:100,screenSizeThreshold:.2,fontSizeMultiplier:4,maxFontSize:1e3,forceAmplification:.1,distanceAmplification:.1,edgeGrowth:20,nodeScaling:{width:1,height:1,radius:1},maxNodeSizeIncrements:600,activeAreaBoxSize:80,clusterLevelDifference:2},navigation:{enabled:!1},keyboard:{enabled:!1,speed:{x:10,y:10,zoom:.02}},dataManipulation:{enabled:!1,initiallyVisible:!1},hierarchicalLayout:{enabled:!1,levelSeparation:150,nodeSpacing:100,direction:"UD",layout:"hubsize"},freezeForStabilization:!1,smoothCurves:{enabled:!0,dynamic:!0,type:"continuous",roundness:.5},maxVelocity:30,minVelocity:.1,stabilize:!0,stabilizationIterations:1e3,zoomExtentOnStabilize:!0,locale:"en",locales:_,tooltip:{delay:300,fontColor:"black",fontSize:14,fontFace:"verdana",color:{border:"#666",background:"#FFFFC6"}},dragNetwork:!0,dragNodes:!0,zoomable:!0,hover:!1,hideEdgesOnDrag:!1,hideNodesOnDrag:!1,width:"100%",height:"100%",selectable:!0},this.constants=a.extend({},this.defaultOptions),this.pixelRatio=1,this.hoverObj={nodes:{},edges:{}},this.controlNodesActive=!1,this.navigationHammers={existing:[],_new:[]},this.animationSpeed=1/this.renderRefreshRate,this.animationEasingFunction="easeInOutQuint",this.easingTime=0,this.sourceScale=0,this.targetScale=0,this.sourceTranslation=0,this.targetTranslation=0,this.lockedOnNodeId=null,this.lockedOnNodeOffset=null,this.touchTime=0;var o=this;this.groups=new u,this.images=new m,this.images.setOnloadCallback(function(){o._redraw()}),this.xIncrement=0,this.yIncrement=0,this.zoomIncrement=0,this._loadPhysicsSystem(),this._create(),this._loadSectorSystem(),this._loadClusterSystem(),this._loadSelectionSystem(),this._loadHierarchySystem(),this._setTranslation(this.frame.clientWidth/2,this.frame.clientHeight/2),this._setScale(1),this.setOptions(i),this.freezeSimulation=!1,this.cachedFunctions={},this.startedStabilization=!1,this.stabilized=!1,this.stabilizationIterations=null,this.draggingNodes=!1,this.calculationNodes={},this.calculationNodeIndices=[],this.nodeIndices=[],this.nodes={},this.edges={},this.canvasTopLeft={x:0,y:0},this.canvasBottomRight={x:0,y:0},this.pointerPosition={x:0,y:0},this.areaCenter={},this.scale=1,this.previousScale=this.scale,this.nodesData=null,this.edgesData=null,this.nodesListeners={add:function(t,e){o._addNodes(e.items),o.start()},update:function(t,e){o._updateNodes(e.items,e.data),o.start()},remove:function(t,e){o._removeNodes(e.items),o.start()}},this.edgesListeners={add:function(t,e){o._addEdges(e.items),o.start()},update:function(t,e){o._updateEdges(e.items),o.start()},remove:function(t,e){o._removeEdges(e.items),o.start()}},this.moving=!0,this.timer=void 0,this.setData(e,this.constants.clustering.enabled||this.constants.hierarchicalLayout.enabled),this.initializing=!1,1==this.constants.hierarchicalLayout.enabled?this._setupHierarchicalLayout():0==this.constants.stabilize&&this.zoomExtent(void 0,!0,this.constants.clustering.enabled),this.constants.clustering.enabled&&this.startWithClustering()}var o=i(56),n=i(45),r=i(58),a=i(1),h=i(47),d=i(3),l=i(4),c=i(42),p=i(43),u=i(38),m=i(39),f=i(40),g=i(37),v=i(41),y=i(54),b=i(55),_=i(49);i(50),o(s.prototype),s.prototype._determineBrowserMethod=function(){var t=navigator.userAgent.toLowerCase();this.requiresTimeout=!1,-1!=t.indexOf("msie 9.0")?this.requiresTimeout=!0:-1!=t.indexOf("safari")&&t.indexOf("chrome")<=-1&&(this.requiresTimeout=!0)},s.prototype._getScriptPath=function(){for(var t=document.getElementsByTagName("script"),e=0;et.boundingBox.left&&(s=t.boundingBox.left),ot.boundingBox.bottom&&(e=t.boundingBox.bottom),i=this.constants.clustering.initialMaxNodes?49.07548/(n+142.05338)+91444e-8:12.662/(n+7.4147)+.0964822:1==this.constants.clustering.enabled&&n>=this.constants.clustering.initialMaxNodes?77.5271985/(n+187.266146)+476710517e-13:30.5062972/(n+19.93597763)+.08413486;var r=Math.min(this.frame.canvas.clientWidth/600,this.frame.canvas.clientHeight/600);s*=r}else{var a=1.1*Math.abs(o.maxX-o.minX),h=1.1*Math.abs(o.maxY-o.minY),d=this.frame.canvas.clientWidth/a,l=this.frame.canvas.clientHeight/h;s=l>=d?d:l}s>1&&(s=1);var c=this._findCenter(o);if(0==i){var p={position:c,scale:s,animation:t};this.moveTo(p),this.moving=!0,this.start()}else c.x*=s,c.y*=s,c.x-=.5*this.frame.canvas.clientWidth,c.y-=.5*this.frame.canvas.clientHeight,this._setScale(s),this._setTranslation(-c.x,-c.y)},s.prototype._updateNodeIndexList=function(){this._clearNodeIndexList();for(var t in this.nodes)this.nodes.hasOwnProperty(t)&&this.nodeIndices.push(t)},s.prototype.setData=function(t,e){if(void 0===e&&(e=!1),this.initializing=!0,t&&t.dot&&(t.nodes||t.edges))throw new SyntaxError('Data must contain either parameter "dot" or parameter pair "nodes" and "edges", but not both.');if(1==this.constants.dataManipulation.enabled&&this._createManipulatorBar(),this.setOptions(t&&t.options),t&&t.dot){if(t&&t.dot){var i=c.DOTToGraph(t.dot);return void this.setData(i)}}else if(t&&t.gephi){if(t&&t.gephi){var s=p.parseGephi(t.gephi);return void this.setData(s)}}else this._setNodes(t&&t.nodes),this._setEdges(t&&t.edges);this._putDataInSector(),0==e&&(1==this.constants.hierarchicalLayout.enabled?(this._resetLevels(),this._setupHierarchicalLayout()):this.constants.stabilize&&this._stabilize(),this.start()),this.initializing=!1},s.prototype.setOptions=function(t){if(t){var e,i=["nodes","edges","smoothCurves","hierarchicalLayout","clustering","navigation","keyboard","dataManipulation","onAdd","onEdit","onEditEdge","onConnect","onDelete","clickToUse"];if(a.selectiveNotDeepExtend(i,this.constants,t),a.selectiveNotDeepExtend(["color"],this.constants.nodes,t.nodes),a.selectiveNotDeepExtend(["color","length"],this.constants.edges,t.edges),t.physics&&(a.mergeOptions(this.constants.physics,t.physics,"barnesHut"),a.mergeOptions(this.constants.physics,t.physics,"repulsion"),t.physics.hierarchicalRepulsion)){this.constants.hierarchicalLayout.enabled=!0,this.constants.physics.hierarchicalRepulsion.enabled=!0,this.constants.physics.barnesHut.enabled=!1;for(e in t.physics.hierarchicalRepulsion)t.physics.hierarchicalRepulsion.hasOwnProperty(e)&&(this.constants.physics.hierarchicalRepulsion[e]=t.physics.hierarchicalRepulsion[e]) +}if(t.onAdd&&(this.triggerFunctions.add=t.onAdd),t.onEdit&&(this.triggerFunctions.edit=t.onEdit),t.onEditEdge&&(this.triggerFunctions.editEdge=t.onEditEdge),t.onConnect&&(this.triggerFunctions.connect=t.onConnect),t.onDelete&&(this.triggerFunctions.del=t.onDelete),a.mergeOptions(this.constants,t,"smoothCurves"),a.mergeOptions(this.constants,t,"hierarchicalLayout"),a.mergeOptions(this.constants,t,"clustering"),a.mergeOptions(this.constants,t,"navigation"),a.mergeOptions(this.constants,t,"keyboard"),a.mergeOptions(this.constants,t,"dataManipulation"),t.dataManipulation&&(this.editMode=this.constants.dataManipulation.initiallyVisible),t.edges&&(void 0!==t.edges.color&&(a.isString(t.edges.color)?(this.constants.edges.color={},this.constants.edges.color.color=t.edges.color,this.constants.edges.color.highlight=t.edges.color,this.constants.edges.color.hover=t.edges.color):(void 0!==t.edges.color.color&&(this.constants.edges.color.color=t.edges.color.color),void 0!==t.edges.color.highlight&&(this.constants.edges.color.highlight=t.edges.color.highlight),void 0!==t.edges.color.hover&&(this.constants.edges.color.hover=t.edges.color.hover)),this.constants.edges.inheritColor=!1),t.edges.fontColor||void 0!==t.edges.color&&(a.isString(t.edges.color)?this.constants.edges.fontColor=t.edges.color:void 0!==t.edges.color.color&&(this.constants.edges.fontColor=t.edges.color.color))),t.nodes&&t.nodes.color){var s=a.parseColor(t.nodes.color);this.constants.nodes.color.background=s.background,this.constants.nodes.color.border=s.border,this.constants.nodes.color.highlight.background=s.highlight.background,this.constants.nodes.color.highlight.border=s.highlight.border,this.constants.nodes.color.hover.background=s.hover.background,this.constants.nodes.color.hover.border=s.hover.border}if(t.groups)for(var o in t.groups)if(t.groups.hasOwnProperty(o)){var n=t.groups[o];this.groups.add(o,n)}if(t.tooltip){for(e in t.tooltip)t.tooltip.hasOwnProperty(e)&&(this.constants.tooltip[e]=t.tooltip[e]);t.tooltip.color&&(this.constants.tooltip.color=a.parseColor(t.tooltip.color))}if("clickToUse"in t&&(t.clickToUse?this.activator||(this.activator=new b(this.frame),this.activator.on("change",this._createKeyBinds.bind(this))):this.activator&&(this.activator.destroy(),delete this.activator)),t.labels)throw new Error('Option "labels" is deprecated. Use options "locale" and "locales" instead.');this._loadPhysicsSystem(),this._loadNavigationControls(),this._loadManipulationSystem(),this._configureSmoothCurves(),this._createKeyBinds(),this.setSize(this.constants.width,this.constants.height),this.moving=!0,this.start()}},s.prototype._create=function(){for(;this.containerElement.hasChildNodes();)this.containerElement.removeChild(this.containerElement.firstChild);if(this.frame=document.createElement("div"),this.frame.className="vis network-frame",this.frame.style.position="relative",this.frame.style.overflow="hidden",this.frame.canvas=document.createElement("canvas"),this.frame.canvas.style.position="relative",this.frame.appendChild(this.frame.canvas),this.frame.canvas.getContext){var t=this.frame.canvas.getContext("2d");this.pixelRatio=(window.devicePixelRatio||1)/(t.webkitBackingStorePixelRatio||t.mozBackingStorePixelRatio||t.msBackingStorePixelRatio||t.oBackingStorePixelRatio||t.backingStorePixelRatio||1),this.frame.canvas.getContext("2d").setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0)}else{var e=document.createElement("DIV");e.style.color="red",e.style.fontWeight="bold",e.style.padding="10px",e.innerHTML="Error: your browser does not support HTML canvas",this.frame.canvas.appendChild(e)}var i=this;this.drag={},this.pinch={},this.hammer=new n(this.frame.canvas),this.hammer.get("pinch").set({enable:!0}),this.hammer.on("tap",i._onTap.bind(i)),this.hammer.on("doubletap",i._onDoubleTap.bind(i)),this.hammer.on("press",i._onHold.bind(i)),this.hammer.on("pinch",i._onPinch.bind(i)),h.onTouch(this.hammer,i._onTouch.bind(i)),this.hammer.on("panstart",i._onDragStart.bind(i)),this.hammer.on("panmove",i._onDrag.bind(i)),this.hammer.on("panend",i._onDragEnd.bind(i)),this.frame.canvas.addEventListener("mousemove",i._onMouseMoveTitle.bind(i)),this.frame.canvas.addEventListener("mousewheel",i._onMouseWheel.bind(i)),this.frame.canvas.addEventListener("DOMMouseScroll",i._onMouseWheel.bind(i)),this.containerElement.appendChild(this.frame)},s.prototype._createKeyBinds=function(){var t=this;void 0!==this.keycharm&&this.keycharm.destroy(),this.keycharm=r(),this.keycharm.reset(),this.constants.keyboard.enabled&&this.isActive()&&(this.keycharm.bind("up",this._moveUp.bind(t),"keydown"),this.keycharm.bind("up",this._yStopMoving.bind(t),"keyup"),this.keycharm.bind("down",this._moveDown.bind(t),"keydown"),this.keycharm.bind("down",this._yStopMoving.bind(t),"keyup"),this.keycharm.bind("left",this._moveLeft.bind(t),"keydown"),this.keycharm.bind("left",this._xStopMoving.bind(t),"keyup"),this.keycharm.bind("right",this._moveRight.bind(t),"keydown"),this.keycharm.bind("right",this._xStopMoving.bind(t),"keyup"),this.keycharm.bind("=",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("=",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("num+",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("num+",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("num-",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("num-",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("-",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("-",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("[",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("[",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("]",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("]",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("pageup",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("pageup",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("pagedown",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("pagedown",this._stopZoom.bind(t),"keyup")),1==this.constants.dataManipulation.enabled&&(this.keycharm.bind("esc",this._createManipulatorBar.bind(t)),this.keycharm.bind("delete",this._deleteSelected.bind(t)))},s.prototype.destroy=function(){this.start=function(){},this.redraw=function(){},this.timer=!1,this._cleanupPhysicsConfiguration(),this.keycharm.reset(),this.hammer.destroy(),this.off(),this._recursiveDOMDelete(this.containerElement)},s.prototype._recursiveDOMDelete=function(t){for(;1==t.hasChildNodes();)this._recursiveDOMDelete(t.firstChild),t.removeChild(t.firstChild)},s.prototype._getPointer=function(t){return{x:t.x-a.getAbsoluteLeft(this.frame.canvas),y:t.y-a.getAbsoluteTop(this.frame.canvas)}},s.prototype._onTouch=function(t){(new Date).valueOf()-this.touchTime>100&&(this.drag.pointer=this._getPointer(t.center),this.drag.pinched=!1,this.pinch.scale=this._getScale(),this.touchTime=(new Date).valueOf(),this._handleTouch(this.drag.pointer))},s.prototype._onDragStart=function(t){this._handleDragStart(t)},s.prototype._handleDragStart=function(t){void 0===this.drag.pointer&&this._onTouch(t);var e=this._getNodeAt(this.drag.pointer);if(this.drag.dragging=!0,this.drag.selection=[],this.drag.translation=this._getTranslation(),this.drag.nodeId=null,this.draggingNodes=!1,null!=e&&1==this.constants.dragNodes){this.draggingNodes=!0,this.drag.nodeId=e.id,e.isSelected()||this._selectObject(e,!1),this.emit("dragStart",{nodeIds:this.getSelection().nodes});for(var i in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(i)){var s=this.selectionObj.nodes[i],o={id:s.id,node:s,x:s.x,y:s.y,xFixed:s.xFixed,yFixed:s.yFixed};s.xFixed=!0,s.yFixed=!0,this.drag.selection.push(o)}}t.preventDefault()},s.prototype._onDrag=function(t){this._handleOnDrag(t)},s.prototype._handleOnDrag=function(t){if(!this.drag.pinched){this.releaseNode();var e=this._getPointer(t.center),i=this,s=this.drag,o=s.selection;if(o&&o.length&&1==this.constants.dragNodes){var n=e.x-s.pointer.x,r=e.y-s.pointer.y;o.forEach(function(t){var e=t.node;t.xFixed||(e.x=i._XconvertDOMtoCanvas(i._XconvertCanvasToDOM(t.x)+n)),t.yFixed||(e.y=i._YconvertDOMtoCanvas(i._YconvertCanvasToDOM(t.y)+r))}),this.moving||(this.moving=!0,this.start())}else if(1==this.constants.dragNetwork){if(void 0===this.drag.pointer)return void this._handleDragStart(t);var a=e.x-this.drag.pointer.x,h=e.y-this.drag.pointer.y;this._setTranslation(this.drag.translation.x+a,this.drag.translation.y+h),this._redraw()}t.preventDefault()}},s.prototype._onDragEnd=function(t){this._handleDragEnd(t)},s.prototype._handleDragEnd=function(t){this.drag.dragging=!1;var e=this.drag.selection;e&&e.length?(e.forEach(function(t){t.node.xFixed=t.xFixed,t.node.yFixed=t.yFixed}),this.moving=!0,this.start()):this._redraw(),0==this.draggingNodes?this.emit("dragEnd",{nodeIds:[]}):this.emit("dragEnd",{nodeIds:this.getSelection().nodes}),t.preventDefault()},s.prototype._onTap=function(t){var e=this._getPointer(t.center);this.pointerPosition=e,this._handleTap(e)},s.prototype._onDoubleTap=function(t){var e=this._getPointer(t.center);this._handleDoubleTap(e)},s.prototype._onHold=function(t){var e=this._getPointer(t.center);this.pointerPosition=e,this._handleOnHold(e)},s.prototype._onRelease=function(t){var e=this._getPointer(t.center);this._handleOnRelease(e)},s.prototype._onPinch=function(t){var e=this._getPointer(t.center);this.drag.pinched=!0,"scale"in this.pinch||(this.pinch.scale=1);var i=this.pinch.scale*t.scale;this._zoom(i,e)},s.prototype._zoom=function(t,e){if(1==this.constants.zoomable){var i=this._getScale();1e-5>t&&(t=1e-5),t>10&&(t=10);var s=null;void 0!==this.drag&&1==this.drag.dragging&&(s=this.DOMtoCanvas(this.drag.pointer));var o=this._getTranslation(),n=t/i,r=(1-n)*e.x+o.x*n,a=(1-n)*e.y+o.y*n;if(this.areaCenter={x:this._XconvertDOMtoCanvas(e.x),y:this._YconvertDOMtoCanvas(e.y)},this._setScale(t),this._setTranslation(r,a),this.updateClustersDefault(),null!=s){var h=this.canvasToDOM(s);this.drag.pointer.x=h.x,this.drag.pointer.y=h.y}return this._redraw(),t>i?this.emit("zoom",{direction:"+"}):this.emit("zoom",{direction:"-"}),t}},s.prototype._onMouseWheel=function(t){var e=0;if(t.wheelDelta?e=t.wheelDelta/120:t.detail&&(e=-t.detail/3),e){var i=this._getScale(),s=e/10;0>e&&(s/=1-s),i*=1+s;var o=this._getPointer({x:t.pageX,y:t.pageY});this._zoom(i,o)}t.preventDefault()},s.prototype._onMouseMoveTitle=function(t){var e=this._getPointer({x:t.pageX,y:t.pageY});this.popupObj&&this._checkHidePopup(e);var i=this,s=function(){i._checkShowPopup(e)};if(this.popupTimer&&clearInterval(this.popupTimer),this.drag.dragging||(this.popupTimer=setTimeout(s,this.constants.tooltip.delay)),1==this.constants.hover){for(var o in this.hoverObj.edges)this.hoverObj.edges.hasOwnProperty(o)&&(this.hoverObj.edges[o].hover=!1,delete this.hoverObj.edges[o]);var n=this._getNodeAt(e);null==n&&(n=this._getEdgeAt(e)),null!=n&&this._hoverObject(n);for(var r in this.hoverObj.nodes)this.hoverObj.nodes.hasOwnProperty(r)&&(n instanceof f&&n.id!=r||n instanceof g||null==n)&&(this._blurObject(this.hoverObj.nodes[r]),delete this.hoverObj.nodes[r]);this.redraw()}},s.prototype._checkShowPopup=function(t){var e,i={left:this._XconvertDOMtoCanvas(t.x),top:this._YconvertDOMtoCanvas(t.y),right:this._XconvertDOMtoCanvas(t.x),bottom:this._YconvertDOMtoCanvas(t.y)},s=this.popupObj,o=!1;if(void 0==this.popupObj){var n=this.nodes,r=[];for(e in n)if(n.hasOwnProperty(e)){var a=n[e];a.isOverlappingWith(i)&&void 0!==a.getTitle()&&r.push(e)}r.length>0&&(this.popupObj=this.nodes[r[r.length-1]],o=!0)}if(void 0===this.popupObj&&0==o){var h=this.edges,d=[];for(e in h)if(h.hasOwnProperty(e)){var l=h[e];l.connected&&void 0!==l.getTitle()&&l.isOverlappingWith(i)&&d.push(e)}d.length>0&&(this.popupObj=this.edges[d[d.length-1]])}if(this.popupObj){if(this.popupObj!=s){var c=this;c.popup||(c.popup=new v(c.frame,c.constants.tooltip)),c.popup.setPosition(t.x-3,t.y-3),c.popup.setText(c.popupObj.getTitle()),c.popup.show()}}else this.popup&&this.popup.hide()},s.prototype._checkHidePopup=function(t){this.popupObj&&this._getNodeAt(t)||(this.popupObj=void 0,this.popup&&this.popup.hide())},s.prototype.setSize=function(t,e){var i=!1,s=this.frame.canvas.width,o=this.frame.canvas.height;t!=this.constants.width||e!=this.constants.height||this.frame.style.width!=t||this.frame.style.height!=e?(this.frame.style.width=t,this.frame.style.height=e,this.frame.canvas.style.width="100%",this.frame.canvas.style.height="100%",this.frame.canvas.width=this.frame.canvas.clientWidth*this.pixelRatio,this.frame.canvas.height=this.frame.canvas.clientHeight*this.pixelRatio,this.constants.width=t,this.constants.height=e,i=!0):(this.frame.canvas.width!=this.frame.canvas.clientWidth*this.pixelRatio&&(this.frame.canvas.width=this.frame.canvas.clientWidth*this.pixelRatio,i=!0),this.frame.canvas.height!=this.frame.canvas.clientHeight*this.pixelRatio&&(this.frame.canvas.height=this.frame.canvas.clientHeight*this.pixelRatio,i=!0)),1==i&&this.emit("resize",{width:this.frame.canvas.width*this.pixelRatio,height:this.frame.canvas.height*this.pixelRatio,oldWidth:s*this.pixelRatio,oldHeight:o*this.pixelRatio})},s.prototype._setNodes=function(t){var e=this.nodesData;if(t instanceof d||t instanceof l)this.nodesData=t;else if(Array.isArray(t))this.nodesData=new d,this.nodesData.add(t);else{if(t)throw new TypeError("Array or DataSet expected");this.nodesData=new d}if(e&&a.forEach(this.nodesListeners,function(t,i){e.off(i,t)}),this.nodes={},this.nodesData){var i=this;a.forEach(this.nodesListeners,function(t,e){i.nodesData.on(e,t)});var s=this.nodesData.getIds();this._addNodes(s)}this._updateSelection()},s.prototype._addNodes=function(t){for(var e,i=0,s=t.length;s>i;i++){e=t[i];var o=this.nodesData.get(e),n=new f(o,this.images,this.groups,this.constants);if(this.nodes[e]=n,!(0!=n.xFixed&&0!=n.yFixed||null!==n.x&&null!==n.y)){var r=1*t.length+10,a=2*Math.PI*Math.random();0==n.xFixed&&(n.x=r*Math.cos(a)),0==n.yFixed&&(n.y=r*Math.sin(a))}this.moving=!0}this._updateNodeIndexList(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateCalculationNodes(),this._reconnectEdges(),this._updateValueRange(this.nodes),this.updateLabels()},s.prototype._updateNodes=function(t,e){for(var i=this.nodes,s=0,o=t.length;o>s;s++){var n=t[s],r=i[n],a=e[s];r?r.setProperties(a,this.constants):(r=new f(properties,this.images,this.groups,this.constants),i[n]=r)}this.moving=!0,1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateNodeIndexList(),this._updateValueRange(i)},s.prototype._removeNodes=function(t){for(var e=this.nodes,i=0,s=t.length;s>i;i++){var o=t[i];delete e[o]}this._updateNodeIndexList(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateCalculationNodes(),this._reconnectEdges(),this._updateSelection(),this._updateValueRange(e)},s.prototype._setEdges=function(t){var e=this.edgesData;if(t instanceof d||t instanceof l)this.edgesData=t;else if(Array.isArray(t))this.edgesData=new d,this.edgesData.add(t);else{if(t)throw new TypeError("Array or DataSet expected");this.edgesData=new d}if(e&&a.forEach(this.edgesListeners,function(t,i){e.off(i,t)}),this.edges={},this.edgesData){var i=this;a.forEach(this.edgesListeners,function(t,e){i.edgesData.on(e,t)});var s=this.edgesData.getIds();this._addEdges(s)}this._reconnectEdges()},s.prototype._addEdges=function(t){for(var e=this.edges,i=this.edgesData,s=0,o=t.length;o>s;s++){var n=t[s],r=e[n];r&&r.disconnect();var a=i.get(n,{showInternalIds:!0});e[n]=new g(a,this,this.constants)}this.moving=!0,this._updateValueRange(e),this._createBezierNodes(),this._updateCalculationNodes(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout())},s.prototype._updateEdges=function(t){for(var e=this.edges,i=this.edgesData,s=0,o=t.length;o>s;s++){var n=t[s],r=i.get(n),a=e[n];a?(a.disconnect(),a.setProperties(r,this.constants),a.connect()):(a=new g(r,this,this.constants),this.edges[n]=a)}this._createBezierNodes(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this.moving=!0,this._updateValueRange(e)},s.prototype._removeEdges=function(t){for(var e=this.edges,i=0,s=t.length;s>i;i++){var o=t[i],n=e[o];n&&(null!=n.via&&delete this.sectors.support.nodes[n.via.id],n.disconnect(),delete e[o])}this.moving=!0,this._updateValueRange(e),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateCalculationNodes()},s.prototype._reconnectEdges=function(){var t,e=this.nodes,i=this.edges;for(t in e)e.hasOwnProperty(t)&&(e[t].edges=[],e[t].dynamicEdges=[]);for(t in i)if(i.hasOwnProperty(t)){var s=i[t];s.from=null,s.to=null,s.connect()}},s.prototype._updateValueRange=function(t){var e,i=void 0,s=void 0;for(e in t)if(t.hasOwnProperty(e)){var o=t[e].getValue();void 0!==o&&(i=void 0===i?o:Math.min(o,i),s=void 0===s?o:Math.max(o,s))}if(void 0!==i&&void 0!==s)for(e in t)t.hasOwnProperty(e)&&t[e].setValueRange(i,s)},s.prototype.redraw=function(){this.setSize(this.constants.width,this.constants.height),this._redraw()},s.prototype._redraw=function(t){var e=this.frame.canvas.getContext("2d");e.setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0);var i=this.frame.canvas.width*this.pixelRatio,s=this.frame.canvas.height*this.pixelRatio;e.clearRect(0,0,i,s),e.save(),e.translate(this.translation.x,this.translation.y),e.scale(this.scale,this.scale),this.canvasTopLeft={x:this._XconvertDOMtoCanvas(0),y:this._YconvertDOMtoCanvas(0)},this.canvasBottomRight={x:this._XconvertDOMtoCanvas(this.frame.canvas.clientWidth*this.pixelRatio),y:this._YconvertDOMtoCanvas(this.frame.canvas.clientHeight*this.pixelRatio)},1!=t&&(this._doInAllSectors("_drawAllSectorNodes",e),(0==this.drag.dragging||void 0===this.drag.dragging||0==this.constants.hideEdgesOnDrag)&&this._doInAllSectors("_drawEdges",e)),(0==this.drag.dragging||void 0===this.drag.dragging||0==this.constants.hideNodesOnDrag)&&this._doInAllSectors("_drawNodes",e,!1),1!=t&&1==this.controlNodesActive&&this._doInAllSectors("_drawControlNodes",e),e.restore(),1==t&&e.clearRect(0,0,i,s)},s.prototype._setTranslation=function(t,e){void 0===this.translation&&(this.translation={x:0,y:0}),void 0!==t&&(this.translation.x=t),void 0!==e&&(this.translation.y=e),this.emit("viewChanged")},s.prototype._getTranslation=function(){return{x:this.translation.x,y:this.translation.y}},s.prototype._setScale=function(t){this.scale=t},s.prototype._getScale=function(){return this.scale},s.prototype._XconvertDOMtoCanvas=function(t){return(t-this.translation.x)/this.scale},s.prototype._XconvertCanvasToDOM=function(t){return t*this.scale+this.translation.x},s.prototype._YconvertDOMtoCanvas=function(t){return(t-this.translation.y)/this.scale},s.prototype._YconvertCanvasToDOM=function(t){return t*this.scale+this.translation.y},s.prototype.canvasToDOM=function(t){return{x:this._XconvertCanvasToDOM(t.x),y:this._YconvertCanvasToDOM(t.y)}},s.prototype.DOMtoCanvas=function(t){return{x:this._XconvertDOMtoCanvas(t.x),y:this._YconvertDOMtoCanvas(t.y)}},s.prototype._drawNodes=function(t,e){void 0===e&&(e=!1);var i=this.nodes,s=[];for(var o in i)i.hasOwnProperty(o)&&(i[o].setScaleAndPos(this.scale,this.canvasTopLeft,this.canvasBottomRight),i[o].isSelected()?s.push(o):(i[o].inArea()||e)&&i[o].draw(t));for(var n=0,r=s.length;r>n;n++)(i[s[n]].inArea()||e)&&i[s[n]].draw(t)},s.prototype._drawEdges=function(t){var e=this.edges;for(var i in e)if(e.hasOwnProperty(i)){var s=e[i];s.setScale(this.scale),s.connected&&e[i].draw(t)}},s.prototype._drawControlNodes=function(t){var e=this.edges;for(var i in e)e.hasOwnProperty(i)&&e[i]._drawControlNodes(t)},s.prototype._stabilize=function(){1==this.constants.freezeForStabilization&&this._freezeDefinedNodes();for(var t=0;this.moving&&t0)for(t in i)i.hasOwnProperty(t)&&(i[t].discreteStepLimited(e,this.constants.maxVelocity),s=!0);else for(t in i)i.hasOwnProperty(t)&&(i[t].discreteStep(e),s=!0);if(1==s){var o=this.constants.minVelocity/Math.max(this.scale,.05);return o>.5*this.constants.maxVelocity?!0:this._isMoving(o)}return!1},s.prototype._revertPhysicsState=function(){var t=this.nodes;for(var e in t)t.hasOwnProperty(e)&&t[e].revertPosition()},s.prototype._revertPhysicsTick=function(){this._doInAllActiveSectors("_revertPhysicsState"),1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic&&this._doInSupportSector("_revertPhysicsState")},s.prototype._physicsTick=function(){if(!this.freezeSimulation&&1==this.moving){var t=!1,e=!1;this._doInAllActiveSectors("_initializeForceCalculation");var i=this._doInAllActiveSectors("_discreteStepNodes");1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic&&(e=this._doInSupportSector("_discreteStepNodes"));for(var s=0;s2*e||1==this.runDoubleSpeed)&&1==this.moving&&(this._physicsTick(),0!=this.renderTime&&(this.runDoubleSpeed=!0));var i=Date.now();this._redraw(),this.renderTime=Date.now()-i,this.start()},"undefined"!=typeof window&&(window.requestAnimationFrame=window.requestAnimationFrame||window.mozRequestAnimationFrame||window.webkitRequestAnimationFrame||window.msRequestAnimationFrame),s.prototype.start=function(){if(1==this.moving||0!=this.xIncrement||0!=this.yIncrement||0!=this.zoomIncrement)this.timer||(this.timer=1==this.requiresTimeout?window.setTimeout(this._animationStep.bind(this),this.renderTimestep):window.requestAnimationFrame(this._animationStep.bind(this)));else if(this._redraw(),this.stabilizationIterations>1){var t=this,e={iterations:t.stabilizationIterations};this.stabilizationIterations=0,this.startedStabilization=!1,setTimeout(function(){t.emit("stabilized",e)},0)}else this.stabilizationIterations=0},s.prototype._handleNavigation=function(){if(0!=this.xIncrement||0!=this.yIncrement){var t=this._getTranslation();this._setTranslation(t.x+this.xIncrement,t.y+this.yIncrement)}if(0!=this.zoomIncrement){var e={x:this.frame.canvas.clientWidth/2,y:this.frame.canvas.clientHeight/2};this._zoom(this.scale*(1+this.zoomIncrement),e)}},s.prototype.toggleFreeze=function(){0==this.freezeSimulation?this.freezeSimulation=!0:(this.freezeSimulation=!1,this.start())},s.prototype._configureSmoothCurves=function(t){if(void 0===t&&(t=!0),1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic){this._createBezierNodes();for(var e in this.sectors.support.nodes)this.sectors.support.nodes.hasOwnProperty(e)&&void 0===this.edges[this.sectors.support.nodes[e].parentEdgeId]&&delete this.sectors.support.nodes[e]}else{this.sectors.support.nodes={};for(var i in this.edges)this.edges.hasOwnProperty(i)&&(this.edges[i].via=null)}this._updateCalculationNodes(),t||(this.moving=!0,this.start())},s.prototype._createBezierNodes=function(){if(1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic)for(var t in this.edges)if(this.edges.hasOwnProperty(t)){var e=this.edges[t];if(null==e.via){var i="edgeId:".concat(e.id);this.sectors.support.nodes[i]=new f({id:i,mass:1,shape:"circle",image:"",internalMultiplier:1},{},{},this.constants),e.via=this.sectors.support.nodes[i],e.via.parentEdgeId=e.id,e.positionBezierNode()}}},s.prototype._initializeMixinLoaders=function(){for(var t in y)y.hasOwnProperty(t)&&(s.prototype[t]=y[t])},s.prototype.storePosition=function(){console.log("storePosition is deprecated: use .storePositions() from now on."),this.storePositions()},s.prototype.storePositions=function(){var t=[];for(var e in this.nodes)if(this.nodes.hasOwnProperty(e)){var i=this.nodes[e],s=!this.nodes.xFixed,o=!this.nodes.yFixed;(this.nodesData._data[e].x!=Math.round(i.x)||this.nodesData._data[e].y!=Math.round(i.y))&&t.push({id:e,x:Math.round(i.x),y:Math.round(i.y),allowedToMoveX:s,allowedToMoveY:o})}this.nodesData.update(t)},s.prototype.getPositions=function(t){var e={};if(void 0!==t){if(1==Array.isArray(t)){for(var i=0;i=1&&(this.easingTime=0,this._redraw=null!=this.lockedOnNodeId?this._lockedRedraw:this._classicRedraw,this.emit("animationFinished"))},s.prototype._classicRedraw=function(){},s.prototype.isActive=function(){return!this.activator||this.activator.active},s.prototype.setScale=function(){return this._setScale()},s.prototype.getScale=function(){return this._getScale()},s.prototype.getCenterCoordinates=function(){return this.DOMtoCanvas({x:.5*this.frame.canvas.clientWidth,y:.5*this.frame.canvas.clientHeight})},s.prototype.getBoundingBox=function(t){return void 0!==this.nodes[t]?this.nodes[t].boundingBox:void 0},t.exports=s},function(t,e,i){function s(t,e,i){if(!e)throw"No network provided";var s=["edges","physics"],n=o.selectiveBridgeObject(s,i);this.options=n.edges,this.physics=n.physics,this.options.smoothCurves=i.smoothCurves,this.network=e,this.id=void 0,this.fromId=void 0,this.toId=void 0,this.title=void 0,this.widthSelected=this.options.width*this.options.widthSelectionMultiplier,this.value=void 0,this.selected=!1,this.hover=!1,this.labelDimensions={top:0,left:0,width:0,height:0,yLine:0},this.dirtyLabel=!0,this.from=null,this.to=null,this.via=null,this.fromBackup=null,this.toBackup=null,this.originalFromId=[],this.originalToId=[],this.connected=!1,this.widthFixed=!1,this.lengthFixed=!1,this.setProperties(t),this.controlNodesEnabled=!1,this.controlNodes={from:null,to:null,positions:{}},this.connectedNode=null}var o=i(1),n=i(40);s.prototype.setProperties=function(t){if(t){var e=["style","fontSize","fontFace","fontColor","fontFill","fontStrokeWidth","fontStrokeColor","width","widthSelectionMultiplier","hoverWidth","arrowScaleFactor","dash","inheritColor","labelAlignment"];switch(o.selectiveDeepExtend(e,this.options,t),void 0!==t.from&&(this.fromId=t.from),void 0!==t.to&&(this.toId=t.to),void 0!==t.id&&(this.id=t.id),void 0!==t.label&&(this.label=t.label,this.dirtyLabel=!0),void 0!==t.title&&(this.title=t.title),void 0!==t.value&&(this.value=t.value),void 0!==t.length&&(this.physics.springLength=t.length),void 0!==t.color&&(this.options.inheritColor=!1,o.isString(t.color)?(this.options.color.color=t.color,this.options.color.highlight=t.color):(void 0!==t.color.color&&(this.options.color.color=t.color.color),void 0!==t.color.highlight&&(this.options.color.highlight=t.color.highlight),void 0!==t.color.hover&&(this.options.color.hover=t.color.hover))),this.connect(),this.widthFixed=this.widthFixed||void 0!==t.width,this.lengthFixed=this.lengthFixed||void 0!==t.length,this.widthSelected=this.options.width*this.options.widthSelectionMultiplier,this.options.style){case"line":this.draw=this._drawLine;break;case"arrow":this.draw=this._drawArrow;break;case"arrow-center":this.draw=this._drawArrowCenter;break;case"dash-line":this.draw=this._drawDashLine;break;default:this.draw=this._drawLine}}},s.prototype.connect=function(){this.disconnect(),this.from=this.network.nodes[this.fromId]||null,this.to=this.network.nodes[this.toId]||null,this.connected=this.from&&this.to,this.connected?(this.from.attachEdge(this),this.to.attachEdge(this)):(this.from&&this.from.detachEdge(this),this.to&&this.to.detachEdge(this))},s.prototype.disconnect=function(){this.from&&(this.from.detachEdge(this),this.from=null),this.to&&(this.to.detachEdge(this),this.to=null),this.connected=!1},s.prototype.getTitle=function(){return"function"==typeof this.title?this.title():this.title +},s.prototype.getValue=function(){return this.value},s.prototype.setValueRange=function(t,e){if(!this.widthFixed&&void 0!==this.value){var i=(this.options.widthMax-this.options.widthMin)/(e-t);this.options.width=(this.value-t)*i+this.options.widthMin,this.widthSelected=this.options.width*this.options.widthSelectionMultiplier}},s.prototype.draw=function(){throw"Method draw not initialized in edge"},s.prototype.isOverlappingWith=function(t){if(this.connected){var e=10,i=this.from.x,s=this.from.y,o=this.to.x,n=this.to.y,r=t.left,a=t.top,h=this._getDistanceToEdge(i,s,o,n,r,a);return e>h}return!1},s.prototype._getColor=function(){var t=this.options.color;return"to"==this.options.inheritColor?t={highlight:this.to.options.color.highlight.border,hover:this.to.options.color.hover.border,color:this.to.options.color.border}:("from"==this.options.inheritColor||1==this.options.inheritColor)&&(t={highlight:this.from.options.color.highlight.border,hover:this.from.options.color.hover.border,color:this.from.options.color.border}),1==this.selected?t.highlight:1==this.hover?t.hover:t.color},s.prototype._drawLine=function(t){if(t.strokeStyle=this._getColor(),t.lineWidth=this._getLineWidth(),this.from!=this.to){var e,i=this._line(t);if(this.label){if(1==this.options.smoothCurves.enabled&&null!=i){var s=.5*(.5*(this.from.x+i.x)+.5*(this.to.x+i.x)),o=.5*(.5*(this.from.y+i.y)+.5*(this.to.y+i.y));e={x:s,y:o}}else e=this._pointOnLine(.5);this._label(t,this.label,e.x,e.y)}}else{var n,r,a=this.physics.springLength/4,h=this.from;h.width||h.resize(t),h.width>h.height?(n=h.x+h.width/2,r=h.y-a):(n=h.x+a,r=h.y-h.height/2),this._circle(t,n,r,a),e=this._pointOnCircle(n,r,a,.5),this._label(t,this.label,e.x,e.y)}},s.prototype._getLineWidth=function(){return 1==this.selected?Math.max(Math.min(this.widthSelected,this.options.widthMax),.3*this.networkScaleInv):1==this.hover?Math.max(Math.min(this.options.hoverWidth,this.options.widthMax),.3*this.networkScaleInv):Math.max(this.options.width,.3*this.networkScaleInv)},s.prototype._getViaCoordinates=function(){if(1==this.options.smoothCurves.dynamic&&1==this.options.smoothCurves.enabled)return this.via;if(0==this.options.smoothCurves.enabled)return{x:0,y:0};var t=null,e=null,i=this.options.smoothCurves.roundness,s=this.options.smoothCurves.type,o=Math.abs(this.from.x-this.to.x),n=Math.abs(this.from.y-this.to.y);return"discrete"==s||"diagonalCross"==s?Math.abs(this.from.x-this.to.x)this.to.y?this.from.xthis.to.x&&(t=this.from.x-i*n,e=this.from.y-i*n):this.from.ythis.to.x&&(t=this.from.x-i*n,e=this.from.y+i*n)),"discrete"==s&&(t=i*n>o?this.from.x:t)):Math.abs(this.from.x-this.to.x)>Math.abs(this.from.y-this.to.y)&&(this.from.y>this.to.y?this.from.xthis.to.x&&(t=this.from.x-i*o,e=this.from.y-i*o):this.from.ythis.to.x&&(t=this.from.x-i*o,e=this.from.y+i*o)),"discrete"==s&&(e=i*o>n?this.from.y:e)):"straightCross"==s?Math.abs(this.from.x-this.to.x)Math.abs(this.from.y-this.to.y)&&(t=this.from.xthis.to.y?this.from.xthis.to.x&&(t=this.from.x-i*n,e=this.from.y-i*n,t=this.to.x>t?this.to.x:t):this.from.ythis.to.x&&(t=this.from.x-i*n,e=this.from.y+i*n,t=this.to.x>t?this.to.x:t)):Math.abs(this.from.x-this.to.x)>Math.abs(this.from.y-this.to.y)&&(this.from.y>this.to.y?this.from.xe?this.to.y:e):this.from.x>this.to.x&&(t=this.from.x-i*o,e=this.from.y-i*o,e=this.to.y>e?this.to.y:e):this.from.ythis.to.x&&(t=this.from.x-i*o,e=this.from.y+i*o,e=this.to.yd;d++){var l=t.measureText(n[d]).width;h=l>h?l:h}var c=this.options.fontSize*r,p=i-h/2,u=s-c/2;this.labelDimensions={top:u,left:p,width:h,height:c,yLine:o}}var o=this.labelDimensions.yLine;t.save(),"horizontal"!=this.options.labelAlignment&&(t.translate(i,o),this._rotateForLabelAlignment(t),i=0,o=0),this._drawLabelRect(t),this._drawLabelText(t,i,o,n,r,a),t.restore()}},s.prototype._rotateForLabelAlignment=function(t){var e=this.from.y-this.to.y,i=this.from.x-this.to.x,s=Math.atan2(e,i);(-1>s&&0>i||s>0&&0>i)&&(s+=Math.PI),t.rotate(s)},s.prototype._drawLabelRect=function(t){if(void 0!==this.options.fontFill&&null!==this.options.fontFill&&"none"!==this.options.fontFill){t.fillStyle=this.options.fontFill;var e=2;"line-center"==this.options.labelAlignment?t.fillRect(.5*-this.labelDimensions.width,.5*-this.labelDimensions.height,this.labelDimensions.width,this.labelDimensions.height):"line-above"==this.options.labelAlignment?t.fillRect(.5*-this.labelDimensions.width,-(this.labelDimensions.height+e),this.labelDimensions.width,this.labelDimensions.height):"line-below"==this.options.labelAlignment?t.fillRect(.5*-this.labelDimensions.width,e,this.labelDimensions.width,this.labelDimensions.height):t.fillRect(this.labelDimensions.left,this.labelDimensions.top,this.labelDimensions.width,this.labelDimensions.height)}},s.prototype._drawLabelText=function(t,e,i,s,o,n){if(t.fillStyle=this.options.fontColor||"black",t.textAlign="center","horizontal"!=this.options.labelAlignment){var r=2;"line-above"==this.options.labelAlignment?(t.textBaseline="alphabetic",i-=2*r):"line-below"==this.options.labelAlignment?(t.textBaseline="hanging",i+=2*r):t.textBaseline="middle"}else t.textBaseline="middle";this.options.fontStrokeWidth>0&&(t.lineWidth=this.options.fontStrokeWidth,t.strokeStyle=this.options.fontStrokeColor,t.lineJoin="round");for(var a=0;o>a;a++)this.options.fontStrokeWidth>0&&t.strokeText(s[a],e,i),t.fillText(s[a],e,i),i+=n},s.prototype._drawDashLine=function(t){t.strokeStyle=this._getColor(),t.lineWidth=this._getLineWidth();var e=null;if(void 0!==t.setLineDash){t.save();var i=[0];i=void 0!==this.options.dash.length&&void 0!==this.options.dash.gap?[this.options.dash.length,this.options.dash.gap]:[5,5],t.setLineDash(i),t.lineDashOffset=0,e=this._line(t),t.setLineDash([0]),t.lineDashOffset=0,t.restore()}else t.beginPath(),t.lineCap="round",void 0!==this.options.dash.altLength?t.dashedLine(this.from.x,this.from.y,this.to.x,this.to.y,[this.options.dash.length,this.options.dash.gap,this.options.dash.altLength,this.options.dash.gap]):void 0!==this.options.dash.length&&void 0!==this.options.dash.gap?t.dashedLine(this.from.x,this.from.y,this.to.x,this.to.y,[this.options.dash.length,this.options.dash.gap]):(t.moveTo(this.from.x,this.from.y),t.lineTo(this.to.x,this.to.y)),t.stroke();if(this.label){var s;if(1==this.options.smoothCurves.enabled&&null!=e){var o=.5*(.5*(this.from.x+e.x)+.5*(this.to.x+e.x)),n=.5*(.5*(this.from.y+e.y)+.5*(this.to.y+e.y));s={x:o,y:n}}else s=this._pointOnLine(.5);this._label(t,this.label,s.x,s.y)}},s.prototype._pointOnLine=function(t){return{x:(1-t)*this.from.x+t*this.to.x,y:(1-t)*this.from.y+t*this.to.y}},s.prototype._pointOnCircle=function(t,e,i,s){var o=2*(s-3/8)*Math.PI;return{x:t+i*Math.cos(o),y:e-i*Math.sin(o)}},s.prototype._drawArrowCenter=function(t){var e;if(t.strokeStyle=this._getColor(),t.fillStyle=t.strokeStyle,t.lineWidth=this._getLineWidth(),this.from!=this.to){var i=this._line(t),s=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x),o=(10+5*this.options.width)*this.options.arrowScaleFactor;if(1==this.options.smoothCurves.enabled&&null!=i){var n=.5*(.5*(this.from.x+i.x)+.5*(this.to.x+i.x)),r=.5*(.5*(this.from.y+i.y)+.5*(this.to.y+i.y));e={x:n,y:r}}else e=this._pointOnLine(.5);t.arrow(e.x,e.y,s,o),t.fill(),t.stroke(),this.label&&this._label(t,this.label,e.x,e.y)}else{var a,h,d=.25*Math.max(100,this.physics.springLength),l=this.from;l.width||l.resize(t),l.width>l.height?(a=l.x+.5*l.width,h=l.y-d):(a=l.x+d,h=l.y-.5*l.height),this._circle(t,a,h,d);var s=.2*Math.PI,o=(10+5*this.options.width)*this.options.arrowScaleFactor;e=this._pointOnCircle(a,h,d,.5),t.arrow(e.x,e.y,s,o),t.fill(),t.stroke(),this.label&&(e=this._pointOnCircle(a,h,d,.5),this._label(t,this.label,e.x,e.y))}},s.prototype._pointOnBezier=function(t){var e=this._getViaCoordinates(),i=Math.pow(1-t,2)*this.from.x+2*t*(1-t)*e.x+Math.pow(t,2)*this.to.x,s=Math.pow(1-t,2)*this.from.y+2*t*(1-t)*e.y+Math.pow(t,2)*this.to.y;return{x:i,y:s}},s.prototype._findBorderPosition=function(t,e){var i,s,o,n,r,a=10,h=0,d=0,l=1,c=.2,p=this.to;for(1==t&&(p=this.from);l>=d&&a>h;){var u=.5*(d+l);if(i=this._pointOnBezier(u),s=Math.atan2(p.y-i.y,p.x-i.x),o=p.distanceToBorder(e,s),n=Math.sqrt(Math.pow(i.x-p.x,2)+Math.pow(i.y-p.y,2)),r=o-n,Math.abs(r)r?0==t?d=u:l=u:0==t?l=u:d=u,h++}return i.t=u,i},s.prototype._drawArrow=function(t){t.strokeStyle=this._getColor(),t.fillStyle=t.strokeStyle,t.lineWidth=this._getLineWidth();var e,i,s;if(this.from!=this.to){if(this._line(t),1==this.options.smoothCurves.enabled){var o=this._getViaCoordinates();s=this._findBorderPosition(!1,t);var n=this._pointOnBezier(Math.max(0,s.t-.1));e=Math.atan2(s.y-n.y,s.x-n.x)}else{e=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x);var r=this.to.x-this.from.x,a=this.to.y-this.from.y,h=Math.sqrt(r*r+a*a),d=this.to.distanceToBorder(t,e),l=(h-d)/h;s={},s.x=(1-l)*this.from.x+l*this.to.x,s.y=(1-l)*this.from.y+l*this.to.y}if(i=(10+5*this.options.width)*this.options.arrowScaleFactor,t.arrow(s.x,s.y,e,i),t.fill(),t.stroke(),this.label){var c;c=1==this.options.smoothCurves.enabled&&null!=o?this._pointOnBezier(.5):this._pointOnLine(.5),this._label(t,this.label,c.x,c.y)}}else{var p,u,m,f=this.from,g=.25*Math.max(100,this.physics.springLength);f.width||f.resize(t),f.width>f.height?(p=f.x+.5*f.width,u=f.y-g,m={x:p,y:f.y,angle:.9*Math.PI}):(p=f.x+g,u=f.y-.5*f.height,m={x:f.x,y:u,angle:.6*Math.PI}),t.beginPath(),t.arc(p,u,g,0,2*Math.PI,!1),t.stroke();var i=(10+5*this.options.width)*this.options.arrowScaleFactor;t.arrow(m.x,m.y,m.angle,i),t.fill(),t.stroke(),this.label&&(c=this._pointOnCircle(p,u,g,.5),this._label(t,this.label,c.x,c.y))}},s.prototype._getDistanceToEdge=function(t,e,i,s,o,n){var r=0;if(this.from!=this.to)if(1==this.options.smoothCurves.enabled){var a,h;if(1==this.options.smoothCurves.enabled&&1==this.options.smoothCurves.dynamic)a=this.via.x,h=this.via.y;else{var d=this._getViaCoordinates();a=d.x,h=d.y}var l,c,p,u,m,f,g,v=1e9;for(c=0;10>c;c++)p=.1*c,u=Math.pow(1-p,2)*t+2*p*(1-p)*a+Math.pow(p,2)*i,m=Math.pow(1-p,2)*e+2*p*(1-p)*h+Math.pow(p,2)*s,c>0&&(l=this._getDistanceToLine(f,g,u,m,o,n),v=v>l?l:v),f=u,g=m;r=v}else r=this._getDistanceToLine(t,e,i,s,o,n);else{var u,m,y,b,_=.25*this.physics.springLength,x=this.from;x.width>x.height?(u=x.x+.5*x.width,m=x.y-_):(u=x.x+_,m=x.y-.5*x.height),y=u-o,b=m-n,r=Math.abs(Math.sqrt(y*y+b*b)-_)}return this.labelDimensions.lefto&&this.labelDimensions.topn?0:r},s.prototype._getDistanceToLine=function(t,e,i,s,o,n){var r=i-t,a=s-e,h=r*r+a*a,d=((o-t)*r+(n-e)*a)/h;d>1?d=1:0>d&&(d=0);var l=t+d*r,c=e+d*a,p=l-o,u=c-n;return Math.sqrt(p*p+u*u)},s.prototype.setScale=function(t){this.networkScaleInv=1/t},s.prototype.select=function(){this.selected=!0},s.prototype.unselect=function(){this.selected=!1},s.prototype.positionBezierNode=function(){null!==this.via&&null!==this.from&&null!==this.to?(this.via.x=.5*(this.from.x+this.to.x),this.via.y=.5*(this.from.y+this.to.y)):(this.via.x=0,this.via.y=0)},s.prototype._drawControlNodes=function(t){if(1==this.controlNodesEnabled){if(null===this.controlNodes.from&&null===this.controlNodes.to){var e="edgeIdFrom:".concat(this.id),i="edgeIdTo:".concat(this.id),s={nodes:{group:"",radius:7,borderWidth:2,borderWidthSelected:2},physics:{damping:0},clustering:{maxNodeSizeIncrements:0,nodeScaling:{width:0,height:0,radius:0}}};this.controlNodes.from=new n({id:e,shape:"dot",color:{background:"#ff0000",border:"#3c3c3c",highlight:{background:"#07f968"}}},{},{},s),this.controlNodes.to=new n({id:i,shape:"dot",color:{background:"#ff0000",border:"#3c3c3c",highlight:{background:"#07f968"}}},{},{},s)}this.controlNodes.positions={},0==this.controlNodes.from.selected&&(this.controlNodes.positions.from=this.getControlNodeFromPosition(t),this.controlNodes.from.x=this.controlNodes.positions.from.x,this.controlNodes.from.y=this.controlNodes.positions.from.y),0==this.controlNodes.to.selected&&(this.controlNodes.positions.to=this.getControlNodeToPosition(t),this.controlNodes.to.x=this.controlNodes.positions.to.x,this.controlNodes.to.y=this.controlNodes.positions.to.y),this.controlNodes.from.draw(t),this.controlNodes.to.draw(t)}else this.controlNodes={from:null,to:null,positions:{}}},s.prototype._enableControlNodes=function(){this.fromBackup=this.from,this.toBackup=this.to,this.controlNodesEnabled=!0},s.prototype._disableControlNodes=function(){this.fromId=this.from.id,this.toId=this.to.id,this.fromId!=this.fromBackup.id?this.fromBackup.detachEdge(this):this.toId!=this.toBackup.id&&this.toBackup.detachEdge(this),this.fromBackup=null,this.toBackup=null,this.controlNodesEnabled=!1},s.prototype._getSelectedControlNode=function(t,e){var i=this.controlNodes.positions,s=Math.sqrt(Math.pow(t-i.from.x,2)+Math.pow(e-i.from.y,2)),o=Math.sqrt(Math.pow(t-i.to.x,2)+Math.pow(e-i.to.y,2));return 15>s?(this.connectedNode=this.from,this.from=this.controlNodes.from,this.controlNodes.from):15>o?(this.connectedNode=this.to,this.to=this.controlNodes.to,this.controlNodes.to):null},s.prototype._restoreControlNodes=function(){1==this.controlNodes.from.selected?(this.from=this.connectedNode,this.connectedNode=null,this.controlNodes.from.unselect()):1==this.controlNodes.to.selected&&(this.to=this.connectedNode,this.connectedNode=null,this.controlNodes.to.unselect())},s.prototype.getControlNodeFromPosition=function(t){var e;if(1==this.options.smoothCurves.enabled)e=this._findBorderPosition(!0,t);else{var i=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x),s=this.to.x-this.from.x,o=this.to.y-this.from.y,n=Math.sqrt(s*s+o*o),r=this.from.distanceToBorder(t,i+Math.PI),a=(n-r)/n;e={},e.x=a*this.from.x+(1-a)*this.to.x,e.y=a*this.from.y+(1-a)*this.to.y}return e},s.prototype.getControlNodeToPosition=function(t){var e;if(1==this.options.smoothCurves.enabled)e=this._findBorderPosition(!1,t);else{var i=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x),s=this.to.x-this.from.x,o=this.to.y-this.from.y,n=Math.sqrt(s*s+o*o),r=this.to.distanceToBorder(t,i),a=(n-r)/n;e={},e.x=(1-a)*this.from.x+a*this.to.x,e.y=(1-a)*this.from.y+a*this.to.y}return e},t.exports=s},function(t,e,i){function s(){this.clear(),this.defaultIndex=0}i(1);s.DEFAULT=[{border:"#2B7CE9",background:"#97C2FC",highlight:{border:"#2B7CE9",background:"#D2E5FF"},hover:{border:"#2B7CE9",background:"#D2E5FF"}},{border:"#FFA500",background:"#FFFF00",highlight:{border:"#FFA500",background:"#FFFFA3"},hover:{border:"#FFA500",background:"#FFFFA3"}},{border:"#FA0A10",background:"#FB7E81",highlight:{border:"#FA0A10",background:"#FFAFB1"},hover:{border:"#FA0A10",background:"#FFAFB1"}},{border:"#41A906",background:"#7BE141",highlight:{border:"#41A906",background:"#A1EC76"},hover:{border:"#41A906",background:"#A1EC76"}},{border:"#E129F0",background:"#EB7DF4",highlight:{border:"#E129F0",background:"#F0B3F5"},hover:{border:"#E129F0",background:"#F0B3F5"}},{border:"#7C29F0",background:"#AD85E4",highlight:{border:"#7C29F0",background:"#D3BDF0"},hover:{border:"#7C29F0",background:"#D3BDF0"}},{border:"#C37F00",background:"#FFA807",highlight:{border:"#C37F00",background:"#FFCA66"},hover:{border:"#C37F00",background:"#FFCA66"}},{border:"#4220FB",background:"#6E6EFD",highlight:{border:"#4220FB",background:"#9B9BFD"},hover:{border:"#4220FB",background:"#9B9BFD"}},{border:"#FD5A77",background:"#FFC0CB",highlight:{border:"#FD5A77",background:"#FFD1D9"},hover:{border:"#FD5A77",background:"#FFD1D9"}},{border:"#4AD63A",background:"#C2FABC",highlight:{border:"#4AD63A",background:"#E6FFE3"},hover:{border:"#4AD63A",background:"#E6FFE3"}}],s.prototype.clear=function(){this.groups={},this.groups.length=function(){var t=0;for(var e in this)this.hasOwnProperty(e)&&t++;return t}},s.prototype.get=function(t){var e=this.groups[t];if(void 0==e){var i=this.defaultIndex%s.DEFAULT.length;this.defaultIndex++,e={},e.color=s.DEFAULT[i],this.groups[t]=e}return e},s.prototype.add=function(t,e){return this.groups[t]=e,e},t.exports=s},function(t){function e(){this.images={},this.imageBroken={},this.callback=void 0}e.prototype.setOnloadCallback=function(t){this.callback=t},e.prototype.load=function(t,e){var i=this.images[t];if(void 0===i){var s=this;i=new Image,i.onload=function(){0==this.width&&(document.body.appendChild(this),this.width=this.offsetWidth,this.height=this.offsetHeight,document.body.removeChild(this)),s.callback&&(s.images[t]=i,s.callback(this))},i.onerror=function(){void 0===e?(console.error("Could not load image:",t),delete this.src,s.callback&&s.callback(this)):s.imageBroken[t]===!0?(console.error("Could not load brokenImage:",e),delete this.src,s.callback&&s.callback(this)):(this.src=e,s.imageBroken[t]=!0)},i.src=t}return i},t.exports=e},function(t,e,i){function s(t,e,i,s){var n=o.selectiveBridgeObject(["nodes"],s);this.options=n.nodes,this.selected=!1,this.hover=!1,this.edges=[],this.dynamicEdges=[],this.reroutedEdges={},this.fontDrawThreshold=3,this.id=void 0,this.allowedToMoveX=!1,this.allowedToMoveY=!1,this.xFixed=!1,this.yFixed=!1,this.horizontalAlignLeft=!0,this.verticalAlignTop=!0,this.baseRadiusValue=s.nodes.radius,this.radiusFixed=!1,this.level=-1,this.preassignedLevel=!1,this.hierarchyEnumerated=!1,this.labelDimensions={top:0,left:0,width:0,height:0,yLine:0},this.boundingBox={top:0,left:0,right:0,bottom:0},this.imagelist=e,this.grouplist=i,this.fx=0,this.fy=0,this.vx=0,this.vy=0,this.x=null,this.y=null,this.previousState={vx:0,vy:0,x:0,y:0},this.damping=s.physics.damping,this.fixedData={x:null,y:null},this.setProperties(t,n),this.resetCluster(),this.dynamicEdgesLength=0,this.clusterSession=0,this.clusterSizeWidthFactor=s.clustering.nodeScaling.width,this.clusterSizeHeightFactor=s.clustering.nodeScaling.height,this.clusterSizeRadiusFactor=s.clustering.nodeScaling.radius,this.maxNodeSizeIncrements=s.clustering.maxNodeSizeIncrements,this.growthIndicator=0,this.networkScaleInv=1,this.networkScale=1,this.canvasTopLeft={x:-300,y:-300},this.canvasBottomRight={x:300,y:300},this.parentEdgeId=null}var o=i(1);s.prototype.revertPosition=function(){this.x=this.previousState.x,this.y=this.previousState.y,this.vx=this.previousState.vx,this.vy=this.previousState.vy},s.prototype.resetCluster=function(){this.formationScale=void 0,this.clusterSize=1,this.containedNodes={},this.containedEdges={},this.clusterSessions=[]},s.prototype.attachEdge=function(t){-1==this.edges.indexOf(t)&&this.edges.push(t),-1==this.dynamicEdges.indexOf(t)&&this.dynamicEdges.push(t),this.dynamicEdgesLength=this.dynamicEdges.length},s.prototype.detachEdge=function(t){var e=this.edges.indexOf(t);-1!=e&&this.edges.splice(e,1),e=this.dynamicEdges.indexOf(t),-1!=e&&this.dynamicEdges.splice(e,1),this.dynamicEdgesLength=this.dynamicEdges.length},s.prototype.setProperties=function(t,e){if(t){var i=["borderWidth","borderWidthSelected","shape","image","brokenImage","radius","fontColor","fontSize","fontFace","fontFill","fontStrokeWidth","fontStrokeColor","group","mass"];if(o.selectiveDeepExtend(i,this.options,t),void 0!==t.id&&(this.id=t.id),void 0!==t.label&&(this.label=t.label,this.originalLabel=t.label),void 0!==t.title&&(this.title=t.title),void 0!==t.x&&(this.x=t.x),void 0!==t.y&&(this.y=t.y),void 0!==t.value&&(this.value=t.value),void 0!==t.level&&(this.level=t.level,this.preassignedLevel=!0),void 0!==t.horizontalAlignLeft&&(this.horizontalAlignLeft=t.horizontalAlignLeft),void 0!==t.verticalAlignTop&&(this.verticalAlignTop=t.verticalAlignTop),void 0!==t.triggerFunction&&(this.triggerFunction=t.triggerFunction),void 0===this.id)throw"Node must have an id";if("number"==typeof this.options.group||"string"==typeof this.options.group&&""!=this.options.group){var s=this.grouplist.get(this.options.group);o.deepExtend(this.options,s),this.options.color=o.parseColor(this.options.color)}if(void 0!==t.radius&&(this.baseRadiusValue=this.options.radius),void 0!==t.color&&(this.options.color=o.parseColor(t.color)),void 0!==this.options.image&&""!=this.options.image){if(!this.imagelist)throw"No imagelist provided";this.imageObj=this.imagelist.load(this.options.image,this.options.brokenImage)}switch(void 0!==t.allowedToMoveX?(this.xFixed=!t.allowedToMoveX,this.allowedToMoveX=t.allowedToMoveX):void 0!==t.x&&0==this.allowedToMoveX&&(this.xFixed=!0),void 0!==t.allowedToMoveY?(this.yFixed=!t.allowedToMoveY,this.allowedToMoveY=t.allowedToMoveY):void 0!==t.y&&0==this.allowedToMoveY&&(this.yFixed=!0),this.radiusFixed=this.radiusFixed||void 0!==t.radius,("image"===this.options.shape||"circularImage"===this.options.shape)&&(this.options.radiusMin=e.nodes.widthMin,this.options.radiusMax=e.nodes.widthMax),this.options.shape){case"database":this.draw=this._drawDatabase,this.resize=this._resizeDatabase;break;case"box":this.draw=this._drawBox,this.resize=this._resizeBox;break;case"circle":this.draw=this._drawCircle,this.resize=this._resizeCircle;break;case"ellipse":this.draw=this._drawEllipse,this.resize=this._resizeEllipse;break;case"image":this.draw=this._drawImage,this.resize=this._resizeImage;break;case"circularImage":this.draw=this._drawCircularImage,this.resize=this._resizeCircularImage;break;case"text":this.draw=this._drawText,this.resize=this._resizeText;break;case"dot":this.draw=this._drawDot,this.resize=this._resizeShape;break;case"square":this.draw=this._drawSquare,this.resize=this._resizeShape;break;case"triangle":this.draw=this._drawTriangle,this.resize=this._resizeShape;break;case"triangleDown":this.draw=this._drawTriangleDown,this.resize=this._resizeShape;break;case"star":this.draw=this._drawStar,this.resize=this._resizeShape;break;default:this.draw=this._drawEllipse,this.resize=this._resizeEllipse}this._reset()}},s.prototype.select=function(){this.selected=!0,this._reset()},s.prototype.unselect=function(){this.selected=!1,this._reset()},s.prototype.clearSizeCache=function(){this._reset()},s.prototype._reset=function(){this.width=void 0,this.height=void 0},s.prototype.getTitle=function(){return"function"==typeof this.title?this.title():this.title},s.prototype.distanceToBorder=function(t,e){var i=1;switch(this.width||this.resize(t),this.options.shape){case"circle":case"dot":return this.options.radius+i;case"ellipse":var s=this.width/2,o=this.height/2,n=Math.sin(e)*s,r=Math.cos(e)*o;return s*o/Math.sqrt(n*n+r*r);case"box":case"image":case"text":default:return this.width?Math.min(Math.abs(this.width/2/Math.cos(e)),Math.abs(this.height/2/Math.sin(e)))+i:0}},s.prototype._setForce=function(t,e){this.fx=t,this.fy=e},s.prototype._addForce=function(t,e){this.fx+=t,this.fy+=e},s.prototype.storeState=function(){this.previousState.x=this.x,this.previousState.y=this.y,this.previousState.vx=this.vx,this.previousState.vy=this.vy},s.prototype.discreteStep=function(t){if(this.storeState(),this.xFixed)this.fx=0,this.vx=0;else{var e=this.damping*this.vx,i=(this.fx-e)/this.options.mass;this.vx+=i*t,this.x+=this.vx*t}if(this.yFixed)this.fy=0,this.vy=0;else{var s=this.damping*this.vy,o=(this.fy-s)/this.options.mass;this.vy+=o*t,this.y+=this.vy*t}},s.prototype.discreteStepLimited=function(t,e){if(this.storeState(),this.xFixed)this.fx=0,this.vx=0;else{var i=this.damping*this.vx,s=(this.fx-i)/this.options.mass;this.vx+=s*t,this.vx=Math.abs(this.vx)>e?this.vx>0?e:-e:this.vx,this.x+=this.vx*t}if(this.yFixed)this.fy=0,this.vy=0;else{var o=this.damping*this.vy,n=(this.fy-o)/this.options.mass;this.vy+=n*t,this.vy=Math.abs(this.vy)>e?this.vy>0?e:-e:this.vy,this.y+=this.vy*t}},s.prototype.isFixed=function(){return this.xFixed&&this.yFixed},s.prototype.isMoving=function(t){var e=Math.sqrt(Math.pow(this.vx,2)+Math.pow(this.vy,2));return e>t},s.prototype.isSelected=function(){return this.selected},s.prototype.getValue=function(){return this.value},s.prototype.getDistance=function(t,e){var i=this.x-t,s=this.y-e;return Math.sqrt(i*i+s*s)},s.prototype.setValueRange=function(t,e){if(!this.radiusFixed&&void 0!==this.value)if(e==t)this.options.radius=(this.options.radiusMin+this.options.radiusMax)/2;else{var i=(this.options.radiusMax-this.options.radiusMin)/(e-t);this.options.radius=(this.value-t)*i+this.options.radiusMin}this.baseRadiusValue=this.options.radius},s.prototype.draw=function(){throw"Draw method not initialized for node"},s.prototype.resize=function(){throw"Resize method not initialized for node"},s.prototype.isOverlappingWith=function(t){return this.leftt.left&&this.topt.top},s.prototype._resizeImage=function(){if(!this.width||!this.height){var t,e;if(this.value){this.options.radius=this.baseRadiusValue;var i=this.imageObj.height/this.imageObj.width;void 0!==i?(t=this.options.radius||this.imageObj.width,e=this.options.radius*i||this.imageObj.height):(t=0,e=0)}else t=this.imageObj.width,e=this.imageObj.height;this.width=t,this.height=e,this.growthIndicator=0,this.width>0&&this.height>0&&(this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-t)}},s.prototype._drawImageAtPosition=function(t){if(0!=this.imageObj.width){if(this.clusterSize>1){var e=this.clusterSize>1?10:0;e*=this.networkScaleInv,e=Math.min(.2*this.width,e),t.globalAlpha=.5,t.drawImage(this.imageObj,this.left-e,this.top-e,this.width+2*e,this.height+2*e)}t.globalAlpha=1,t.drawImage(this.imageObj,this.left,this.top,this.width,this.height)}},s.prototype._drawImageLabel=function(t){var e,i=0;if(this.height){i=this.height/2;var s=this.getTextSize(t);s.lineCount>=1&&(i+=s.height/2,i+=3)}e=this.y+i,this._label(t,this.label,this.x,e,void 0)},s.prototype._drawImage=function(t){this._resizeImage(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2,this._drawImageAtPosition(t),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._drawImageLabel(t),this.boundingBox.left=Math.min(this.boundingBox.left,this.labelDimensions.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelDimensions.left+this.labelDimensions.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelDimensions.height)},s.prototype._resizeCircularImage=function(t){if(this.imageObj.src&&this.imageObj.width&&this.imageObj.height)this._swapToImageResizeWhenImageLoaded&&(this.width=0,this.height=0,delete this._swapToImageResizeWhenImageLoaded),this._resizeImage(t);else if(!this.width){var e=2*this.options.radius;this.width=e,this.height=e,this.options.radius+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.options.radius-.5*e,this._swapToImageResizeWhenImageLoaded=!0}},s.prototype._drawCircularImage=function(t){this._resizeCircularImage(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var e=this.left+this.width/2,i=this.top+this.height/2,s=Math.abs(this.height/2);this._drawRawCircle(t,e,i,s),t.save(),t.circle(this.x,this.y,s),t.stroke(),t.clip(),this._drawImageAtPosition(t),t.restore(),this.boundingBox.top=this.y-this.options.radius,this.boundingBox.left=this.x-this.options.radius,this.boundingBox.right=this.x+this.options.radius,this.boundingBox.bottom=this.y+this.options.radius,this._drawImageLabel(t),this.boundingBox.left=Math.min(this.boundingBox.left,this.labelDimensions.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelDimensions.left+this.labelDimensions.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelDimensions.height)},s.prototype._resizeBox=function(t){if(!this.width){var e=5,i=this.getTextSize(t);this.width=i.width+2*e,this.height=i.height+2*e,this.width+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.growthIndicator=this.width-(i.width+2*e)}},s.prototype._drawBox=function(t){this._resizeBox(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var e=2.5,i=this.options.borderWidth,s=this.options.borderWidthSelected||2*this.options.borderWidth;t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.roundRect(this.left-2*t.lineWidth,this.top-2*t.lineWidth,this.width+4*t.lineWidth,this.height+4*t.lineWidth,this.options.radius),t.stroke()),t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.roundRect(this.left,this.top,this.width,this.height,this.options.radius),t.fill(),t.stroke(),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._label(t,this.label,this.x,this.y)},s.prototype._resizeDatabase=function(t){if(!this.width){var e=5,i=this.getTextSize(t),s=i.width+2*e;this.width=s,this.height=s,this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-s}},s.prototype._drawDatabase=function(t){this._resizeDatabase(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var e=2.5,i=this.options.borderWidth,s=this.options.borderWidthSelected||2*this.options.borderWidth;t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.database(this.x-this.width/2-2*t.lineWidth,this.y-.5*this.height-2*t.lineWidth,this.width+4*t.lineWidth,this.height+4*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.database(this.x-this.width/2,this.y-.5*this.height,this.width,this.height),t.fill(),t.stroke(),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._label(t,this.label,this.x,this.y) +},s.prototype._resizeCircle=function(t){if(!this.width){var e=5,i=this.getTextSize(t),s=Math.max(i.width,i.height)+2*e;this.options.radius=s/2,this.width=s,this.height=s,this.options.radius+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.options.radius-.5*s}},s.prototype._drawRawCircle=function(t,e,i,s){var o=2.5,n=this.options.borderWidth,r=this.options.borderWidthSelected||2*this.options.borderWidth;t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?r:n)+(this.clusterSize>1?o:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.circle(e,i,s+2*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?r:n)+(this.clusterSize>1?o:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.circle(this.x,this.y,s),t.fill(),t.stroke()},s.prototype._drawCircle=function(t){this._resizeCircle(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2,this._drawRawCircle(t,this.x,this.y,this.options.radius),this.boundingBox.top=this.y-this.options.radius,this.boundingBox.left=this.x-this.options.radius,this.boundingBox.right=this.x+this.options.radius,this.boundingBox.bottom=this.y+this.options.radius,this._label(t,this.label,this.x,this.y)},s.prototype._resizeEllipse=function(t){if(!this.width){var e=this.getTextSize(t);this.width=1.5*e.width,this.height=2*e.height,this.width1&&(t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.ellipse(this.left-2*t.lineWidth,this.top-2*t.lineWidth,this.width+4*t.lineWidth,this.height+4*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.ellipse(this.left,this.top,this.width,this.height),t.fill(),t.stroke(),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._label(t,this.label,this.x,this.y)},s.prototype._drawDot=function(t){this._drawShape(t,"circle")},s.prototype._drawTriangle=function(t){this._drawShape(t,"triangle")},s.prototype._drawTriangleDown=function(t){this._drawShape(t,"triangleDown")},s.prototype._drawSquare=function(t){this._drawShape(t,"square")},s.prototype._drawStar=function(t){this._drawShape(t,"star")},s.prototype._resizeShape=function(){if(!this.width){this.options.radius=this.baseRadiusValue;var t=2*this.options.radius;this.width=t,this.height=t,this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-t}},s.prototype._drawShape=function(t,e){this._resizeShape(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var i=2.5,s=this.options.borderWidth,o=this.options.borderWidthSelected||2*this.options.borderWidth,n=2;switch(e){case"dot":n=2;break;case"square":n=2;break;case"triangle":n=3;break;case"triangleDown":n=3;break;case"star":n=4}t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?o:s)+(this.clusterSize>1?i:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t[e](this.x,this.y,this.options.radius+n*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?o:s)+(this.clusterSize>1?i:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t[e](this.x,this.y,this.options.radius),t.fill(),t.stroke(),this.boundingBox.top=this.y-this.options.radius,this.boundingBox.left=this.x-this.options.radius,this.boundingBox.right=this.x+this.options.radius,this.boundingBox.bottom=this.y+this.options.radius,this.label&&(this._label(t,this.label,this.x,this.y+this.height/2,void 0,"hanging",!0),this.boundingBox.left=Math.min(this.boundingBox.left,this.labelDimensions.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelDimensions.left+this.labelDimensions.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelDimensions.height))},s.prototype._resizeText=function(t){if(!this.width){var e=5,i=this.getTextSize(t);this.width=i.width+2*e,this.height=i.height+2*e,this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-(i.width+2*e)}},s.prototype._drawText=function(t){this._resizeText(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2,this._label(t,this.label,this.x,this.y),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height},s.prototype._label=function(t,e,i,s,o,n,r){if(e&&Number(this.options.fontSize)*this.networkScale>this.fontDrawThreshold){t.font=(this.selected?"bold ":"")+this.options.fontSize+"px "+this.options.fontFace;var a=e.split("\n"),h=a.length,d=Number(this.options.fontSize),l=s+(1-h)/2*d;1==r&&(l=s+(1-h)/(2*d));for(var c=t.measureText(a[0]).width,p=1;h>p;p++){var u=t.measureText(a[p]).width;c=u>c?u:c}var m=this.options.fontSize*h,f=i-c/2,g=s-m/2;"hanging"==n&&(g+=.5*d,g+=4,l+=4),this.labelDimensions={top:g,left:f,width:c,height:m,yLine:l},void 0!==this.options.fontFill&&null!==this.options.fontFill&&"none"!==this.options.fontFill&&(t.fillStyle=this.options.fontFill,t.fillRect(f,g,c,m)),t.fillStyle=this.options.fontColor||"black",t.textAlign=o||"center",t.textBaseline=n||"middle",this.options.fontStrokeWidth>0&&(t.lineWidth=this.options.fontStrokeWidth,t.strokeStyle=this.options.fontStrokeColor,t.lineJoin="round");for(var p=0;h>p;p++)this.options.fontStrokeWidth&&t.strokeText(a[p],i,l),t.fillText(a[p],i,l),l+=d}},s.prototype.getTextSize=function(t){if(void 0!==this.label){t.font=(this.selected?"bold ":"")+this.options.fontSize+"px "+this.options.fontFace;for(var e=this.label.split("\n"),i=(Number(this.options.fontSize)+4)*e.length,s=0,o=0,n=e.length;n>o;o++)s=Math.max(s,t.measureText(e[o]).width);return{width:s,height:i,lineCount:e.length}}return{width:0,height:0,lineCount:0}},s.prototype.inArea=function(){return void 0!==this.width?this.x+this.width*this.networkScaleInv>=this.canvasTopLeft.x&&this.x-this.width*this.networkScaleInv=this.canvasTopLeft.y&&this.y-this.height*this.networkScaleInv=this.canvasTopLeft.x&&this.x=this.canvasTopLeft.y&&this.ys&&(n=s-e-this.padding),no&&(r=o-i-this.padding),ri;i++)if(e.id===r.nodes[i].id){o=r.nodes[i];break}for(o||(o={id:e.id},t.node&&(o.attr=a(o.attr,t.node))),i=n.length-1;i>=0;i--){var h=n[i];h.nodes||(h.nodes=[]),-1==h.nodes.indexOf(o)&&h.nodes.push(o)}e.attr&&(o.attr=a(o.attr,e.attr))}function l(t,e){if(t.edges||(t.edges=[]),t.edges.push(e),t.edge){var i=a({},t.edge);e.attr=a(i,e.attr)}}function c(t,e,i,s,o){var n={from:e,to:i,type:s};return t.edge&&(n.attr=a({},t.edge)),n.attr=a(n.attr||{},o),n}function p(){for(N=D.NULL,k="";" "==E||" "==E||"\n"==E||"\r"==E;)o();do{var t=!1;if("#"==E){for(var e=O-1;" "==T.charAt(e)||" "==T.charAt(e);)e--;if("\n"==T.charAt(e)||""==T.charAt(e)){for(;""!=E&&"\n"!=E;)o();t=!0}}if("/"==E&&"/"==n()){for(;""!=E&&"\n"!=E;)o();t=!0}if("/"==E&&"*"==n()){for(;""!=E;){if("*"==E&&"/"==n()){o(),o();break}o()}t=!0}for(;" "==E||" "==E||"\n"==E||"\r"==E;)o()}while(t);if(""==E)return void(N=D.DELIMITER);var i=E+n();if(C[i])return N=D.DELIMITER,k=i,o(),void o();if(C[E])return N=D.DELIMITER,k=E,void o();if(r(E)||"-"==E){for(k+=E,o();r(E);)k+=E,o();return"false"==k?k=!1:"true"==k?k=!0:isNaN(Number(k))||(k=Number(k)),void(N=D.IDENTIFIER)}if('"'==E){for(o();""!=E&&('"'!=E||'"'==E&&'"'==n());)k+=E,'"'==E&&o(),o();if('"'!=E)throw x('End of string " expected');return o(),void(N=D.IDENTIFIER)}for(N=D.UNKNOWN;""!=E;)k+=E,o();throw new SyntaxError('Syntax error in part "'+w(k,30)+'"')}function u(){var t={};if(s(),p(),"strict"==k&&(t.strict=!0,p()),("graph"==k||"digraph"==k)&&(t.type=k,p()),N==D.IDENTIFIER&&(t.id=k,p()),"{"!=k)throw x("Angle bracket { expected");if(p(),m(t),"}"!=k)throw x("Angle bracket } expected");if(p(),""!==k)throw x("End of file expected");return p(),delete t.node,delete t.edge,delete t.graph,t}function m(t){for(;""!==k&&"}"!=k;)f(t),";"==k&&p()}function f(t){var e=g(t);if(e)return void b(t,e);var i=v(t);if(!i){if(N!=D.IDENTIFIER)throw x("Identifier expected");var s=k;if(p(),"="==k){if(p(),N!=D.IDENTIFIER)throw x("Identifier expected");t[s]=k,p()}else y(t,s)}}function g(t){var e=null;if("subgraph"==k&&(e={},e.type="subgraph",p(),N==D.IDENTIFIER&&(e.id=k,p())),"{"==k){if(p(),e||(e={}),e.parent=t,e.node=t.node,e.edge=t.edge,e.graph=t.graph,m(e),"}"!=k)throw x("Angle bracket } expected");p(),delete e.node,delete e.edge,delete e.graph,delete e.parent,t.subgraphs||(t.subgraphs=[]),t.subgraphs.push(e)}return e}function v(t){return"node"==k?(p(),t.node=_(),"node"):"edge"==k?(p(),t.edge=_(),"edge"):"graph"==k?(p(),t.graph=_(),"graph"):null}function y(t,e){var i={id:e},s=_();s&&(i.attr=s),d(t,i),b(t,e)}function b(t,e){for(;"->"==k||"--"==k;){var i,s=k;p();var o=g(t);if(o)i=o;else{if(N!=D.IDENTIFIER)throw x("Identifier or subgraph expected");i=k,d(t,{id:i}),p()}var n=_(),r=c(t,e,i,s,n);l(t,r),e=i}}function _(){for(var t=null;"["==k;){for(p(),t={};""!==k&&"]"!=k;){if(N!=D.IDENTIFIER)throw x("Attribute name expected");var e=k;if(p(),"="!=k)throw x("Equal sign = expected");if(p(),N!=D.IDENTIFIER)throw x("Attribute value expected");var i=k;h(t,e,i),p(),","==k&&p()}if("]"!=k)throw x("Bracket ] expected");p()}return t}function x(t){return new SyntaxError(t+', got "'+w(k,30)+'" (char '+O+")")}function w(t,e){return t.length<=e?t:t.substr(0,27)+"..."}function S(t,e,i){Array.isArray(t)?t.forEach(function(t){Array.isArray(e)?e.forEach(function(e){i(t,e)}):i(t,e)}):Array.isArray(e)?e.forEach(function(e){i(t,e)}):i(t,e)}function M(t){var e=i(t),s={nodes:[],edges:[],options:{}};if(e.nodes&&e.nodes.forEach(function(t){var e={id:t.id,label:String(t.label||t.id)};a(e,t.attr),e.image&&(e.shape="image"),s.nodes.push(e)}),e.edges){var o=function(t){var e={from:t.from,to:t.to};return a(e,t.attr),e.style="->"==t.type?"arrow":"line",e};e.edges.forEach(function(t){var e,i;e=t.from instanceof Object?t.from.nodes:{id:t.from},i=t.to instanceof Object?t.to.nodes:{id:t.to},t.from instanceof Object&&t.from.edges&&t.from.edges.forEach(function(t){var e=o(t);s.edges.push(e)}),S(e,i,function(e,i){var n=c(s,e.id,i.id,t.type,t.attr),r=o(n);s.edges.push(r)}),t.to instanceof Object&&t.to.edges&&t.to.edges.forEach(function(t){var e=o(t);s.edges.push(e)})})}return e.attr&&(s.options=e.attr),s}var D={NULL:0,DELIMITER:1,IDENTIFIER:2,UNKNOWN:3},C={"{":!0,"}":!0,"[":!0,"]":!0,";":!0,"=":!0,",":!0,"->":!0,"--":!0},T="",O=0,E="",k="",N=D.NULL,I=/[a-zA-Z_0-9.:#]/;e.parseDOT=i,e.DOTToGraph=M},function(t,e){function i(t,e){var i=[],s=[];this.options={edges:{inheritColor:!0},nodes:{allowedToMove:!1,parseColor:!1}},void 0!==e&&(this.options.nodes.allowedToMove=e.allowedToMove|!1,this.options.nodes.parseColor=e.parseColor|!1,this.options.edges.inheritColor=e.inheritColor|!0);for(var o=t.edges,n=t.nodes,r=0;r=s&&(s=864e5),e=new Date(e.valueOf()-.05*s),i=new Date(i.valueOf()+.05*s)}return{start:e,end:i}},s.prototype.setWindow=function(t,e,i){var s=i&&void 0!==i.animate?i.animate:!0;if(1==arguments.length){var o=arguments[0];this.range.setRange(o.start,o.end,s)}else this.range.setRange(t,e,s)},s.prototype.moveTo=function(t,e){var i=this.range.end-this.range.start,s=r.convert(t,"Date").valueOf(),o=s-i/2,n=s+i/2,a=e&&void 0!==e.animate?e.animate:!0;this.range.setRange(o,n,a)},s.prototype.getWindow=function(){var t=this.range.getRange();return{start:new Date(t.start),end:new Date(t.end)}},s.prototype.redraw=function(){var t=!1,e=this.options,i=this.props,s=this.dom;if(s){h.updateHiddenDates(this.body,this.options.hiddenDates),"top"==e.orientation?(r.addClassName(s.root,"top"),r.removeClassName(s.root,"bottom")):(r.removeClassName(s.root,"top"),r.addClassName(s.root,"bottom")),s.root.style.maxHeight=r.option.asSize(e.maxHeight,""),s.root.style.minHeight=r.option.asSize(e.minHeight,""),s.root.style.width=r.option.asSize(e.width,""),i.border.left=(s.centerContainer.offsetWidth-s.centerContainer.clientWidth)/2,i.border.right=i.border.left,i.border.top=(s.centerContainer.offsetHeight-s.centerContainer.clientHeight)/2,i.border.bottom=i.border.top;var o=s.root.offsetHeight-s.root.clientHeight,n=s.root.offsetWidth-s.root.clientWidth;0===s.centerContainer.clientHeight&&(i.border.left=i.border.top,i.border.right=i.border.left),0===s.root.clientHeight&&(n=o),i.center.height=s.center.offsetHeight,i.left.height=s.left.offsetHeight,i.right.height=s.right.offsetHeight,i.top.height=s.top.clientHeight||-i.border.top,i.bottom.height=s.bottom.clientHeight||-i.border.bottom;var a=Math.max(i.left.height,i.center.height,i.right.height),d=i.top.height+a+i.bottom.height+o+i.border.top+i.border.bottom;s.root.style.height=r.option.asSize(e.height,d+"px"),i.root.height=s.root.offsetHeight,i.background.height=i.root.height-o;var l=i.root.height-i.top.height-i.bottom.height-o;i.centerContainer.height=l,i.leftContainer.height=l,i.rightContainer.height=i.leftContainer.height,i.root.width=s.root.offsetWidth,i.background.width=i.root.width-n,i.left.width=s.leftContainer.clientWidth||-i.border.left,i.leftContainer.width=i.left.width,i.right.width=s.rightContainer.clientWidth||-i.border.right,i.rightContainer.width=i.right.width;var c=i.root.width-i.left.width-i.right.width-n;i.center.width=c,i.centerContainer.width=c,i.top.width=c,i.bottom.width=c,s.background.style.height=i.background.height+"px",s.backgroundVertical.style.height=i.background.height+"px",s.backgroundHorizontal.style.height=i.centerContainer.height+"px",s.centerContainer.style.height=i.centerContainer.height+"px",s.leftContainer.style.height=i.leftContainer.height+"px",s.rightContainer.style.height=i.rightContainer.height+"px",s.background.style.width=i.background.width+"px",s.backgroundVertical.style.width=i.centerContainer.width+"px",s.backgroundHorizontal.style.width=i.background.width+"px",s.centerContainer.style.width=i.center.width+"px",s.top.style.width=i.top.width+"px",s.bottom.style.width=i.bottom.width+"px",s.background.style.left="0",s.background.style.top="0",s.backgroundVertical.style.left=i.left.width+i.border.left+"px",s.backgroundVertical.style.top="0",s.backgroundHorizontal.style.left="0",s.backgroundHorizontal.style.top=i.top.height+"px",s.centerContainer.style.left=i.left.width+"px",s.centerContainer.style.top=i.top.height+"px",s.leftContainer.style.left="0",s.leftContainer.style.top=i.top.height+"px",s.rightContainer.style.left=i.left.width+i.center.width+"px",s.rightContainer.style.top=i.top.height+"px",s.top.style.left=i.left.width+"px",s.top.style.top="0",s.bottom.style.left=i.left.width+"px",s.bottom.style.top=i.top.height+i.centerContainer.height+"px",this._updateScrollTop();var p=this.props.scrollTop;"bottom"==e.orientation&&(p+=Math.max(this.props.centerContainer.height-this.props.center.height-this.props.border.top-this.props.border.bottom,0)),s.center.style.left="0",s.center.style.top=p+"px",s.left.style.left="0",s.left.style.top=p+"px",s.right.style.left="0",s.right.style.top=p+"px";var u=0==this.props.scrollTop?"hidden":"",m=this.props.scrollTop==this.props.scrollTopMin?"hidden":"";if(s.shadowTop.style.visibility=u,s.shadowBottom.style.visibility=m,s.shadowTopLeft.style.visibility=u,s.shadowBottomLeft.style.visibility=m,s.shadowTopRight.style.visibility=u,s.shadowBottomRight.style.visibility=m,this.components.forEach(function(e){t=e.redraw()||t}),t){var f=3;this.redrawCount0&&(this.props.scrollTop=0),this.props.scrollTops;s++){var o=s%2===0?1.3*i:.5*i;this.lineTo(t+o*Math.sin(2*s*Math.PI/10),e-o*Math.cos(2*s*Math.PI/10))}this.closePath()},CanvasRenderingContext2D.prototype.roundRect=function(t,e,i,s,o){var n=Math.PI/180;0>i-2*o&&(o=i/2),0>s-2*o&&(o=s/2),this.beginPath(),this.moveTo(t+o,e),this.lineTo(t+i-o,e),this.arc(t+i-o,e+o,o,270*n,360*n,!1),this.lineTo(t+i,e+s-o),this.arc(t+i-o,e+s-o,o,0,90*n,!1),this.lineTo(t+o,e+s),this.arc(t+o,e+s-o,o,90*n,180*n,!1),this.lineTo(t,e+o),this.arc(t+o,e+o,o,180*n,270*n,!1)},CanvasRenderingContext2D.prototype.ellipse=function(t,e,i,s){var o=.5522848,n=i/2*o,r=s/2*o,a=t+i,h=e+s,d=t+i/2,l=e+s/2; +this.beginPath(),this.moveTo(t,l),this.bezierCurveTo(t,l-r,d-n,e,d,e),this.bezierCurveTo(d+n,e,a,l-r,a,l),this.bezierCurveTo(a,l+r,d+n,h,d,h),this.bezierCurveTo(d-n,h,t,l+r,t,l)},CanvasRenderingContext2D.prototype.database=function(t,e,i,s){var o=1/3,n=i,r=s*o,a=.5522848,h=n/2*a,d=r/2*a,l=t+n,c=e+r,p=t+n/2,u=e+r/2,m=e+(s-r/2),f=e+s;this.beginPath(),this.moveTo(l,u),this.bezierCurveTo(l,u+d,p+h,c,p,c),this.bezierCurveTo(p-h,c,t,u+d,t,u),this.bezierCurveTo(t,u-d,p-h,e,p,e),this.bezierCurveTo(p+h,e,l,u-d,l,u),this.lineTo(l,m),this.bezierCurveTo(l,m+d,p+h,f,p,f),this.bezierCurveTo(p-h,f,t,m+d,t,m),this.lineTo(t,u)},CanvasRenderingContext2D.prototype.arrow=function(t,e,i,s){var o=t-s*Math.cos(i),n=e-s*Math.sin(i),r=t-.9*s*Math.cos(i),a=e-.9*s*Math.sin(i),h=o+s/3*Math.cos(i+.5*Math.PI),d=n+s/3*Math.sin(i+.5*Math.PI),l=o+s/3*Math.cos(i-.5*Math.PI),c=n+s/3*Math.sin(i-.5*Math.PI);this.beginPath(),this.moveTo(t,e),this.lineTo(h,d),this.lineTo(r,a),this.lineTo(l,c),this.closePath()},CanvasRenderingContext2D.prototype.dashedLine=function(t,e,i,s,o){o||(o=[10,5]),0==p&&(p=.001);var n=o.length;this.moveTo(t,e);for(var r=i-t,a=s-e,h=a/r,d=Math.sqrt(r*r+a*a),l=0,c=!0;d>=.1;){var p=o[l++%n];p>d&&(p=d);var u=Math.sqrt(p*p/(1+h*h));0>r&&(u=-u),t+=u,e+=h*u,this[c?"lineTo":"moveTo"](t,e),d-=p,c=!c}})},function(t,e,i){function s(t,e){this.groupId=t,this.options=e}var o=i(2),n=i(53);s.prototype.getYRange=function(t){for(var e=t[0].y,i=t[0].y,s=0;st[s].y?t[s].y:e,i=i0){var r,a,h=Number(i.svg.style.height.replace("px",""));if(r=o.getSVGElement("path",i.svgElements,i.svg),r.setAttributeNS(null,"class",e.className),void 0!==e.style&&r.setAttributeNS(null,"style",e.style),a=1==e.options.catmullRom.enabled?s._catmullRom(t,e):s._linear(t),1==e.options.shaded.enabled){var d,l=o.getSVGElement("path",i.svgElements,i.svg);d="top"==e.options.shaded.orientation?"M"+t[0].x+",0 "+a+"L"+t[t.length-1].x+",0":"M"+t[0].x+","+h+" "+a+"L"+t[t.length-1].x+","+h,l.setAttributeNS(null,"class",e.className+" fill"),void 0!==e.options.shaded.style&&l.setAttributeNS(null,"style",e.options.shaded.style),l.setAttributeNS(null,"d",d)}r.setAttributeNS(null,"d","M"+a),1==e.options.drawPoints.enabled&&n.draw(t,e,i)}},s._catmullRomUniform=function(t){for(var e,i,s,o,n,r,a=Math.round(t[0].x)+","+Math.round(t[0].y)+" ",h=1/6,d=t.length,l=0;d-1>l;l++)e=0==l?t[0]:t[l-1],i=t[l],s=t[l+1],o=d>l+2?t[l+2]:s,n={x:(-e.x+6*i.x+s.x)*h,y:(-e.y+6*i.y+s.y)*h},r={x:(i.x+6*s.x-o.x)*h,y:(i.y+6*s.y-o.y)*h},a+="C"+n.x+","+n.y+" "+r.x+","+r.y+" "+s.x+","+s.y+" ";return a},s._catmullRom=function(t,e){var i=e.options.catmullRom.alpha;if(0==i||void 0===i)return this._catmullRomUniform(t);for(var s,o,n,r,a,h,d,l,c,p,u,m,f,g,v,y,b,_,x,w=Math.round(t[0].x)+","+Math.round(t[0].y)+" ",S=t.length,M=0;S-1>M;M++)s=0==M?t[0]:t[M-1],o=t[M],n=t[M+1],r=S>M+2?t[M+2]:n,d=Math.sqrt(Math.pow(s.x-o.x,2)+Math.pow(s.y-o.y,2)),l=Math.sqrt(Math.pow(o.x-n.x,2)+Math.pow(o.y-n.y,2)),c=Math.sqrt(Math.pow(n.x-r.x,2)+Math.pow(n.y-r.y,2)),g=Math.pow(c,i),y=Math.pow(c,2*i),v=Math.pow(l,i),b=Math.pow(l,2*i),x=Math.pow(d,i),_=Math.pow(d,2*i),p=2*_+3*x*v+b,u=2*y+3*g*v+b,m=3*x*(x+v),m>0&&(m=1/m),f=3*g*(g+v),f>0&&(f=1/f),a={x:(-b*s.x+p*o.x+_*n.x)*m,y:(-b*s.y+p*o.y+_*n.y)*m},h={x:(y*o.x+u*n.x-b*r.x)*f,y:(y*o.y+u*n.y-b*r.y)*f},0==a.x&&0==a.y&&(a=o),0==h.x&&0==h.y&&(h=n),w+="C"+a.x+","+a.y+" "+h.x+","+h.y+" "+n.x+","+n.y+" ";return w},s._linear=function(t){for(var e="",i=0;it[s].y?t[s].y:e,i=i0&&(n=Math.min(n,Math.abs(c[d-1].x-r))),a=s._getSafeDrawData(n,h,m);else{var g=d+(p[r].amount-p[r].resolved),v=d-(p[r].resolved+1);g0&&(n=Math.min(n,Math.abs(c[v].x-r))),a=s._getSafeDrawData(n,h,m),p[r].resolved+=1,"stack"==h.options.barChart.handleOverlap?(f=p[r].accumulated,p[r].accumulated+=h.zeroPosition-c[d].y):"sideBySide"==h.options.barChart.handleOverlap&&(a.width=a.width/p[r].amount,a.offset+=p[r].resolved*a.width-.5*a.width*(p[r].amount+1),"left"==h.options.barChart.align?a.offset-=.5*a.width:"right"==h.options.barChart.align&&(a.offset+=.5*a.width))}o.drawBar(c[d].x+a.offset,c[d].y-f,a.width,h.zeroPosition-c[d].y,h.className+" bar",i.svgElements,i.svg),1==h.options.drawPoints.enabled&&o.drawPoint(c[d].x+a.offset,c[d].y,h,i.svgElements,i.svg)}},s._getDataIntersections=function(t,e){for(var i,s=0;s0&&(i=Math.min(i,Math.abs(e[s-1].x-e[s].x))),0==i&&(void 0===t[e[s].x]&&(t[e[s].x]={amount:0,resolved:0,accumulated:0}),t[e[s].x].amount+=1)},s._getSafeDrawData=function(t,e,i){var s,o;return t0?(s=i>t?i:t,o=0,"left"==e.options.barChart.align?o-=.5*t:"right"==e.options.barChart.align&&(o+=.5*t)):(s=e.options.barChart.width,o=0,"left"==e.options.barChart.align?o-=.5*e.options.barChart.width:"right"==e.options.barChart.align&&(o+=.5*e.options.barChart.width)),{width:s,offset:o}},s.getStackedBarYRange=function(t,e,i,o,n){if(t.length>0){t.sort(function(t,e){return t.x==e.x?t.groupId-e.groupId:t.x-e.x});var r={};s._getDataIntersections(r,t),e[o]=s._getStackedBarYRange(r,t),e[o].yAxisOrientation=n,i.push(o)}},s._getStackedBarYRange=function(t,e){for(var i,s=e[0].y,o=e[0].y,n=0;ne[n].y?e[n].y:s,o=ot[r].accumulated?t[r].accumulated:s,o=ot[s].y?t[s].y:e,i=is;++s)i[s].apply(this,e)}return this},e.prototype.listeners=function(t){return this._callbacks=this._callbacks||{},this._callbacks[t]||[]},e.prototype.hasListeners=function(t){return!!this.listeners(t).length}},function(t,e,i){var s;(function(t,o){(function(n){function r(t,e,i){switch(arguments.length){case 2:return null!=t?t:e;case 3:return null!=t?t:null!=e?e:i;default:throw new Error("Implement me")}}function a(t,e){return Ie.call(t,e)}function h(){return{empty:!1,unusedTokens:[],unusedInput:[],overflow:-2,charsLeftOver:0,nullInput:!1,invalidMonth:null,invalidFormat:!1,userInvalidated:!1,iso:!1}}function d(t){Ce.suppressDeprecationWarnings===!1&&"undefined"!=typeof console&&console.warn&&console.warn("Deprecation warning: "+t)}function l(t,e){var i=!0;return b(function(){return i&&(d(t),i=!1),e.apply(this,arguments)},e)}function c(t,e){Si[t]||(d(e),Si[t]=!0)}function p(t,e){return function(i){return w(t.call(this,i),e)}}function u(t,e){return function(i){return this.localeData().ordinal(t.call(this,i),e)}}function m(t,e){var i,s,o=12*(e.year()-t.year())+(e.month()-t.month()),n=t.clone().add(o,"months");return 0>e-n?(i=t.clone().add(o-1,"months"),s=(e-n)/(n-i)):(i=t.clone().add(o+1,"months"),s=(e-n)/(i-n)),-(o+s)}function f(t,e,i){var s;return null==i?e:null!=t.meridiemHour?t.meridiemHour(e,i):null!=t.isPM?(s=t.isPM(i),s&&12>e&&(e+=12),s||12!==e||(e=0),e):e}function g(){}function v(t,e){e!==!1&&F(t),_(this,t),this._d=new Date(+t._d),Di===!1&&(Di=!0,Ce.updateOffset(this),Di=!1)}function y(t){var e=N(t),i=e.year||0,s=e.quarter||0,o=e.month||0,n=e.week||0,r=e.day||0,a=e.hour||0,h=e.minute||0,d=e.second||0,l=e.millisecond||0;this._milliseconds=+l+1e3*d+6e4*h+36e5*a,this._days=+r+7*n,this._months=+o+3*s+12*i,this._data={},this._locale=Ce.localeData(),this._bubble()}function b(t,e){for(var i in e)a(e,i)&&(t[i]=e[i]);return a(e,"toString")&&(t.toString=e.toString),a(e,"valueOf")&&(t.valueOf=e.valueOf),t}function _(t,e){var i,s,o;if("undefined"!=typeof e._isAMomentObject&&(t._isAMomentObject=e._isAMomentObject),"undefined"!=typeof e._i&&(t._i=e._i),"undefined"!=typeof e._f&&(t._f=e._f),"undefined"!=typeof e._l&&(t._l=e._l),"undefined"!=typeof e._strict&&(t._strict=e._strict),"undefined"!=typeof e._tzm&&(t._tzm=e._tzm),"undefined"!=typeof e._isUTC&&(t._isUTC=e._isUTC),"undefined"!=typeof e._offset&&(t._offset=e._offset),"undefined"!=typeof e._pf&&(t._pf=e._pf),"undefined"!=typeof e._locale&&(t._locale=e._locale),Ye.length>0)for(i in Ye)s=Ye[i],o=e[s],"undefined"!=typeof o&&(t[s]=o);return t}function x(t){return 0>t?Math.ceil(t):Math.floor(t)}function w(t,e,i){for(var s=""+Math.abs(t),o=t>=0;s.lengths;s++)(i&&t[s]!==e[s]||!i&&L(t[s])!==L(e[s]))&&r++;return r+n}function k(t){if(t){var e=t.toLowerCase().replace(/(.)s$/,"$1");t=gi[t]||vi[e]||e}return t}function N(t){var e,i,s={};for(i in t)a(t,i)&&(e=k(i),e&&(s[e]=t[i]));return s}function I(t){var e,i;if(0===t.indexOf("week"))e=7,i="day";else{if(0!==t.indexOf("month"))return;e=12,i="month"}Ce[t]=function(s,o){var r,a,h=Ce._locale[t],d=[];if("number"==typeof s&&(o=s,s=n),a=function(t){var e=Ce().utc().set(i,t);return h.call(Ce._locale,e,s||"")},null!=o)return a(o);for(r=0;e>r;r++)d.push(a(r));return d}}function L(t){var e=+t,i=0;return 0!==e&&isFinite(e)&&(i=e>=0?Math.floor(e):Math.ceil(e)),i}function z(t,e){return new Date(Date.UTC(t,e+1,0)).getUTCDate()}function P(t,e,i){return me(Ce([t,11,31+e-i]),e,i).week}function A(t){return R(t)?366:365}function R(t){return t%4===0&&t%100!==0||t%400===0}function F(t){var e;t._a&&-2===t._pf.overflow&&(e=t._a[ze]<0||t._a[ze]>11?ze:t._a[Pe]<1||t._a[Pe]>z(t._a[Le],t._a[ze])?Pe:t._a[Ae]<0||t._a[Ae]>24||24===t._a[Ae]&&(0!==t._a[Re]||0!==t._a[Fe]||0!==t._a[He])?Ae:t._a[Re]<0||t._a[Re]>59?Re:t._a[Fe]<0||t._a[Fe]>59?Fe:t._a[He]<0||t._a[He]>999?He:-1,t._pf._overflowDayOfYear&&(Le>e||e>Pe)&&(e=Pe),t._pf.overflow=e)}function H(t){return null==t._isValid&&(t._isValid=!isNaN(t._d.getTime())&&t._pf.overflow<0&&!t._pf.empty&&!t._pf.invalidMonth&&!t._pf.nullInput&&!t._pf.invalidFormat&&!t._pf.userInvalidated,t._strict&&(t._isValid=t._isValid&&0===t._pf.charsLeftOver&&0===t._pf.unusedTokens.length&&t._pf.bigHour===n)),t._isValid}function B(t){return t?t.toLowerCase().replace("_","-"):t}function Y(t){for(var e,i,s,o,n=0;n0;){if(s=W(o.slice(0,e).join("-")))return s;if(i&&i.length>=e&&E(o,i,!0)>=e-1)break;e--}n++}return null}function W(t){var e=null;if(!Be[t]&&We)try{e=Ce.locale(),!function(){var t=new Error('Cannot find module "./locale"');throw t.code="MODULE_NOT_FOUND",t}(),Ce.locale(e)}catch(i){}return Be[t]}function G(t,e){var i,s;return e._isUTC?(i=e.clone(),s=(Ce.isMoment(t)||O(t)?+t:+Ce(t))-+i,i._d.setTime(+i._d+s),Ce.updateOffset(i,!1),i):Ce(t).local()}function j(t){return t.match(/\[[\s\S]/)?t.replace(/^\[|\]$/g,""):t.replace(/\\/g,"")}function U(t){var e,i,s=t.match(Ve);for(e=0,i=s.length;i>e;e++)s[e]=wi[s[e]]?wi[s[e]]:j(s[e]);return function(o){var n="";for(e=0;i>e;e++)n+=s[e]instanceof Function?s[e].call(o,t):s[e];return n}}function V(t,e){return t.isValid()?(e=X(e,t.localeData()),yi[e]||(yi[e]=U(e)),yi[e](t)):t.localeData().invalidDate()}function X(t,e){function i(t){return e.longDateFormat(t)||t}var s=5;for(Xe.lastIndex=0;s>=0&&Xe.test(t);)t=t.replace(Xe,i),Xe.lastIndex=0,s-=1;return t}function q(t,e){var i,s=e._strict;switch(t){case"Q":return oi;case"DDDD":return ri;case"YYYY":case"GGGG":case"gggg":return s?ai:Qe;case"Y":case"G":case"g":return di;case"YYYYYY":case"YYYYY":case"GGGGG":case"ggggg":return s?hi:Ke;case"S":if(s)return oi;case"SS":if(s)return ni;case"SSS":if(s)return ri;case"DDD":return Ze;case"MMM":case"MMMM":case"dd":case"ddd":case"dddd":return Je;case"a":case"A":return e._locale._meridiemParse;case"x":return ii;case"X":return si;case"Z":case"ZZ":return ti;case"T":return ei;case"SSSS":return $e;case"MM":case"DD":case"YY":case"GG":case"gg":case"HH":case"hh":case"mm":case"ss":case"ww":case"WW":return s?ni:qe;case"M":case"D":case"d":case"H":case"h":case"m":case"s":case"w":case"W":case"e":case"E":return qe;case"Do":return s?e._locale._ordinalParse:e._locale._ordinalParseLenient;default:return i=new RegExp(se(ie(t.replace("\\","")),"i"))}}function Z(t){t=t||"";var e=t.match(ti)||[],i=e[e.length-1]||[],s=(i+"").match(mi)||["-",0,0],o=+(60*s[1])+L(s[2]);return"+"===s[0]?o:-o}function Q(t,e,i){var s,o=i._a;switch(t){case"Q":null!=e&&(o[ze]=3*(L(e)-1));break;case"M":case"MM":null!=e&&(o[ze]=L(e)-1);break;case"MMM":case"MMMM":s=i._locale.monthsParse(e,t,i._strict),null!=s?o[ze]=s:i._pf.invalidMonth=e;break;case"D":case"DD":null!=e&&(o[Pe]=L(e));break;case"Do":null!=e&&(o[Pe]=L(parseInt(e.match(/\d{1,2}/)[0],10)));break;case"DDD":case"DDDD":null!=e&&(i._dayOfYear=L(e));break;case"YY":o[Le]=Ce.parseTwoDigitYear(e);break;case"YYYY":case"YYYYY":case"YYYYYY":o[Le]=L(e);break;case"a":case"A":i._meridiem=e;break;case"h":case"hh":i._pf.bigHour=!0;case"H":case"HH":o[Ae]=L(e);break;case"m":case"mm":o[Re]=L(e);break;case"s":case"ss":o[Fe]=L(e);break;case"S":case"SS":case"SSS":case"SSSS":o[He]=L(1e3*("0."+e));break;case"x":i._d=new Date(L(e));break;case"X":i._d=new Date(1e3*parseFloat(e));break;case"Z":case"ZZ":i._useUTC=!0,i._tzm=Z(e);break;case"dd":case"ddd":case"dddd":s=i._locale.weekdaysParse(e),null!=s?(i._w=i._w||{},i._w.d=s):i._pf.invalidWeekday=e;break;case"w":case"ww":case"W":case"WW":case"d":case"e":case"E":t=t.substr(0,1);case"gggg":case"GGGG":case"GGGGG":t=t.substr(0,2),e&&(i._w=i._w||{},i._w[t]=L(e));break;case"gg":case"GG":i._w=i._w||{},i._w[t]=Ce.parseTwoDigitYear(e)}}function K(t){var e,i,s,o,n,a,h;e=t._w,null!=e.GG||null!=e.W||null!=e.E?(n=1,a=4,i=r(e.GG,t._a[Le],me(Ce(),1,4).year),s=r(e.W,1),o=r(e.E,1)):(n=t._locale._week.dow,a=t._locale._week.doy,i=r(e.gg,t._a[Le],me(Ce(),n,a).year),s=r(e.w,1),null!=e.d?(o=e.d,n>o&&++s):o=null!=e.e?e.e+n:n),h=fe(i,s,o,a,n),t._a[Le]=h.year,t._dayOfYear=h.dayOfYear}function $(t){var e,i,s,o,n=[];if(!t._d){for(s=te(t),t._w&&null==t._a[Pe]&&null==t._a[ze]&&K(t),t._dayOfYear&&(o=r(t._a[Le],s[Le]),t._dayOfYear>A(o)&&(t._pf._overflowDayOfYear=!0),i=le(o,0,t._dayOfYear),t._a[ze]=i.getUTCMonth(),t._a[Pe]=i.getUTCDate()),e=0;3>e&&null==t._a[e];++e)t._a[e]=n[e]=s[e];for(;7>e;e++)t._a[e]=n[e]=null==t._a[e]?2===e?1:0:t._a[e];24===t._a[Ae]&&0===t._a[Re]&&0===t._a[Fe]&&0===t._a[He]&&(t._nextDay=!0,t._a[Ae]=0),t._d=(t._useUTC?le:de).apply(null,n),null!=t._tzm&&t._d.setUTCMinutes(t._d.getUTCMinutes()-t._tzm),t._nextDay&&(t._a[Ae]=24)}}function J(t){var e;t._d||(e=N(t._i),t._a=[e.year,e.month,e.day||e.date,e.hour,e.minute,e.second,e.millisecond],$(t))}function te(t){var e=new Date;return t._useUTC?[e.getUTCFullYear(),e.getUTCMonth(),e.getUTCDate()]:[e.getFullYear(),e.getMonth(),e.getDate()]}function ee(t){if(t._f===Ce.ISO_8601)return void ne(t);t._a=[],t._pf.empty=!0;var e,i,s,o,r,a=""+t._i,h=a.length,d=0;for(s=X(t._f,t._locale).match(Ve)||[],e=0;e0&&t._pf.unusedInput.push(r),a=a.slice(a.indexOf(i)+i.length),d+=i.length),wi[o]?(i?t._pf.empty=!1:t._pf.unusedTokens.push(o),Q(o,i,t)):t._strict&&!i&&t._pf.unusedTokens.push(o);t._pf.charsLeftOver=h-d,a.length>0&&t._pf.unusedInput.push(a),t._pf.bigHour===!0&&t._a[Ae]<=12&&(t._pf.bigHour=n),t._a[Ae]=f(t._locale,t._a[Ae],t._meridiem),$(t),F(t)}function ie(t){return t.replace(/\\(\[)|\\(\])|\[([^\]\[]*)\]|\\(.)/g,function(t,e,i,s,o){return e||i||s||o})}function se(t){return t.replace(/[-\/\\^$*+?.()|[\]{}]/g,"\\$&")}function oe(t){var e,i,s,o,n;if(0===t._f.length)return t._pf.invalidFormat=!0,void(t._d=new Date(0/0));for(o=0;on)&&(s=n,i=e));b(t,i||e)}function ne(t){var e,i,s=t._i,o=li.exec(s);if(o){for(t._pf.iso=!0,e=0,i=pi.length;i>e;e++)if(pi[e][1].exec(s)){t._f=pi[e][0]+(o[6]||" ");break}for(e=0,i=ui.length;i>e;e++)if(ui[e][1].exec(s)){t._f+=ui[e][0];break}s.match(ti)&&(t._f+="Z"),ee(t)}else t._isValid=!1}function re(t){ne(t),t._isValid===!1&&(delete t._isValid,Ce.createFromInputFallback(t))}function ae(t,e){var i,s=[];for(i=0;it&&a.setFullYear(t),a}function le(t){var e=new Date(Date.UTC.apply(null,arguments));return 1970>t&&e.setUTCFullYear(t),e}function ce(t,e){if("string"==typeof t)if(isNaN(t)){if(t=e.weekdaysParse(t),"number"!=typeof t)return null}else t=parseInt(t,10);return t}function pe(t,e,i,s,o){return o.relativeTime(e||1,!!i,t,s)}function ue(t,e,i){var s=Ce.duration(t).abs(),o=Ne(s.as("s")),n=Ne(s.as("m")),r=Ne(s.as("h")),a=Ne(s.as("d")),h=Ne(s.as("M")),d=Ne(s.as("y")),l=o0,l[4]=i,pe.apply({},l)}function me(t,e,i){var s,o=i-e,n=i-t.day();return n>o&&(n-=7),o-7>n&&(n+=7),s=Ce(t).add(n,"d"),{week:Math.ceil(s.dayOfYear()/7),year:s.year()}}function fe(t,e,i,s,o){var n,r,a=le(t,0,1).getUTCDay();return a=0===a?7:a,i=null!=i?i:o,n=o-a+(a>s?7:0)-(o>a?7:0),r=7*(e-1)+(i-o)+n+1,{year:r>0?t:t-1,dayOfYear:r>0?r:A(t-1)+r}}function ge(t){var e,i=t._i,s=t._f;return t._locale=t._locale||Ce.localeData(t._l),null===i||s===n&&""===i?Ce.invalid({nullInput:!0}):("string"==typeof i&&(t._i=i=t._locale.preparse(i)),Ce.isMoment(i)?new v(i,!0):(s?T(s)?oe(t):ee(t):he(t),e=new v(t),e._nextDay&&(e.add(1,"d"),e._nextDay=n),e))}function ve(t,e){var i,s;if(1===e.length&&T(e[0])&&(e=e[0]),!e.length)return Ce();for(i=e[0],s=1;s=0?"+":"-";return e+w(Math.abs(t),6)},gg:function(){return w(this.weekYear()%100,2)},gggg:function(){return w(this.weekYear(),4)},ggggg:function(){return w(this.weekYear(),5)},GG:function(){return w(this.isoWeekYear()%100,2)},GGGG:function(){return w(this.isoWeekYear(),4)},GGGGG:function(){return w(this.isoWeekYear(),5)},e:function(){return this.weekday()},E:function(){return this.isoWeekday()},a:function(){return this.localeData().meridiem(this.hours(),this.minutes(),!0)},A:function(){return this.localeData().meridiem(this.hours(),this.minutes(),!1)},H:function(){return this.hours()},h:function(){return this.hours()%12||12},m:function(){return this.minutes()},s:function(){return this.seconds()},S:function(){return L(this.milliseconds()/100)},SS:function(){return w(L(this.milliseconds()/10),2)},SSS:function(){return w(this.milliseconds(),3)},SSSS:function(){return w(this.milliseconds(),3)},Z:function(){var t=this.utcOffset(),e="+";return 0>t&&(t=-t,e="-"),e+w(L(t/60),2)+":"+w(L(t)%60,2)},ZZ:function(){var t=this.utcOffset(),e="+";return 0>t&&(t=-t,e="-"),e+w(L(t/60),2)+w(L(t)%60,2)},z:function(){return this.zoneAbbr()},zz:function(){return this.zoneName()},x:function(){return this.valueOf()},X:function(){return this.unix()},Q:function(){return this.quarter()}},Si={},Mi=["months","monthsShort","weekdays","weekdaysShort","weekdaysMin"],Di=!1;_i.length;)Oe=_i.pop(),wi[Oe+"o"]=u(wi[Oe],Oe);for(;xi.length;)Oe=xi.pop(),wi[Oe+Oe]=p(wi[Oe],2);wi.DDDD=p(wi.DDD,3),b(g.prototype,{set:function(t){var e,i;for(i in t)e=t[i],"function"==typeof e?this[i]=e:this["_"+i]=e;this._ordinalParseLenient=new RegExp(this._ordinalParse.source+"|"+/\d{1,2}/.source)},_months:"January_February_March_April_May_June_July_August_September_October_November_December".split("_"),months:function(t){return this._months[t.month()]},_monthsShort:"Jan_Feb_Mar_Apr_May_Jun_Jul_Aug_Sep_Oct_Nov_Dec".split("_"),monthsShort:function(t){return this._monthsShort[t.month()]},monthsParse:function(t,e,i){var s,o,n;for(this._monthsParse||(this._monthsParse=[],this._longMonthsParse=[],this._shortMonthsParse=[]),s=0;12>s;s++){if(o=Ce.utc([2e3,s]),i&&!this._longMonthsParse[s]&&(this._longMonthsParse[s]=new RegExp("^"+this.months(o,"").replace(".","")+"$","i"),this._shortMonthsParse[s]=new RegExp("^"+this.monthsShort(o,"").replace(".","")+"$","i")),i||this._monthsParse[s]||(n="^"+this.months(o,"")+"|^"+this.monthsShort(o,""),this._monthsParse[s]=new RegExp(n.replace(".",""),"i")),i&&"MMMM"===e&&this._longMonthsParse[s].test(t))return s;if(i&&"MMM"===e&&this._shortMonthsParse[s].test(t))return s;if(!i&&this._monthsParse[s].test(t))return s}},_weekdays:"Sunday_Monday_Tuesday_Wednesday_Thursday_Friday_Saturday".split("_"),weekdays:function(t){return this._weekdays[t.day()]},_weekdaysShort:"Sun_Mon_Tue_Wed_Thu_Fri_Sat".split("_"),weekdaysShort:function(t){return this._weekdaysShort[t.day()]},_weekdaysMin:"Su_Mo_Tu_We_Th_Fr_Sa".split("_"),weekdaysMin:function(t){return this._weekdaysMin[t.day()]},weekdaysParse:function(t){var e,i,s;for(this._weekdaysParse||(this._weekdaysParse=[]),e=0;7>e;e++)if(this._weekdaysParse[e]||(i=Ce([2e3,1]).day(e),s="^"+this.weekdays(i,"")+"|^"+this.weekdaysShort(i,"")+"|^"+this.weekdaysMin(i,""),this._weekdaysParse[e]=new RegExp(s.replace(".",""),"i")),this._weekdaysParse[e].test(t))return e},_longDateFormat:{LTS:"h:mm:ss A",LT:"h:mm A",L:"MM/DD/YYYY",LL:"MMMM D, YYYY",LLL:"MMMM D, YYYY LT",LLLL:"dddd, MMMM D, YYYY LT"},longDateFormat:function(t){var e=this._longDateFormat[t]; +return!e&&this._longDateFormat[t.toUpperCase()]&&(e=this._longDateFormat[t.toUpperCase()].replace(/MMMM|MM|DD|dddd/g,function(t){return t.slice(1)}),this._longDateFormat[t]=e),e},isPM:function(t){return"p"===(t+"").toLowerCase().charAt(0)},_meridiemParse:/[ap]\.?m?\.?/i,meridiem:function(t,e,i){return t>11?i?"pm":"PM":i?"am":"AM"},_calendar:{sameDay:"[Today at] LT",nextDay:"[Tomorrow at] LT",nextWeek:"dddd [at] LT",lastDay:"[Yesterday at] LT",lastWeek:"[Last] dddd [at] LT",sameElse:"L"},calendar:function(t,e,i){var s=this._calendar[t];return"function"==typeof s?s.apply(e,[i]):s},_relativeTime:{future:"in %s",past:"%s ago",s:"a few seconds",m:"a minute",mm:"%d minutes",h:"an hour",hh:"%d hours",d:"a day",dd:"%d days",M:"a month",MM:"%d months",y:"a year",yy:"%d years"},relativeTime:function(t,e,i,s){var o=this._relativeTime[i];return"function"==typeof o?o(t,e,i,s):o.replace(/%d/i,t)},pastFuture:function(t,e){var i=this._relativeTime[t>0?"future":"past"];return"function"==typeof i?i(e):i.replace(/%s/i,e)},ordinal:function(t){return this._ordinal.replace("%d",t)},_ordinal:"%d",_ordinalParse:/\d{1,2}/,preparse:function(t){return t},postformat:function(t){return t},week:function(t){return me(t,this._week.dow,this._week.doy).week},_week:{dow:0,doy:6},firstDayOfWeek:function(){return this._week.dow},firstDayOfYear:function(){return this._week.doy},_invalidDate:"Invalid date",invalidDate:function(){return this._invalidDate}}),Ce=function(t,e,i,s){var o;return"boolean"==typeof i&&(s=i,i=n),o={},o._isAMomentObject=!0,o._i=t,o._f=e,o._l=i,o._strict=s,o._isUTC=!1,o._pf=h(),ge(o)},Ce.suppressDeprecationWarnings=!1,Ce.createFromInputFallback=l("moment construction falls back to js Date. This is discouraged and will be removed in upcoming major release. Please refer to https://github.com/moment/moment/issues/1407 for more info.",function(t){t._d=new Date(t._i+(t._useUTC?" UTC":""))}),Ce.min=function(){var t=[].slice.call(arguments,0);return ve("isBefore",t)},Ce.max=function(){var t=[].slice.call(arguments,0);return ve("isAfter",t)},Ce.utc=function(t,e,i,s){var o;return"boolean"==typeof i&&(s=i,i=n),o={},o._isAMomentObject=!0,o._useUTC=!0,o._isUTC=!0,o._l=i,o._i=t,o._f=e,o._strict=s,o._pf=h(),ge(o).utc()},Ce.unix=function(t){return Ce(1e3*t)},Ce.duration=function(t,e){var i,s,o,n,r=t,h=null;return Ce.isDuration(t)?r={ms:t._milliseconds,d:t._days,M:t._months}:"number"==typeof t?(r={},e?r[e]=t:r.milliseconds=t):(h=je.exec(t))?(i="-"===h[1]?-1:1,r={y:0,d:L(h[Pe])*i,h:L(h[Ae])*i,m:L(h[Re])*i,s:L(h[Fe])*i,ms:L(h[He])*i}):(h=Ue.exec(t))?(i="-"===h[1]?-1:1,o=function(t){var e=t&&parseFloat(t.replace(",","."));return(isNaN(e)?0:e)*i},r={y:o(h[2]),M:o(h[3]),d:o(h[4]),h:o(h[5]),m:o(h[6]),s:o(h[7]),w:o(h[8])}):null==r?r={}:"object"==typeof r&&("from"in r||"to"in r)&&(n=M(Ce(r.from),Ce(r.to)),r={},r.ms=n.milliseconds,r.M=n.months),s=new y(r),Ce.isDuration(t)&&a(t,"_locale")&&(s._locale=t._locale),s},Ce.version=Ee,Ce.defaultFormat=ci,Ce.ISO_8601=function(){},Ce.momentProperties=Ye,Ce.updateOffset=function(){},Ce.relativeTimeThreshold=function(t,e){return bi[t]===n?!1:e===n?bi[t]:(bi[t]=e,!0)},Ce.lang=l("moment.lang is deprecated. Use moment.locale instead.",function(t,e){return Ce.locale(t,e)}),Ce.locale=function(t,e){var i;return t&&(i="undefined"!=typeof e?Ce.defineLocale(t,e):Ce.localeData(t),i&&(Ce.duration._locale=Ce._locale=i)),Ce._locale._abbr},Ce.defineLocale=function(t,e){return null!==e?(e.abbr=t,Be[t]||(Be[t]=new g),Be[t].set(e),Ce.locale(t),Be[t]):(delete Be[t],null)},Ce.langData=l("moment.langData is deprecated. Use moment.localeData instead.",function(t){return Ce.localeData(t)}),Ce.localeData=function(t){var e;if(t&&t._locale&&t._locale._abbr&&(t=t._locale._abbr),!t)return Ce._locale;if(!T(t)){if(e=W(t))return e;t=[t]}return Y(t)},Ce.isMoment=function(t){return t instanceof v||null!=t&&a(t,"_isAMomentObject")},Ce.isDuration=function(t){return t instanceof y};for(Oe=Mi.length-1;Oe>=0;--Oe)I(Mi[Oe]);Ce.normalizeUnits=function(t){return k(t)},Ce.invalid=function(t){var e=Ce.utc(0/0);return null!=t?b(e._pf,t):e._pf.userInvalidated=!0,e},Ce.parseZone=function(){return Ce.apply(null,arguments).parseZone()},Ce.parseTwoDigitYear=function(t){return L(t)+(L(t)>68?1900:2e3)},Ce.isDate=O,b(Ce.fn=v.prototype,{clone:function(){return Ce(this)},valueOf:function(){return+this._d-6e4*(this._offset||0)},unix:function(){return Math.floor(+this/1e3)},toString:function(){return this.clone().locale("en").format("ddd MMM DD YYYY HH:mm:ss [GMT]ZZ")},toDate:function(){return this._offset?new Date(+this):this._d},toISOString:function(){var t=Ce(this).utc();return 00:!1},parsingFlags:function(){return b({},this._pf)},invalidAt:function(){return this._pf.overflow},utc:function(t){return this.utcOffset(0,t)},local:function(t){return this._isUTC&&(this.utcOffset(0,t),this._isUTC=!1,t&&this.subtract(this._dateUtcOffset(),"m")),this},format:function(t){var e=V(this,t||Ce.defaultFormat);return this.localeData().postformat(e)},add:D(1,"add"),subtract:D(-1,"subtract"),diff:function(t,e,i){var s,o,n=G(t,this),r=6e4*(n.utcOffset()-this.utcOffset());return e=k(e),"year"===e||"month"===e||"quarter"===e?(o=m(this,n),"quarter"===e?o/=3:"year"===e&&(o/=12)):(s=this-n,o="second"===e?s/1e3:"minute"===e?s/6e4:"hour"===e?s/36e5:"day"===e?(s-r)/864e5:"week"===e?(s-r)/6048e5:s),i?o:x(o)},from:function(t,e){return Ce.duration({to:this,from:t}).locale(this.locale()).humanize(!e)},fromNow:function(t){return this.from(Ce(),t)},calendar:function(t){var e=t||Ce(),i=G(e,this).startOf("day"),s=this.diff(i,"days",!0),o=-6>s?"sameElse":-1>s?"lastWeek":0>s?"lastDay":1>s?"sameDay":2>s?"nextDay":7>s?"nextWeek":"sameElse";return this.format(this.localeData().calendar(o,this,Ce(e)))},isLeapYear:function(){return R(this.year())},isDST:function(){return this.utcOffset()>this.clone().month(0).utcOffset()||this.utcOffset()>this.clone().month(5).utcOffset()},day:function(t){var e=this._isUTC?this._d.getUTCDay():this._d.getDay();return null!=t?(t=ce(t,this.localeData()),this.add(t-e,"d")):e},month:xe("Month",!0),startOf:function(t){switch(t=k(t)){case"year":this.month(0);case"quarter":case"month":this.date(1);case"week":case"isoWeek":case"day":this.hours(0);case"hour":this.minutes(0);case"minute":this.seconds(0);case"second":this.milliseconds(0)}return"week"===t?this.weekday(0):"isoWeek"===t&&this.isoWeekday(1),"quarter"===t&&this.month(3*Math.floor(this.month()/3)),this},endOf:function(t){return t=k(t),t===n||"millisecond"===t?this:this.startOf(t).add(1,"isoWeek"===t?"week":t).subtract(1,"ms")},isAfter:function(t,e){var i;return e=k("undefined"!=typeof e?e:"millisecond"),"millisecond"===e?(t=Ce.isMoment(t)?t:Ce(t),+this>+t):(i=Ce.isMoment(t)?+t:+Ce(t),i<+this.clone().startOf(e))},isBefore:function(t,e){var i;return e=k("undefined"!=typeof e?e:"millisecond"),"millisecond"===e?(t=Ce.isMoment(t)?t:Ce(t),+t>+this):(i=Ce.isMoment(t)?+t:+Ce(t),+this.clone().endOf(e)t?this:t}),max:l("moment().max is deprecated, use moment.max instead. https://github.com/moment/moment/issues/1548",function(t){return t=Ce.apply(null,arguments),t>this?this:t}),zone:l("moment().zone is deprecated, use moment().utcOffset instead. https://github.com/moment/moment/issues/1779",function(t,e){return null!=t?("string"!=typeof t&&(t=-t),this.utcOffset(t,e),this):-this.utcOffset()}),utcOffset:function(t,e){var i,s=this._offset||0;return null!=t?("string"==typeof t&&(t=Z(t)),Math.abs(t)<16&&(t=60*t),!this._isUTC&&e&&(i=this._dateUtcOffset()),this._offset=t,this._isUTC=!0,null!=i&&this.add(i,"m"),s!==t&&(!e||this._changeInProgress?C(this,Ce.duration(t-s,"m"),1,!1):this._changeInProgress||(this._changeInProgress=!0,Ce.updateOffset(this,!0),this._changeInProgress=null)),this):this._isUTC?s:this._dateUtcOffset()},isLocal:function(){return!this._isUTC},isUtcOffset:function(){return this._isUTC},isUtc:function(){return this._isUTC&&0===this._offset},zoneAbbr:function(){return this._isUTC?"UTC":""},zoneName:function(){return this._isUTC?"Coordinated Universal Time":""},parseZone:function(){return this._tzm?this.utcOffset(this._tzm):"string"==typeof this._i&&this.utcOffset(Z(this._i)),this},hasAlignedHourOffset:function(t){return t=t?Ce(t).utcOffset():0,(this.utcOffset()-t)%60===0},daysInMonth:function(){return z(this.year(),this.month())},dayOfYear:function(t){var e=Ne((Ce(this).startOf("day")-Ce(this).startOf("year"))/864e5)+1;return null==t?e:this.add(t-e,"d")},quarter:function(t){return null==t?Math.ceil((this.month()+1)/3):this.month(3*(t-1)+this.month()%3)},weekYear:function(t){var e=me(this,this.localeData()._week.dow,this.localeData()._week.doy).year;return null==t?e:this.add(t-e,"y")},isoWeekYear:function(t){var e=me(this,1,4).year;return null==t?e:this.add(t-e,"y")},week:function(t){var e=this.localeData().week(this);return null==t?e:this.add(7*(t-e),"d")},isoWeek:function(t){var e=me(this,1,4).week;return null==t?e:this.add(7*(t-e),"d")},weekday:function(t){var e=(this.day()+7-this.localeData()._week.dow)%7;return null==t?e:this.add(t-e,"d")},isoWeekday:function(t){return null==t?this.day()||7:this.day(this.day()%7?t:t-7)},isoWeeksInYear:function(){return P(this.year(),1,4)},weeksInYear:function(){var t=this.localeData()._week;return P(this.year(),t.dow,t.doy)},get:function(t){return t=k(t),this[t]()},set:function(t,e){var i;if("object"==typeof t)for(i in t)this.set(i,t[i]);else t=k(t),"function"==typeof this[t]&&this[t](e);return this},locale:function(t){var e;return t===n?this._locale._abbr:(e=Ce.localeData(t),null!=e&&(this._locale=e),this)},lang:l("moment().lang() is deprecated. Instead, use moment().localeData() to get the language configuration. Use moment().locale() to change languages.",function(t){return t===n?this.localeData():this.locale(t)}),localeData:function(){return this._locale},_dateUtcOffset:function(){return 15*-Math.round(this._d.getTimezoneOffset()/15)}}),Ce.fn.millisecond=Ce.fn.milliseconds=xe("Milliseconds",!1),Ce.fn.second=Ce.fn.seconds=xe("Seconds",!1),Ce.fn.minute=Ce.fn.minutes=xe("Minutes",!1),Ce.fn.hour=Ce.fn.hours=xe("Hours",!0),Ce.fn.date=xe("Date",!0),Ce.fn.dates=l("dates accessor is deprecated. Use date instead.",xe("Date",!0)),Ce.fn.year=xe("FullYear",!0),Ce.fn.years=l("years accessor is deprecated. Use year instead.",xe("FullYear",!0)),Ce.fn.days=Ce.fn.day,Ce.fn.months=Ce.fn.month,Ce.fn.weeks=Ce.fn.week,Ce.fn.isoWeeks=Ce.fn.isoWeek,Ce.fn.quarters=Ce.fn.quarter,Ce.fn.toJSON=Ce.fn.toISOString,Ce.fn.isUTC=Ce.fn.isUtc,b(Ce.duration.fn=y.prototype,{_bubble:function(){var t,e,i,s=this._milliseconds,o=this._days,n=this._months,r=this._data,a=0;r.milliseconds=s%1e3,t=x(s/1e3),r.seconds=t%60,e=x(t/60),r.minutes=e%60,i=x(e/60),r.hours=i%24,o+=x(i/24),a=x(we(o)),o-=x(Se(a)),n+=x(o/30),o%=30,a+=x(n/12),n%=12,r.days=o,r.months=n,r.years=a},abs:function(){return this._milliseconds=Math.abs(this._milliseconds),this._days=Math.abs(this._days),this._months=Math.abs(this._months),this._data.milliseconds=Math.abs(this._data.milliseconds),this._data.seconds=Math.abs(this._data.seconds),this._data.minutes=Math.abs(this._data.minutes),this._data.hours=Math.abs(this._data.hours),this._data.months=Math.abs(this._data.months),this._data.years=Math.abs(this._data.years),this},weeks:function(){return x(this.days()/7)},valueOf:function(){return this._milliseconds+864e5*this._days+this._months%12*2592e6+31536e6*L(this._months/12)},humanize:function(t){var e=ue(this,!t,this.localeData());return t&&(e=this.localeData().pastFuture(+this,e)),this.localeData().postformat(e)},add:function(t,e){var i=Ce.duration(t,e);return this._milliseconds+=i._milliseconds,this._days+=i._days,this._months+=i._months,this._bubble(),this},subtract:function(t,e){var i=Ce.duration(t,e);return this._milliseconds-=i._milliseconds,this._days-=i._days,this._months-=i._months,this._bubble(),this},get:function(t){return t=k(t),this[t.toLowerCase()+"s"]()},as:function(t){var e,i;if(t=k(t),"month"===t||"year"===t)return e=this._days+this._milliseconds/864e5,i=this._months+12*we(e),"month"===t?i:i/12;switch(e=this._days+Math.round(Se(this._months/12)),t){case"week":return e/7+this._milliseconds/6048e5;case"day":return e+this._milliseconds/864e5;case"hour":return 24*e+this._milliseconds/36e5;case"minute":return 24*e*60+this._milliseconds/6e4;case"second":return 24*e*60*60+this._milliseconds/1e3;case"millisecond":return Math.floor(24*e*60*60*1e3)+this._milliseconds;default:throw new Error("Unknown unit "+t)}},lang:Ce.fn.lang,locale:Ce.fn.locale,toIsoString:l("toIsoString() is deprecated. Please use toISOString() instead (notice the capitals)",function(){return this.toISOString()}),toISOString:function(){var t=Math.abs(this.years()),e=Math.abs(this.months()),i=Math.abs(this.days()),s=Math.abs(this.hours()),o=Math.abs(this.minutes()),n=Math.abs(this.seconds()+this.milliseconds()/1e3);return this.asSeconds()?(this.asSeconds()<0?"-":"")+"P"+(t?t+"Y":"")+(e?e+"M":"")+(i?i+"D":"")+(s||o||n?"T":"")+(s?s+"H":"")+(o?o+"M":"")+(n?n+"S":""):"P0D"},localeData:function(){return this._locale},toJSON:function(){return this.toISOString()}}),Ce.duration.fn.toString=Ce.duration.fn.toISOString;for(Oe in fi)a(fi,Oe)&&Me(Oe.toLowerCase());Ce.duration.fn.asMilliseconds=function(){return this.as("ms")},Ce.duration.fn.asSeconds=function(){return this.as("s")},Ce.duration.fn.asMinutes=function(){return this.as("m")},Ce.duration.fn.asHours=function(){return this.as("h")},Ce.duration.fn.asDays=function(){return this.as("d")},Ce.duration.fn.asWeeks=function(){return this.as("weeks")},Ce.duration.fn.asMonths=function(){return this.as("M")},Ce.duration.fn.asYears=function(){return this.as("y")},Ce.locale("en",{ordinalParse:/\d{1,2}(th|st|nd|rd)/,ordinal:function(t){var e=t%10,i=1===L(t%100/10)?"th":1===e?"st":2===e?"nd":3===e?"rd":"th";return t+i}}),We?o.exports=Ce:(s=function(t,e,i){return i.config&&i.config()&&i.config().noGlobal===!0&&(ke.moment=Te),Ce}.call(e,i,e,o),!(s!==n&&(o.exports=s)),De(!0))}).call(this)}).call(e,function(){return this}(),i(72)(t))},function(t,e){var i,s,o;!function(n,r){s=[],i=r,o="function"==typeof i?i.apply(e,s):i,!(void 0!==o&&(t.exports=o))}(this,function(){function t(t){var e,i=t&&t.preventDefault||!1,s=t&&t.container||window,o={},n={keydown:{},keyup:{}},r={};for(e=97;122>=e;e++)r[String.fromCharCode(e)]={code:65+(e-97),shift:!1};for(e=65;90>=e;e++)r[String.fromCharCode(e)]={code:e,shift:!0};for(e=0;9>=e;e++)r[""+e]={code:48+e,shift:!1};for(e=1;12>=e;e++)r["F"+e]={code:111+e,shift:!1};for(e=0;9>=e;e++)r["num"+e]={code:96+e,shift:!1};r["num*"]={code:106,shift:!1},r["num+"]={code:107,shift:!1},r["num-"]={code:109,shift:!1},r["num/"]={code:111,shift:!1},r["num."]={code:110,shift:!1},r.left={code:37,shift:!1},r.up={code:38,shift:!1},r.right={code:39,shift:!1},r.down={code:40,shift:!1},r.space={code:32,shift:!1},r.enter={code:13,shift:!1},r.shift={code:16,shift:void 0},r.esc={code:27,shift:!1},r.backspace={code:8,shift:!1},r.tab={code:9,shift:!1},r.ctrl={code:17,shift:!1},r.alt={code:18,shift:!1},r["delete"]={code:46,shift:!1},r.pageup={code:33,shift:!1},r.pagedown={code:34,shift:!1},r["="]={code:187,shift:!1},r["-"]={code:189,shift:!1},r["]"]={code:221,shift:!1},r["["]={code:219,shift:!1};var a=function(t){d(t,"keydown")},h=function(t){d(t,"keyup")},d=function(t,e){if(void 0!==n[e][t.keyCode]){for(var s=n[e][t.keyCode],o=0;o0?i._handlers[t]=s:(i._off(t,o),delete i._handlers[t]))}),i},i.destroy=function(){var t=i.element;delete t.hammer,i._handlers={},i._destroy()},i}})},function(t,e,i){var s;!function(o,n,r,a){function h(t,e,i){return setTimeout(m(t,i),e)}function d(t,e,i){return Array.isArray(t)?(l(t,i[e],i),!0):!1}function l(t,e,i){var s;if(t)if(t.forEach)t.forEach(e,i);else if(t.length!==a)for(s=0;s-1}function x(t){return t.trim().split(/\s+/g)}function w(t,e,i){if(t.indexOf&&!i)return t.indexOf(e);for(var s=0;si[e]}):s.sort()),s}function D(t,e){for(var i,s,o=e[0].toUpperCase()+e.slice(1),n=0;n1&&!i.firstMultiple?i.firstMultiple=z(e):1===o&&(i.firstMultiple=!1);var n=i.firstInput,r=i.firstMultiple,a=r?r.center:n.center,h=e.center=P(s);e.timeStamp=ve(),e.deltaTime=e.timeStamp-n.timeStamp,e.angle=H(a,h),e.distance=F(a,h),I(i,e),e.offsetDirection=R(e.deltaX,e.deltaY),e.scale=r?Y(r.pointers,s):1,e.rotation=r?B(r.pointers,s):0,L(i,e);var d=t.element;b(e.srcEvent.target,d)&&(d=e.srcEvent.target),e.target=d}function I(t,e){var i=e.center,s=t.offsetDelta||{},o=t.prevDelta||{},n=t.prevInput||{};(e.eventType===Oe||n.eventType===ke)&&(o=t.prevDelta={x:n.deltaX||0,y:n.deltaY||0},s=t.offsetDelta={x:i.x,y:i.y}),e.deltaX=o.x+(i.x-s.x),e.deltaY=o.y+(i.y-s.y)}function L(t,e){var i,s,o,n,r=t.lastInterval||e,h=e.timeStamp-r.timeStamp;if(e.eventType!=Ne&&(h>Te||r.velocity===a)){var d=r.deltaX-e.deltaX,l=r.deltaY-e.deltaY,c=A(h,d,l);s=c.x,o=c.y,i=ge(c.x)>ge(c.y)?c.x:c.y,n=R(d,l),t.lastInterval=e}else i=r.velocity,s=r.velocityX,o=r.velocityY,n=r.direction;e.velocity=i,e.velocityX=s,e.velocityY=o,e.direction=n}function z(t){for(var e=[],i=0;io;)i+=t[o].clientX,s+=t[o].clientY,o++;return{x:fe(i/e),y:fe(s/e)}}function A(t,e,i){return{x:e/t||0,y:i/t||0}}function R(t,e){return t===e?Ie:ge(t)>=ge(e)?t>0?Le:ze:e>0?Pe:Ae}function F(t,e,i){i||(i=Be);var s=e[i[0]]-t[i[0]],o=e[i[1]]-t[i[1]];return Math.sqrt(s*s+o*o)}function H(t,e,i){i||(i=Be);var s=e[i[0]]-t[i[0]],o=e[i[1]]-t[i[1]];return 180*Math.atan2(o,s)/Math.PI}function B(t,e){return H(e[1],e[0],Ye)-H(t[1],t[0],Ye)}function Y(t,e){return F(e[0],e[1],Ye)/F(t[0],t[1],Ye)}function W(){this.evEl=Ge,this.evWin=je,this.allow=!0,this.pressed=!1,O.apply(this,arguments)}function G(){this.evEl=Xe,this.evWin=qe,O.apply(this,arguments),this.store=this.manager.session.pointerEvents=[]}function j(){this.evTarget=Qe,this.evWin=Ke,this.started=!1,O.apply(this,arguments)}function U(t,e){var i=S(t.touches),s=S(t.changedTouches);return e&(ke|Ne)&&(i=M(i.concat(s),"identifier",!0)),[i,s]}function V(){this.evTarget=Je,this.targetIds={},O.apply(this,arguments)}function X(t,e){var i=S(t.touches),s=this.targetIds;if(e&(Oe|Ee)&&1===i.length)return s[i[0].identifier]=!0,[i,i];var o,n,r=S(t.changedTouches),a=[],h=this.target;if(n=i.filter(function(t){return b(t.target,h)}),e===Oe)for(o=0;oa&&(e.push(t),a=e.length-1):o&(ke|Ne)&&(i=!0),0>a||(e[a]=t,this.callback(this.manager,o,{pointers:e,changedPointers:[t],pointerType:n,srcEvent:t}),i&&e.splice(a,1))}});var Ze={touchstart:Oe,touchmove:Ee,touchend:ke,touchcancel:Ne},Qe="touchstart",Ke="touchstart touchmove touchend touchcancel";u(j,O,{handler:function(t){var e=Ze[t.type];if(e===Oe&&(this.started=!0),this.started){var i=U.call(this,t,e);e&(ke|Ne)&&i[0].length-i[1].length===0&&(this.started=!1),this.callback(this.manager,e,{pointers:i[0],changedPointers:i[1],pointerType:Se,srcEvent:t})}}});var $e={touchstart:Oe,touchmove:Ee,touchend:ke,touchcancel:Ne},Je="touchstart touchmove touchend touchcancel";u(V,O,{handler:function(t){var e=$e[t.type],i=X.call(this,t,e);i&&this.callback(this.manager,e,{pointers:i[0],changedPointers:i[1],pointerType:Se,srcEvent:t})}}),u(q,O,{handler:function(t,e,i){var s=i.pointerType==Se,o=i.pointerType==De;if(s)this.mouse.allow=!1;else if(o&&!this.mouse.allow)return;e&(ke|Ne)&&(this.mouse.allow=!0),this.callback(t,e,i)},destroy:function(){this.touch.destroy(),this.mouse.destroy()}});var ti=D(ue.style,"touchAction"),ei=ti!==a,ii="compute",si="auto",oi="manipulation",ni="none",ri="pan-x",ai="pan-y";Z.prototype={set:function(t){t==ii&&(t=this.compute()),ei&&(this.manager.element.style[ti]=t),this.actions=t.toLowerCase().trim()},update:function(){this.set(this.manager.options.touchAction)},compute:function(){var t=[];return l(this.manager.recognizers,function(e){f(e.options.enable,[e])&&(t=t.concat(e.getTouchAction()))}),Q(t.join(" "))},preventDefaults:function(t){if(!ei){var e=t.srcEvent,i=t.offsetDirection;if(this.manager.session.prevented)return void e.preventDefault();var s=this.actions,o=_(s,ni),n=_(s,ai),r=_(s,ri);return o||n&&i&Re||r&&i&Fe?this.preventSrc(e):void 0}},preventSrc:function(t){this.manager.session.prevented=!0,t.preventDefault()}};var hi=1,di=2,li=4,ci=8,pi=ci,ui=16,mi=32;K.prototype={defaults:{},set:function(t){return c(this.options,t),this.manager&&this.manager.touchAction.update(),this},recognizeWith:function(t){if(d(t,"recognizeWith",this))return this;var e=this.simultaneous;return t=te(t,this),e[t.id]||(e[t.id]=t,t.recognizeWith(this)),this},dropRecognizeWith:function(t){return d(t,"dropRecognizeWith",this)?this:(t=te(t,this),delete this.simultaneous[t.id],this)},requireFailure:function(t){if(d(t,"requireFailure",this))return this;var e=this.requireFail;return t=te(t,this),-1===w(e,t)&&(e.push(t),t.requireFailure(this)),this},dropRequireFailure:function(t){if(d(t,"dropRequireFailure",this))return this;t=te(t,this);var e=w(this.requireFail,t);return e>-1&&this.requireFail.splice(e,1),this},hasRequireFailures:function(){return this.requireFail.length>0},canRecognizeWith:function(t){return!!this.simultaneous[t.id]},emit:function(t){function e(e){i.manager.emit(i.options.event+(e?$(s):""),t)}var i=this,s=this.state;ci>s&&e(!0),e(),s>=ci&&e(!0)},tryEmit:function(t){return this.canEmit()?this.emit(t):void(this.state=mi)},canEmit:function(){for(var t=0;tn?Le:ze,i=n!=this.pX,s=Math.abs(t.deltaX)):(o=0===r?Ie:0>r?Pe:Ae,i=r!=this.pY,s=Math.abs(t.deltaY))),t.direction=o,i&&s>e.threshold&&o&e.direction},attrTest:function(t){return ee.prototype.attrTest.call(this,t)&&(this.state&di||!(this.state&di)&&this.directionTest(t))},emit:function(t){this.pX=t.deltaX,this.pY=t.deltaY;var e=J(t.direction);e&&this.manager.emit(this.options.event+e,t),this._super.emit.call(this,t)}}),u(se,ee,{defaults:{event:"pinch",threshold:0,pointers:2},getTouchAction:function(){return[ni]},attrTest:function(t){return this._super.attrTest.call(this,t)&&(Math.abs(t.scale-1)>this.options.threshold||this.state&di)},emit:function(t){if(this._super.emit.call(this,t),1!==t.scale){var e=t.scale<1?"in":"out";this.manager.emit(this.options.event+e,t)}}}),u(oe,K,{defaults:{event:"press",pointers:1,time:500,threshold:5},getTouchAction:function(){return[si]},process:function(t){var e=this.options,i=t.pointers.length===e.pointers,s=t.distancee.time;if(this._input=t,!s||!i||t.eventType&(ke|Ne)&&!o)this.reset();else if(t.eventType&Oe)this.reset(),this._timer=h(function(){this.state=pi,this.tryEmit() +},e.time,this);else if(t.eventType&ke)return pi;return mi},reset:function(){clearTimeout(this._timer)},emit:function(t){this.state===pi&&(t&&t.eventType&ke?this.manager.emit(this.options.event+"up",t):(this._input.timeStamp=ve(),this.manager.emit(this.options.event,this._input)))}}),u(ne,ee,{defaults:{event:"rotate",threshold:0,pointers:2},getTouchAction:function(){return[ni]},attrTest:function(t){return this._super.attrTest.call(this,t)&&(Math.abs(t.rotation)>this.options.threshold||this.state&di)}}),u(re,ee,{defaults:{event:"swipe",threshold:10,velocity:.65,direction:Re|Fe,pointers:1},getTouchAction:function(){return ie.prototype.getTouchAction.call(this)},attrTest:function(t){var e,i=this.options.direction;return i&(Re|Fe)?e=t.velocity:i&Re?e=t.velocityX:i&Fe&&(e=t.velocityY),this._super.attrTest.call(this,t)&&i&t.direction&&t.distance>this.options.threshold&&ge(e)>this.options.velocity&&t.eventType&ke},emit:function(t){var e=J(t.direction);e&&this.manager.emit(this.options.event+e,t),this.manager.emit(this.options.event,t)}}),u(ae,K,{defaults:{event:"tap",pointers:1,taps:1,interval:300,time:250,threshold:2,posThreshold:10},getTouchAction:function(){return[oi]},process:function(t){var e=this.options,i=t.pointers.length===e.pointers,s=t.distancet&&s>o;)o%3==0?(this.forceAggregateHubs(!0),this.normalizeClusterLevels()):this.increaseClusterLevel(),i=this.nodeIndices.length,o+=1;o>0&&1==e&&this.repositionNodes(),this._updateCalculationNodes()},e.openCluster=function(t){var e=this.moving;if(t.clusterSize>this.constants.clustering.sectorThreshold&&this._nodeInActiveArea(t)&&("default"!=this._sector()||1!=this.nodeIndices.length)){this._addSector(t);for(var i=0;this.nodeIndices.lengthi;)this.decreaseClusterLevel(),i+=1}else this._expandClusterNode(t,!1,!0),this._updateNodeIndexList(),this._updateDynamicEdges(),this._updateCalculationNodes(),this.updateLabels();this.moving!=e&&this.start()},e.updateClustersDefault=function(){1==this.constants.clustering.enabled&&this.updateClusters(0,!1,!1)},e.increaseClusterLevel=function(){this.updateClusters(-1,!1,!0)},e.decreaseClusterLevel=function(){this.updateClusters(1,!1,!0)},e.updateClusters=function(t,e,i,s){var o=this.moving,n=this.nodeIndices.length;this.previousScale>this.scale&&0==t&&this._collapseSector(),this.previousScale>this.scale||-1==t?this._formClusters(i):(this.previousScalethis.scale||-1==t)&&(this._aggregateHubs(i),this._updateNodeIndexList()),(this.previousScale>this.scale||-1==t)&&(this.handleChains(),this._updateNodeIndexList()),this.previousScale=this.scale,this._updateDynamicEdges(),this.updateLabels(),this.nodeIndices.lengththis.constants.clustering.chainThreshold&&this._reduceAmountOfChains(1-this.constants.clustering.chainThreshold/t)},e._aggregateHubs=function(t){this._getHubSize(),this._formClustersByHub(t,!1)},e.forceAggregateHubs=function(t){var e=this.moving,i=this.nodeIndices.length;this._aggregateHubs(!0),this._updateNodeIndexList(),this._updateDynamicEdges(),this.updateLabels(),this.nodeIndices.length!=i&&(this.clusterSession+=1),(0==t||void 0===t)&&this.moving!=e&&this.start()},e._openClustersBySize=function(){for(var t in this.nodes)if(this.nodes.hasOwnProperty(t)){var e=this.nodes[t];1==e.inView()&&(e.width*this.scale>this.constants.clustering.screenSizeThreshold*this.frame.canvas.clientWidth||e.height*this.scale>this.constants.clustering.screenSizeThreshold*this.frame.canvas.clientHeight)&&this.openCluster(e)}},e._openClusters=function(t,e){for(var i=0;i1&&(t.clusterSizei)){var r=n.from,a=n.to;n.to.options.mass>n.from.options.mass&&(r=n.to,a=n.from),1==a.dynamicEdgesLength?this._addToCluster(r,a,!1):1==r.dynamicEdgesLength&&this._addToCluster(a,r,!1)}}},e._forceClustersByZoom=function(){for(var t in this.nodes)if(this.nodes.hasOwnProperty(t)){var e=this.nodes[t];if(1==e.dynamicEdgesLength&&0!=e.dynamicEdges.length){var i=e.dynamicEdges[0],s=i.toId==e.id?this.nodes[i.fromId]:this.nodes[i.toId];e.id!=s.id&&(s.options.mass>e.options.mass?this._addToCluster(s,e,!0):this._addToCluster(e,s,!0))}}},e._clusterToSmallestNeighbour=function(t){for(var e=-1,i=null,s=0;so.clusterSessions.length&&(e=o.clusterSessions.length,i=o)}null!=o&&void 0!==this.nodes[o.id]&&this._addToCluster(o,t,!0)},e._formClustersByHub=function(t,e){for(var i in this.nodes)this.nodes.hasOwnProperty(i)&&this._formClusterFromHub(this.nodes[i],t,e)},e._formClusterFromHub=function(t,e,i,s){if(void 0===s&&(s=0),t.dynamicEdgesLength>=this.hubThreshold&&0==i||t.dynamicEdgesLength==this.hubThreshold&&1==i){for(var o,n,r,a=this.constants.clustering.clusterEdgeThreshold/this.scale,h=!1,d=[],l=t.dynamicEdges.length,c=0;l>c;c++)d.push(t.dynamicEdges[c].id);if(0==e)for(h=!1,c=0;l>c;c++){var p=this.edges[d[c]];if(void 0!==p&&p.connected&&p.toId!=p.fromId&&(o=p.to.x-p.from.x,n=p.to.y-p.from.y,r=Math.sqrt(o*o+n*n),a>r)){h=!0;break}}if(!e&&h||e)for(c=0;l>c;c++)if(p=this.edges[d[c]],void 0!==p){var u=this.nodes[p.fromId==t.id?p.toId:p.fromId];u.dynamicEdges.length<=this.hubThreshold+s&&u.id!=t.id&&this._addToCluster(t,u,e)}}},e._addToCluster=function(t,e,i){t.containedNodes[e.id]=e;for(var s=0;s1)for(var s=0;s1&&(e.label="[".concat(String(e.clusterSize),"]"))}for(t in this.nodes)this.nodes.hasOwnProperty(t)&&(e=this.nodes[t],1==e.clusterSize&&(e.label=void 0!==e.originalLabel?e.originalLabel:String(e.id)))},e.normalizeClusterLevels=function(){var t,e=0,i=1e9,s=0;for(t in this.nodes)this.nodes.hasOwnProperty(t)&&(s=this.nodes[t].clusterSessions.length,s>e&&(e=s),i>s&&(i=s));if(e-i>this.constants.clustering.clusterLevelDifference){var o=this.nodeIndices.length,n=e-this.constants.clustering.clusterLevelDifference;for(t in this.nodes)this.nodes.hasOwnProperty(t)&&this.nodes[t].clusterSessions.lengths&&(s=n.dynamicEdgesLength),t+=n.dynamicEdgesLength,e+=Math.pow(n.dynamicEdgesLength,2),i+=1}t/=i,e/=i;var r=e-Math.pow(t,2),a=Math.sqrt(r);this.hubThreshold=Math.floor(t+2*a),this.hubThreshold>s&&(this.hubThreshold=s)},e._reduceAmountOfChains=function(t){this.hubThreshold=2;var e=Math.floor(this.nodeIndices.length*t);for(var i in this.nodes)this.nodes.hasOwnProperty(i)&&2==this.nodes[i].dynamicEdgesLength&&this.nodes[i].dynamicEdges.length>=2&&e>0&&(this._formClusterFromHub(this.nodes[i],!0,!0,1),e-=1)},e._getChainFraction=function(){var t=0,e=0;for(var i in this.nodes)this.nodes.hasOwnProperty(i)&&(2==this.nodes[i].dynamicEdgesLength&&this.nodes[i].dynamicEdges.length>=2&&(t+=1),e+=1);return t/e}},function(t,e,i){var s=i(1),o=i(40);e._putDataInSector=function(){this.sectors.active[this._sector()].nodes=this.nodes,this.sectors.active[this._sector()].edges=this.edges,this.sectors.active[this._sector()].nodeIndices=this.nodeIndices},e._switchToSector=function(t,e){void 0===e||"active"==e?this._switchToActiveSector(t):this._switchToFrozenSector(t)},e._switchToActiveSector=function(t){this.nodeIndices=this.sectors.active[t].nodeIndices,this.nodes=this.sectors.active[t].nodes,this.edges=this.sectors.active[t].edges},e._switchToSupportSector=function(){this.nodeIndices=this.sectors.support.nodeIndices,this.nodes=this.sectors.support.nodes,this.edges=this.sectors.support.edges},e._switchToFrozenSector=function(t){this.nodeIndices=this.sectors.frozen[t].nodeIndices,this.nodes=this.sectors.frozen[t].nodes,this.edges=this.sectors.frozen[t].edges},e._loadLatestSector=function(){this._switchToSector(this._sector())},e._sector=function(){return this.activeSector[this.activeSector.length-1]},e._previousSector=function(){if(this.activeSector.length>1)return this.activeSector[this.activeSector.length-2];throw new TypeError("there are not enough sectors in the this.activeSector array.")},e._setActiveSector=function(t){this.activeSector.push(t)},e._forgetLastSector=function(){this.activeSector.pop()},e._createNewSector=function(t){this.sectors.active[t]={nodes:{},edges:{},nodeIndices:[],formationScale:this.scale,drawingNode:void 0},this.sectors.active[t].drawingNode=new o({id:t,color:{background:"#eaefef",border:"495c5e"}},{},{},this.constants),this.sectors.active[t].drawingNode.clusterSize=2},e._deleteActiveSector=function(t){delete this.sectors.active[t]},e._deleteFrozenSector=function(t){delete this.sectors.frozen[t]},e._freezeSector=function(t){this.sectors.frozen[t]=this.sectors.active[t],this._deleteActiveSector(t)},e._activateSector=function(t){this.sectors.active[t]=this.sectors.frozen[t],this._deleteFrozenSector(t)},e._mergeThisWithFrozen=function(t){for(var e in this.nodes)this.nodes.hasOwnProperty(e)&&(this.sectors.frozen[t].nodes[e]=this.nodes[e]);for(var i in this.edges)this.edges.hasOwnProperty(i)&&(this.sectors.frozen[t].edges[i]=this.edges[i]);for(var s=0;s1?this[t](o[0],o[1]):this[t](e))}return this._loadLatestSector(),i},e._doInSupportSector=function(t,e){var i=!1;if(void 0===e)this._switchToSupportSector(),i=this[t]();else{this._switchToSupportSector();var s=Array.prototype.splice.call(arguments,1);i=s.length>1?this[t](s[0],s[1]):this[t](e)}return this._loadLatestSector(),i},e._doInAllFrozenSectors=function(t,e){if(void 0===e)for(var i in this.sectors.frozen)this.sectors.frozen.hasOwnProperty(i)&&(this._switchToFrozenSector(i),this[t]());else for(var i in this.sectors.frozen)if(this.sectors.frozen.hasOwnProperty(i)){this._switchToFrozenSector(i);var s=Array.prototype.splice.call(arguments,1);s.length>1?this[t](s[0],s[1]):this[t](e)}this._loadLatestSector()},e._doInAllSectors=function(t,e){var i=Array.prototype.splice.call(arguments,1);void 0===e?(this._doInAllActiveSectors(t),this._doInAllFrozenSectors(t)):i.length>1?(this._doInAllActiveSectors(t,i[0],i[1]),this._doInAllFrozenSectors(t,i[0],i[1])):(this._doInAllActiveSectors(t,e),this._doInAllFrozenSectors(t,e))},e._clearNodeIndexList=function(){var t=this._sector();this.sectors.active[t].nodeIndices=[],this.nodeIndices=this.sectors.active[t].nodeIndices},e._drawSectorNodes=function(t,e){var i,s=1e9,o=-1e9,n=1e9,r=-1e9;for(var a in this.sectors[e])if(this.sectors[e].hasOwnProperty(a)&&void 0!==this.sectors[e][a].drawingNode){this._switchToSector(a,e),s=1e9,o=-1e9,n=1e9,r=-1e9;for(var h in this.nodes)this.nodes.hasOwnProperty(h)&&(i=this.nodes[h],i.resize(t),n>i.x-.5*i.width&&(n=i.x-.5*i.width),ri.y-.5*i.height&&(s=i.y-.5*i.height),o0?this.nodes[i[i.length-1]]:null},e._getEdgesOverlappingWith=function(t,e){var i=this.edges;for(var s in i)i.hasOwnProperty(s)&&i[s].isOverlappingWith(t)&&e.push(s)},e._getAllEdgesOverlappingWith=function(t){var e=[];return this._doInAllActiveSectors("_getEdgesOverlappingWith",t,e),e},e._getEdgeAt=function(t){var e=this._pointerToPositionObject(t),i=this._getAllEdgesOverlappingWith(e);return i.length>0?this.edges[i[i.length-1]]:null},e._addToSelection=function(t){t instanceof s?this.selectionObj.nodes[t.id]=t:this.selectionObj.edges[t.id]=t},e._addToHover=function(t){t instanceof s?this.hoverObj.nodes[t.id]=t:this.hoverObj.edges[t.id]=t},e._removeFromSelection=function(t){t instanceof s?delete this.selectionObj.nodes[t.id]:delete this.selectionObj.edges[t.id]},e._unselectAll=function(t){void 0===t&&(t=!1);for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&this.selectionObj.nodes[e].unselect();for(var i in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(i)&&this.selectionObj.edges[i].unselect();this.selectionObj={nodes:{},edges:{}},0==t&&this.emit("select",this.getSelection())},e._unselectClusters=function(t){void 0===t&&(t=!1);for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&this.selectionObj.nodes[e].clusterSize>1&&(this.selectionObj.nodes[e].unselect(),this._removeFromSelection(this.selectionObj.nodes[e]));0==t&&this.emit("select",this.getSelection())},e._getSelectedNodeCount=function(){var t=0;for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&(t+=1);return t},e._getSelectedNode=function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t))return this.selectionObj.nodes[t];return null},e._getSelectedEdge=function(){for(var t in this.selectionObj.edges)if(this.selectionObj.edges.hasOwnProperty(t))return this.selectionObj.edges[t];return null},e._getSelectedEdgeCount=function(){var t=0;for(var e in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(e)&&(t+=1);return t},e._getSelectedObjectCount=function(){var t=0;for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&(t+=1);for(var i in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(i)&&(t+=1);return t},e._selectionIsEmpty=function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t))return!1;for(var e in this.selectionObj.edges)if(this.selectionObj.edges.hasOwnProperty(e))return!1;return!0},e._clusterInSelection=function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t)&&this.selectionObj.nodes[t].clusterSize>1)return!0;return!1},e._selectConnectedEdges=function(t){for(var e=0;ei;i++){o=t[i];var n=this.nodes[o];if(!n)throw new RangeError('Node with id "'+o+'" not found');this._selectObject(n,!0,!0,e,!0)}this.redraw()},e.selectEdges=function(t){var e,i,s;if(!t||void 0==t.length)throw"Selection must be an array with ids";for(this._unselectAll(!0),e=0,i=t.length;i>e;e++){s=t[e];var o=this.edges[s];if(!o)throw new RangeError('Edge with id "'+s+'" not found');this._selectObject(o,!0,!0,!1,!0)}this.redraw()},e._updateSelection=function(){for(var t in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(t)&&(this.nodes.hasOwnProperty(t)||delete this.selectionObj.nodes[t]);for(var e in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(e)&&(this.edges.hasOwnProperty(e)||delete this.selectionObj.edges[e])}},function(t,e,i){var s=i(1),o=i(40),n=i(37);e._clearManipulatorBar=function(){this._recursiveDOMDelete(this.manipulationDiv),this.manipulationDOM={},this._manipulationReleaseOverload=function(){},delete this.sectors.support.nodes.targetNode,delete this.sectors.support.nodes.targetViaNode,this.controlNodesActive=!1,this.freezeSimulation=!1},e._restoreOverloadedFunctions=function(){for(var t in this.cachedFunctions)this.cachedFunctions.hasOwnProperty(t)&&(this[t]=this.cachedFunctions[t],delete this.cachedFunctions[t])},e._toggleEditMode=function(){this.editMode=!this.editMode;var t=this.manipulationDiv,e=this.closeDiv,i=this.editModeDiv;1==this.editMode?(t.style.display="block",e.style.display="block",i.style.display="none",e.onclick=this._toggleEditMode.bind(this)):(t.style.display="none",e.style.display="none",i.style.display="block",e.onclick=null),this._createManipulatorBar()},e._createManipulatorBar=function(){this.boundFunction&&this.off("select",this.boundFunction);var t=this.constants.locales[this.constants.locale];if(void 0!==this.edgeBeingEdited&&(this.edgeBeingEdited._disableControlNodes(),this.edgeBeingEdited=void 0,this.selectedControlNode=null,this.controlNodesActive=!1,this._redraw()),this._restoreOverloadedFunctions(),this.freezeSimulation=!1,this.blockConnectingEdgeSelection=!1,this.forceAppendSelection=!1,this.manipulationDOM={},1==this.editMode){for(;this.manipulationDiv.hasChildNodes();)this.manipulationDiv.removeChild(this.manipulationDiv.firstChild);this.manipulationDOM.addNodeSpan=document.createElement("span"),this.manipulationDOM.addNodeSpan.className="network-manipulationUI add",this.manipulationDOM.addNodeLabelSpan=document.createElement("span"),this.manipulationDOM.addNodeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.addNodeLabelSpan.innerHTML=t.addNode,this.manipulationDOM.addNodeSpan.appendChild(this.manipulationDOM.addNodeLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.addEdgeSpan=document.createElement("span"),this.manipulationDOM.addEdgeSpan.className="network-manipulationUI connect",this.manipulationDOM.addEdgeLabelSpan=document.createElement("span"),this.manipulationDOM.addEdgeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.addEdgeLabelSpan.innerHTML=t.addEdge,this.manipulationDOM.addEdgeSpan.appendChild(this.manipulationDOM.addEdgeLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.addNodeSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.addEdgeSpan),1==this._getSelectedNodeCount()&&this.triggerFunctions.edit?(this.manipulationDOM.seperatorLineDiv2=document.createElement("div"),this.manipulationDOM.seperatorLineDiv2.className="network-seperatorLine",this.manipulationDOM.editNodeSpan=document.createElement("span"),this.manipulationDOM.editNodeSpan.className="network-manipulationUI edit",this.manipulationDOM.editNodeLabelSpan=document.createElement("span"),this.manipulationDOM.editNodeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.editNodeLabelSpan.innerHTML=t.editNode,this.manipulationDOM.editNodeSpan.appendChild(this.manipulationDOM.editNodeLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv2),this.manipulationDiv.appendChild(this.manipulationDOM.editNodeSpan)):1==this._getSelectedEdgeCount()&&0==this._getSelectedNodeCount()&&(this.manipulationDOM.seperatorLineDiv3=document.createElement("div"),this.manipulationDOM.seperatorLineDiv3.className="network-seperatorLine",this.manipulationDOM.editEdgeSpan=document.createElement("span"),this.manipulationDOM.editEdgeSpan.className="network-manipulationUI edit",this.manipulationDOM.editEdgeLabelSpan=document.createElement("span"),this.manipulationDOM.editEdgeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.editEdgeLabelSpan.innerHTML=t.editEdge,this.manipulationDOM.editEdgeSpan.appendChild(this.manipulationDOM.editEdgeLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv3),this.manipulationDiv.appendChild(this.manipulationDOM.editEdgeSpan)),0==this._selectionIsEmpty()&&(this.manipulationDOM.seperatorLineDiv4=document.createElement("div"),this.manipulationDOM.seperatorLineDiv4.className="network-seperatorLine",this.manipulationDOM.deleteSpan=document.createElement("span"),this.manipulationDOM.deleteSpan.className="network-manipulationUI delete",this.manipulationDOM.deleteLabelSpan=document.createElement("span"),this.manipulationDOM.deleteLabelSpan.className="network-manipulationLabel",this.manipulationDOM.deleteLabelSpan.innerHTML=t.del,this.manipulationDOM.deleteSpan.appendChild(this.manipulationDOM.deleteLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv4),this.manipulationDiv.appendChild(this.manipulationDOM.deleteSpan)),this.manipulationDOM.addNodeSpan.onclick=this._createAddNodeToolbar.bind(this),this.manipulationDOM.addEdgeSpan.onclick=this._createAddEdgeToolbar.bind(this),1==this._getSelectedNodeCount()&&this.triggerFunctions.edit?this.manipulationDOM.editNodeSpan.onclick=this._editNode.bind(this):1==this._getSelectedEdgeCount()&&0==this._getSelectedNodeCount()&&(this.manipulationDOM.editEdgeSpan.onclick=this._createEditEdgeToolbar.bind(this)),0==this._selectionIsEmpty()&&(this.manipulationDOM.deleteSpan.onclick=this._deleteSelected.bind(this)),this.closeDiv.onclick=this._toggleEditMode.bind(this); +var e=this;this.boundFunction=e._createManipulatorBar,this.on("select",this.boundFunction)}else{for(;this.editModeDiv.hasChildNodes();)this.editModeDiv.removeChild(this.editModeDiv.firstChild);this.manipulationDOM.editModeSpan=document.createElement("span"),this.manipulationDOM.editModeSpan.className="network-manipulationUI edit editmode",this.manipulationDOM.editModeLabelSpan=document.createElement("span"),this.manipulationDOM.editModeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.editModeLabelSpan.innerHTML=t.edit,this.manipulationDOM.editModeSpan.appendChild(this.manipulationDOM.editModeLabelSpan),this.editModeDiv.appendChild(this.manipulationDOM.editModeSpan),this.manipulationDOM.editModeSpan.onclick=this._toggleEditMode.bind(this)}},e._createAddNodeToolbar=function(){this._clearManipulatorBar(),this.boundFunction&&this.off("select",this.boundFunction);var t=this.constants.locales[this.constants.locale];this.manipulationDOM={},this.manipulationDOM.backSpan=document.createElement("span"),this.manipulationDOM.backSpan.className="network-manipulationUI back",this.manipulationDOM.backLabelSpan=document.createElement("span"),this.manipulationDOM.backLabelSpan.className="network-manipulationLabel",this.manipulationDOM.backLabelSpan.innerHTML=t.back,this.manipulationDOM.backSpan.appendChild(this.manipulationDOM.backLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.descriptionSpan=document.createElement("span"),this.manipulationDOM.descriptionSpan.className="network-manipulationUI none",this.manipulationDOM.descriptionLabelSpan=document.createElement("span"),this.manipulationDOM.descriptionLabelSpan.className="network-manipulationLabel",this.manipulationDOM.descriptionLabelSpan.innerHTML=t.addDescription,this.manipulationDOM.descriptionSpan.appendChild(this.manipulationDOM.descriptionLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.backSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.descriptionSpan),this.manipulationDOM.backSpan.onclick=this._createManipulatorBar.bind(this);var e=this;this.boundFunction=e._addNode,this.on("select",this.boundFunction)},e._createAddEdgeToolbar=function(){this._clearManipulatorBar(),this._unselectAll(!0),this.freezeSimulation=!0,this.boundFunction&&this.off("select",this.boundFunction);var t=this.constants.locales[this.constants.locale];this._unselectAll(),this.forceAppendSelection=!1,this.blockConnectingEdgeSelection=!0,this.manipulationDOM={},this.manipulationDOM.backSpan=document.createElement("span"),this.manipulationDOM.backSpan.className="network-manipulationUI back",this.manipulationDOM.backLabelSpan=document.createElement("span"),this.manipulationDOM.backLabelSpan.className="network-manipulationLabel",this.manipulationDOM.backLabelSpan.innerHTML=t.back,this.manipulationDOM.backSpan.appendChild(this.manipulationDOM.backLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.descriptionSpan=document.createElement("span"),this.manipulationDOM.descriptionSpan.className="network-manipulationUI none",this.manipulationDOM.descriptionLabelSpan=document.createElement("span"),this.manipulationDOM.descriptionLabelSpan.className="network-manipulationLabel",this.manipulationDOM.descriptionLabelSpan.innerHTML=t.edgeDescription,this.manipulationDOM.descriptionSpan.appendChild(this.manipulationDOM.descriptionLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.backSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.descriptionSpan),this.manipulationDOM.backSpan.onclick=this._createManipulatorBar.bind(this);var e=this;this.boundFunction=e._handleConnect,this.on("select",this.boundFunction),this.cachedFunctions._handleTouch=this._handleTouch,this.cachedFunctions._manipulationReleaseOverload=this._manipulationReleaseOverload,this.cachedFunctions._handleDragStart=this._handleDragStart,this.cachedFunctions._handleDragEnd=this._handleDragEnd,this._handleTouch=this._handleConnect,this._manipulationReleaseOverload=function(){},this._handleDragStart=function(){},this._handleDragEnd=this._finishConnect,this._redraw()},e._createEditEdgeToolbar=function(){this._clearManipulatorBar(),this.controlNodesActive=!0,this.boundFunction&&this.off("select",this.boundFunction),this.edgeBeingEdited=this._getSelectedEdge(),this.edgeBeingEdited._enableControlNodes();var t=this.constants.locales[this.constants.locale];this.manipulationDOM={},this.manipulationDOM.backSpan=document.createElement("span"),this.manipulationDOM.backSpan.className="network-manipulationUI back",this.manipulationDOM.backLabelSpan=document.createElement("span"),this.manipulationDOM.backLabelSpan.className="network-manipulationLabel",this.manipulationDOM.backLabelSpan.innerHTML=t.back,this.manipulationDOM.backSpan.appendChild(this.manipulationDOM.backLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.descriptionSpan=document.createElement("span"),this.manipulationDOM.descriptionSpan.className="network-manipulationUI none",this.manipulationDOM.descriptionLabelSpan=document.createElement("span"),this.manipulationDOM.descriptionLabelSpan.className="network-manipulationLabel",this.manipulationDOM.descriptionLabelSpan.innerHTML=t.editEdgeDescription,this.manipulationDOM.descriptionSpan.appendChild(this.manipulationDOM.descriptionLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.backSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.descriptionSpan),this.manipulationDOM.backSpan.onclick=this._createManipulatorBar.bind(this),this.cachedFunctions._handleTouch=this._handleTouch,this.cachedFunctions._manipulationReleaseOverload=this._manipulationReleaseOverload,this.cachedFunctions._handleTap=this._handleTap,this.cachedFunctions._handleDragStart=this._handleDragStart,this.cachedFunctions._handleOnDrag=this._handleOnDrag,this._handleTouch=this._selectControlNode,this._handleTap=function(){},this._handleOnDrag=this._controlNodeDrag,this._handleDragStart=function(){},this._manipulationReleaseOverload=this._releaseControlNode,this._redraw()},e._selectControlNode=function(t){this.edgeBeingEdited.controlNodes.from.unselect(),this.edgeBeingEdited.controlNodes.to.unselect(),this.selectedControlNode=this.edgeBeingEdited._getSelectedControlNode(this._XconvertDOMtoCanvas(t.x),this._YconvertDOMtoCanvas(t.y)),null!==this.selectedControlNode&&(this.selectedControlNode.select(),this.freezeSimulation=!0),this._redraw()},e._controlNodeDrag=function(t){var e=this._getPointer(t.center);null!==this.selectedControlNode&&void 0!==this.selectedControlNode&&(this.selectedControlNode.x=this._XconvertDOMtoCanvas(e.x),this.selectedControlNode.y=this._YconvertDOMtoCanvas(e.y)),this._redraw()},e._releaseControlNode=function(t){var e=this._getNodeAt(t);null!==e?(1==this.edgeBeingEdited.controlNodes.from.selected&&(this.edgeBeingEdited._restoreControlNodes(),this._editEdge(e.id,this.edgeBeingEdited.to.id),this.edgeBeingEdited.controlNodes.from.unselect()),1==this.edgeBeingEdited.controlNodes.to.selected&&(this.edgeBeingEdited._restoreControlNodes(),this._editEdge(this.edgeBeingEdited.from.id,e.id),this.edgeBeingEdited.controlNodes.to.unselect())):this.edgeBeingEdited._restoreControlNodes(),this.freezeSimulation=!1,this._redraw()},e._handleConnect=function(t){if(0==this._getSelectedNodeCount()){var e=this._getNodeAt(t);if(null!=e)if(e.clusterSize>1)alert(this.constants.locales[this.constants.locale].createEdgeError);else{this._selectObject(e,!1);var i=this.sectors.support.nodes;i.targetNode=new o({id:"targetNode"},{},{},this.constants);var s=i.targetNode;s.x=e.x,s.y=e.y,this.edges.connectionEdge=new n({id:"connectionEdge",from:e.id,to:s.id},this,this.constants);var r=this.edges.connectionEdge;r.from=e,r.connected=!0,r.options.smoothCurves={enabled:!0,dynamic:!1,type:"continuous",roundness:.5},r.selected=!0,r.to=s,this.cachedFunctions._handleOnDrag=this._handleOnDrag,this._handleOnDrag=function(t){var e=this._getPointer(t.center),i=this.edges.connectionEdge;i.to.x=this._XconvertDOMtoCanvas(e.x),i.to.y=this._YconvertDOMtoCanvas(e.y)},this.moving=!0,this.start()}}},e._finishConnect=function(t){if(1==this._getSelectedNodeCount()){var e=this._getPointer(t.center);this._handleOnDrag=this.cachedFunctions._handleOnDrag,delete this.cachedFunctions._handleOnDrag;var i=this.edges.connectionEdge.fromId;delete this.edges.connectionEdge,delete this.sectors.support.nodes.targetNode,delete this.sectors.support.nodes.targetViaNode;var s=this._getNodeAt(e);null!=s&&(s.clusterSize>1?alert(this.constants.locales[this.constants.locale].createEdgeError):(this._createEdge(i,s.id),this._createManipulatorBar())),this._unselectAll()}},e._addNode=function(){if(this._selectionIsEmpty()&&1==this.editMode){var t=this._pointerToPositionObject(this.pointerPosition),e={id:s.randomUUID(),x:t.left,y:t.top,label:"new",allowedToMoveX:!0,allowedToMoveY:!0};if(this.triggerFunctions.add){if(2!=this.triggerFunctions.add.length)throw new Error("The function for add does not support two arguments (data,callback)");var i=this;this.triggerFunctions.add(e,function(t){i.nodesData.add(t),i._createManipulatorBar(),i.moving=!0,i.start()})}else this.nodesData.add(e),this._createManipulatorBar(),this.moving=!0,this.start()}},e._createEdge=function(t,e){if(1==this.editMode){var i={from:t,to:e};if(this.triggerFunctions.connect){if(2!=this.triggerFunctions.connect.length)throw new Error("The function for connect does not support two arguments (data,callback)");var s=this;this.triggerFunctions.connect(i,function(t){s.edgesData.add(t),s.moving=!0,s.start()})}else this.edgesData.add(i),this.moving=!0,this.start()}},e._editEdge=function(t,e){if(1==this.editMode){var i={id:this.edgeBeingEdited.id,from:t,to:e};if(this.triggerFunctions.editEdge){if(2!=this.triggerFunctions.editEdge.length)throw new Error("The function for edit does not support two arguments (data, callback)");var s=this;this.triggerFunctions.editEdge(i,function(t){s.edgesData.update(t),s.moving=!0,s.start()})}else this.edgesData.update(i),this.moving=!0,this.start()}},e._editNode=function(){if(!this.triggerFunctions.edit||1!=this.editMode)throw new Error("No edit function has been bound to this button");var t=this._getSelectedNode(),e={id:t.id,label:t.label,group:t.options.group,shape:t.options.shape,color:{background:t.options.color.background,border:t.options.color.border,highlight:{background:t.options.color.highlight.background,border:t.options.color.highlight.border}}};if(2!=this.triggerFunctions.edit.length)throw new Error("The function for edit does not support two arguments (data, callback)");var i=this;this.triggerFunctions.edit(e,function(t){i.nodesData.update(t),i._createManipulatorBar(),i.moving=!0,i.start()})},e._deleteSelected=function(){if(!this._selectionIsEmpty()&&1==this.editMode)if(this._clusterInSelection())alert(this.constants.locales[this.constants.locale].deleteClusterError);else{var t=this.getSelectedNodes(),e=this.getSelectedEdges();if(this.triggerFunctions.del){var i=this,s={nodes:t,edges:e};if(2!=this.triggerFunctions.del.length)throw new Error("The function for delete does not support two arguments (data, callback)");this.triggerFunctions.del(s,function(t){i.edgesData.remove(t.edges),i.nodesData.remove(t.nodes),i._unselectAll(),i.moving=!0,i.start()})}else this.edgesData.remove(e),this.nodesData.remove(t),this._unselectAll(),this.moving=!0,this.start()}}},function(t,e,i){var s=(i(1),i(47)),o=i(45);e._cleanNavigation=function(){if(0!=this.navigationHammers.existing.length){for(var t=0;t0){var t,e,i=0,s=!1,o=!1;for(e in this.nodes)this.nodes.hasOwnProperty(e)&&(t=this.nodes[e],-1!=t.level?s=!0:o=!0,is&&(n.xFixed=!1,n.x=i[n.level].minPos,r=!0):n.yFixed&&n.level>s&&(n.yFixed=!1,n.y=i[n.level].minPos,r=!0),1==r&&(i[n.level].minPos+=i[n.level].nodeSpacing,n.edges.length>1&&this._placeBranchNodes(n.edges,n.id,i,n.level))}},e._setLevel=function(t,e,i){for(var s=0;st)&&(o.level=t,o.edges.length>1&&this._setLevel(t+1,o.edges,o.id))}},e._setLevelDirected=function(t,e,i){this.nodes[i].hierarchyEnumerated=!0;for(var s,o,n=0;n1&&s.hierarchyEnumerated===!1&&this._setLevelDirected(s.level,s.edges,s.id)},e._restoreNodes=function(){for(var t in this.nodes)this.nodes.hasOwnProperty(t)&&(this.nodes[t].xFixed=!1,this.nodes[t].yFixed=!1)}},function(t,e,i){function s(){this.constants.smoothCurves.enabled=!this.constants.smoothCurves.enabled;var t=document.getElementById("graph_toggleSmooth");t.style.background=1==this.constants.smoothCurves.enabled?"#A4FF56":"#FF8532",this._configureSmoothCurves(!1)}function o(){for(var t in this.calculationNodes)this.calculationNodes.hasOwnProperty(t)&&(this.calculationNodes[t].vx=0,this.calculationNodes[t].vy=0,this.calculationNodes[t].fx=0,this.calculationNodes[t].fy=0);1==this.constants.hierarchicalLayout.enabled?(this._setupHierarchicalLayout(),a.call(this,"graph_H_nd",1,"physics_hierarchicalRepulsion_nodeDistance"),a.call(this,"graph_H_cg",1,"physics_centralGravity"),a.call(this,"graph_H_sc",1,"physics_springConstant"),a.call(this,"graph_H_sl",1,"physics_springLength"),a.call(this,"graph_H_damp",1,"physics_damping")):this.repositionNodes(),this.moving=!0,this.start()}function n(){var t="No options are required, default values used.",e=[],i=document.getElementById("graph_physicsMethod1"),s=document.getElementById("graph_physicsMethod2");if(1==i.checked){if(this.constants.physics.barnesHut.gravitationalConstant!=this.backupConstants.physics.barnesHut.gravitationalConstant&&e.push("gravitationalConstant: "+this.constants.physics.barnesHut.gravitationalConstant),this.constants.physics.centralGravity!=this.backupConstants.physics.barnesHut.centralGravity&&e.push("centralGravity: "+this.constants.physics.centralGravity),this.constants.physics.springLength!=this.backupConstants.physics.barnesHut.springLength&&e.push("springLength: "+this.constants.physics.springLength),this.constants.physics.springConstant!=this.backupConstants.physics.barnesHut.springConstant&&e.push("springConstant: "+this.constants.physics.springConstant),this.constants.physics.damping!=this.backupConstants.physics.barnesHut.damping&&e.push("damping: "+this.constants.physics.damping),0!=e.length){t="var options = {",t+="physics: {barnesHut: {";for(var o=0;othis.constants.clustering.clusterThreshold&&1==this.constants.clustering.enabled&&this.clusterToFit(this.constants.clustering.reduceToNodes,!1),this._calculateForces())},e._calculateForces=function(){this._calculateGravitationalForces(),this._calculateNodeForces(),this.constants.physics.springConstant>0&&(1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic?this._calculateSpringForcesWithSupport():1==this.constants.physics.hierarchicalRepulsion.enabled?this._calculateHierarchicalSpringForces():this._calculateSpringForces())},e._updateCalculationNodes=function(){if(1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic){this.calculationNodes={},this.calculationNodeIndices=[];for(var t in this.nodes)this.nodes.hasOwnProperty(t)&&(this.calculationNodes[t]=this.nodes[t]);var e=this.sectors.support.nodes;for(var i in e)e.hasOwnProperty(i)&&(this.edges.hasOwnProperty(e[i].parentEdgeId)?this.calculationNodes[i]=e[i]:e[i]._setForce(0,0));for(var s in this.calculationNodes)this.calculationNodes.hasOwnProperty(s)&&this.calculationNodeIndices.push(s)}else this.calculationNodes=this.nodes,this.calculationNodeIndices=this.nodeIndices},e._calculateGravitationalForces=function(){var t,e,i,s,o,n=this.calculationNodes,r=this.constants.physics.centralGravity,a=0;for(o=0;oSimulation Mode:Barnes HutRepulsionHierarchical
Options:
',this.containerElement.parentElement.insertBefore(this.physicsConfiguration,this.containerElement),this.optionsDiv=document.createElement("div"),this.optionsDiv.style.fontSize="14px",this.optionsDiv.style.fontFamily="verdana",this.containerElement.parentElement.insertBefore(this.optionsDiv,this.containerElement); +var e;e=document.getElementById("graph_BH_gc"),e.onchange=a.bind(this,"graph_BH_gc",-1,"physics_barnesHut_gravitationalConstant"),e=document.getElementById("graph_BH_cg"),e.onchange=a.bind(this,"graph_BH_cg",1,"physics_centralGravity"),e=document.getElementById("graph_BH_sc"),e.onchange=a.bind(this,"graph_BH_sc",1,"physics_springConstant"),e=document.getElementById("graph_BH_sl"),e.onchange=a.bind(this,"graph_BH_sl",1,"physics_springLength"),e=document.getElementById("graph_BH_damp"),e.onchange=a.bind(this,"graph_BH_damp",1,"physics_damping"),e=document.getElementById("graph_R_nd"),e.onchange=a.bind(this,"graph_R_nd",1,"physics_repulsion_nodeDistance"),e=document.getElementById("graph_R_cg"),e.onchange=a.bind(this,"graph_R_cg",1,"physics_centralGravity"),e=document.getElementById("graph_R_sc"),e.onchange=a.bind(this,"graph_R_sc",1,"physics_springConstant"),e=document.getElementById("graph_R_sl"),e.onchange=a.bind(this,"graph_R_sl",1,"physics_springLength"),e=document.getElementById("graph_R_damp"),e.onchange=a.bind(this,"graph_R_damp",1,"physics_damping"),e=document.getElementById("graph_H_nd"),e.onchange=a.bind(this,"graph_H_nd",1,"physics_hierarchicalRepulsion_nodeDistance"),e=document.getElementById("graph_H_cg"),e.onchange=a.bind(this,"graph_H_cg",1,"physics_centralGravity"),e=document.getElementById("graph_H_sc"),e.onchange=a.bind(this,"graph_H_sc",1,"physics_springConstant"),e=document.getElementById("graph_H_sl"),e.onchange=a.bind(this,"graph_H_sl",1,"physics_springLength"),e=document.getElementById("graph_H_damp"),e.onchange=a.bind(this,"graph_H_damp",1,"physics_damping"),e=document.getElementById("graph_H_direction"),e.onchange=a.bind(this,"graph_H_direction",t,"hierarchicalLayout_direction"),e=document.getElementById("graph_H_levsep"),e.onchange=a.bind(this,"graph_H_levsep",1,"hierarchicalLayout_levelSeparation"),e=document.getElementById("graph_H_nspac"),e.onchange=a.bind(this,"graph_H_nspac",1,"hierarchicalLayout_nodeSpacing");var i=document.getElementById("graph_physicsMethod1"),d=document.getElementById("graph_physicsMethod2"),l=document.getElementById("graph_physicsMethod3");d.checked=!0,this.constants.physics.barnesHut.enabled&&(i.checked=!0),this.constants.hierarchicalLayout.enabled&&(l.checked=!0);var c=document.getElementById("graph_toggleSmooth"),p=document.getElementById("graph_repositionNodes"),u=document.getElementById("graph_generateOptions");c.onclick=s.bind(this),p.onclick=o.bind(this),u.onclick=n.bind(this),c.style.background=1==this.constants.smoothCurves&&0==this.constants.dynamicSmoothCurves?"#A4FF56":"#FF8532",r.apply(this),i.onchange=r.bind(this),d.onchange=r.bind(this),l.onchange=r.bind(this)}},e._overWriteGraphConstants=function(t,e){var i=t.split("_");1==i.length?this.constants[i[0]]=e:2==i.length?this.constants[i[0]][i[1]]=e:3==i.length&&(this.constants[i[0]][i[1]][i[2]]=e)}},function(t){function e(t){throw new Error("Cannot find module '"+t+"'.")}e.keys=function(){return[]},e.resolve=e,t.exports=e,e.id=68},function(t,e){e._calculateNodeForces=function(){var t,e,i,s,o,n,r,a,h,d,l,c=this.calculationNodes,p=this.calculationNodeIndices,u=-2/3,m=4/3,f=this.constants.physics.repulsion.nodeDistance,g=f;for(d=0;di&&(r=.5*g>i?1:v*i+m,r*=0==n?1:1+n*this.constants.clustering.forceAmplification,r/=Math.max(i,.01*g),s=t*r,o=e*r,a.fx-=s,a.fy-=o,h.fx+=s,h.fy+=o)}}},function(t,e){e._calculateNodeForces=function(){var t,e,i,s,o,n,r,a,h,d,l=this.calculationNodes,c=this.calculationNodeIndices,p=this.constants.physics.hierarchicalRepulsion.nodeDistance;for(h=0;hi?-Math.pow(u*i,2)+Math.pow(u*p,2):0,0==i?i=.01:n/=i,s=t*n,o=e*n,r.fx-=s,r.fy-=o,a.fx+=s,a.fy+=o}},e._calculateHierarchicalSpringForces=function(){for(var t,e,i,s,o,n,r,a,h,d=this.edges,l=this.calculationNodes,c=this.calculationNodeIndices,p=0;pn;n++)t=e[i[n]],t.options.mass>0&&(this._getForceContribution(o.root.children.NW,t),this._getForceContribution(o.root.children.NE,t),this._getForceContribution(o.root.children.SW,t),this._getForceContribution(o.root.children.SE,t))}},e._getForceContribution=function(t,e){if(t.childrenCount>0){var i,s,o;if(i=t.centerOfMass.x-e.x,s=t.centerOfMass.y-e.y,o=Math.sqrt(i*i+s*s),o*t.calcSize>this.constants.physics.barnesHut.thetaInverted){0==o&&(o=.1*Math.random(),i=o);var n=this.constants.physics.barnesHut.gravitationalConstant*t.mass*e.options.mass/(o*o*o),r=i*n,a=s*n;e.fx+=r,e.fy+=a}else if(4==t.childrenCount)this._getForceContribution(t.children.NW,e),this._getForceContribution(t.children.NE,e),this._getForceContribution(t.children.SW,e),this._getForceContribution(t.children.SE,e);else if(t.children.data.id!=e.id){0==o&&(o=.5*Math.random(),i=o);var n=this.constants.physics.barnesHut.gravitationalConstant*t.mass*e.options.mass/(o*o*o),r=i*n,a=s*n;e.fx+=r,e.fy+=a}}},e._formBarnesHutTree=function(t,e){for(var i,s=e.length,o=Number.MAX_VALUE,n=Number.MAX_VALUE,r=-Number.MAX_VALUE,a=-Number.MAX_VALUE,h=0;s>h;h++){var d=t[e[h]].x,l=t[e[h]].y;t[e[h]].options.mass>0&&(o>d&&(o=d),d>r&&(r=d),n>l&&(n=l),l>a&&(a=l))}var c=Math.abs(r-o)-Math.abs(a-n);c>0?(n-=.5*c,a+=.5*c):(o+=.5*c,r-=.5*c);var p=1e-5,u=Math.max(p,Math.abs(r-o)),m=.5*u,f=.5*(o+r),g=.5*(n+a),v={root:{centerOfMass:{x:0,y:0},mass:0,range:{minX:f-m,maxX:f+m,minY:g-m,maxY:g+m},size:u,calcSize:1/u,children:{data:null},maxWidth:0,level:0,childrenCount:4}};for(this._splitBranch(v.root),h=0;s>h;h++)i=t[e[h]],i.options.mass>0&&this._placeInTree(v.root,i);this.barnesHutTree=v},e._updateBranchMass=function(t,e){var i=t.mass+e.options.mass,s=1/i;t.centerOfMass.x=t.centerOfMass.x*t.mass+e.x*e.options.mass,t.centerOfMass.x*=s,t.centerOfMass.y=t.centerOfMass.y*t.mass+e.y*e.options.mass,t.centerOfMass.y*=s,t.mass=i;var o=Math.max(Math.max(e.height,e.radius),e.width);t.maxWidth=t.maxWidthe.x?t.children.NW.range.maxY>e.y?this._placeInRegion(t,e,"NW"):this._placeInRegion(t,e,"SW"):t.children.NW.range.maxY>e.y?this._placeInRegion(t,e,"NE"):this._placeInRegion(t,e,"SE")},e._placeInRegion=function(t,e,i){switch(t.children[i].childrenCount){case 0:t.children[i].children.data=e,t.children[i].childrenCount=1,this._updateBranchMass(t.children[i],e);break;case 1:t.children[i].children.data.x==e.x&&t.children[i].children.data.y==e.y?(e.x+=Math.random(),e.y+=Math.random()):(this._splitBranch(t.children[i]),this._placeInTree(t.children[i],e));break;case 4:this._placeInTree(t.children[i],e)}},e._splitBranch=function(t){var e=null;1==t.childrenCount&&(e=t.children.data,t.mass=0,t.centerOfMass.x=0,t.centerOfMass.y=0),t.childrenCount=4,t.children.data=null,this._insertRegion(t,"NW"),this._insertRegion(t,"NE"),this._insertRegion(t,"SW"),this._insertRegion(t,"SE"),null!=e&&this._placeInTree(t,e)},e._insertRegion=function(t,e){var i,s,o,n,r=.5*t.size;switch(e){case"NW":i=t.range.minX,s=t.range.minX+r,o=t.range.minY,n=t.range.minY+r;break;case"NE":i=t.range.minX+r,s=t.range.maxX,o=t.range.minY,n=t.range.minY+r;break;case"SW":i=t.range.minX,s=t.range.minX+r,o=t.range.minY+r,n=t.range.maxY;break;case"SE":i=t.range.minX+r,s=t.range.maxX,o=t.range.minY+r,n=t.range.maxY}t.children[e]={centerOfMass:{x:0,y:0},mass:0,range:{minX:i,maxX:s,minY:o,maxY:n},size:.5*t.size,calcSize:2*t.calcSize,children:{data:null},maxWidth:0,level:t.level+1,childrenCount:0}},e._drawTree=function(t,e){void 0!==this.barnesHutTree&&(t.lineWidth=1,this._drawBranch(this.barnesHutTree.root,t,e))},e._drawBranch=function(t,e,i){void 0===i&&(i="#FF0000"),4==t.childrenCount&&(this._drawBranch(t.children.NW,e),this._drawBranch(t.children.NE,e),this._drawBranch(t.children.SE,e),this._drawBranch(t.children.SW,e)),e.strokeStyle=i,e.beginPath(),e.moveTo(t.range.minX,t.range.minY),e.lineTo(t.range.maxX,t.range.minY),e.stroke(),e.beginPath(),e.moveTo(t.range.maxX,t.range.minY),e.lineTo(t.range.maxX,t.range.maxY),e.stroke(),e.beginPath(),e.moveTo(t.range.maxX,t.range.maxY),e.lineTo(t.range.minX,t.range.maxY),e.stroke(),e.beginPath(),e.moveTo(t.range.minX,t.range.maxY),e.lineTo(t.range.minX,t.range.minY),e.stroke()}},function(t){t.exports=function(t){return t.webpackPolyfill||(t.deprecate=function(){},t.paths=[],t.children=[],t.webpackPolyfill=1),t}},function(t,e){(function(e){t.exports=e}).call(e,{})}])}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f23ba9dba167..04f3070d25b4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -85,17 +85,13 @@ table.sortable td { filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } -span.kill-link { +a.kill-link { margin-right: 2px; margin-left: 20px; color: gray; float: right; } -span.kill-link a { - color: gray; -} - span.expand-details { font-size: 10pt; cursor: pointer; @@ -103,15 +99,25 @@ span.expand-details { float: right; } +span.rest-uri { + font-size: 10pt; + font-style: italic; + color: gray; +} + pre { - font-size: 0.8em; + font-size: 12px; + line-height: 18px; + padding: 6px; + margin: 0; + border-radius: 3px; } .stage-details { max-height: 100px; overflow-y: auto; margin: 0; - transition: max-height 0.5s ease-out, padding 0.5s ease-out; + transition: max-height 0.25s ease-out, padding 0.25s ease-out; } .stage-details.collapsed { @@ -129,11 +135,19 @@ pre { display: block; } +.description-input-full { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: normal; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; margin: 0; - transition: max-height 0.5s ease-out, padding 0.5s ease-out; + transition: max-height 0.25s ease-out, padding 0.25s ease-out; } .stacktrace-details.collapsed { @@ -143,7 +157,7 @@ pre { border: none; } -span.expand-additional-metrics { +span.expand-additional-metrics, span.expand-dag-viz { cursor: pointer; } @@ -156,7 +170,7 @@ span.additional-metric-title { } .tooltip { - font-weight: normal; + font-weight: normal; } .arrow-open { @@ -164,9 +178,9 @@ span.additional-metric-title { height: 0; border-left: 5px solid transparent; border-right: 5px solid transparent; - border-top: 5px solid black; - float: left; - margin-top: 6px; + border-top: 5px solid #08c; + display: inline-block; + margin-bottom: 2px; } .arrow-closed { @@ -174,8 +188,10 @@ span.additional-metric-title { height: 0; border-top: 5px solid transparent; border-bottom: 5px solid transparent; - border-left: 5px solid black; + border-left: 5px solid #08c; display: inline-block; + margin-left: 2px; + margin-right: 3px; } .version { @@ -190,7 +206,29 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, -.getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, +.serialization_time, .getting_result_time, .peak_execution_memory { display: none; } + +.accordion-inner { + background: #f5f5f5; +} + +.accordion-inner pre { + border: 0; + padding: 0; + background: none; +} + +a.expandbutton { + cursor: pointer; +} + +.executor-thread { + background: #E6E6E6; +} + +.non-executor-thread { + background: #FAFAFA; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f31bfba3f8d..c39c8667d013 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,11 +18,11 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} -import java.util.concurrent.atomic.AtomicLong -import java.lang.ThreadLocal import scala.collection.generic.Growable -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable +import scala.ref.WeakReference import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -40,25 +40,44 @@ import org.apache.spark.util.Utils * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` * @param name human-readable name for use in Spark's web UI + * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported + * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be + * thread safe so that they can be reported correctly. * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] ( +class Accumulable[R, T] private[spark] ( @transient initialValue: R, param: AccumulableParam[R, T], - val name: Option[String]) + val name: Option[String], + internal: Boolean) extends Serializable { + private[spark] def this( + @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { + this(initialValue, param, None, internal) + } + + def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false) + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) val id: Long = Accumulators.newId - @transient private var value_ = initialValue // Current value on master + @volatile @transient private var value_ : R = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false - Accumulators.register(this, true) + Accumulators.register(this) + + /** + * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver + * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be + * reported correctly. + */ + private[spark] def isInternal: Boolean = internal /** * Add more data to this accumulator / accumulable @@ -108,7 +127,7 @@ class Accumulable[R, T] ( * The typical use of this method is to directly mutate the local value, eg., to add * an element to a Set. */ - def localValue = value_ + def localValue: R = value_ /** * Set the accumulator's value; only allowed on master. @@ -133,10 +152,18 @@ class Accumulable[R, T] ( in.defaultReadObject() value_ = zero deserialized = true - Accumulators.register(this, false) + // Automatically register the accumulator when it is deserialized with the task closure. + // + // Note internal accumulators sent with task are deserialized before the TaskContext is created + // and are registered in the TaskContext constructor. Other internal accumulators, such SQL + // metrics, still need to register here. + val taskContext = TaskContext.get() + if (taskContext != null) { + taskContext.registerAccumulator(this) + } } - override def toString = if (value_ == null) "null" else value_.toString + override def toString: String = if (value_ == null) "null" else value_.toString } /** @@ -228,10 +255,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T,T](initialValue, param, name) { +class Accumulator[T] private[spark] ( + @transient private[spark] val initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { + + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } - def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } } /** @@ -256,22 +293,22 @@ object AccumulatorParam { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // TODO: Add AccumulatorParams for other types, e.g. lists and strings @@ -279,52 +316,84 @@ object AccumulatorParam { // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right -private[spark] object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() - } - var lastId: Long = 0 +private[spark] object Accumulators extends Logging { + /** + * This global map holds the original accumulator objects that are created on the driver. + * It keeps weak references to these objects so that accumulators can be garbage-collected + * once the RDDs and user-code that reference them are cleaned up. + */ + val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() + + private var lastId: Long = 0 def newId(): Long = synchronized { lastId += 1 lastId } - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = a - } else { - localAccums.get()(a.id) = a - } + def register(a: Accumulable[_, _]): Unit = synchronized { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) } - // Clear the local (non-original) accumulators for the current thread - def clear() { + def remove(accId: Long) { synchronized { - localAccums.get.clear + originals.remove(accId) } } - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue - } - return ret - } - // Add values to the original accumulators with some given IDs def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value + case None => + throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") + } + } else { + logWarning(s"Ignoring accumulator update for unknown accumulator id $id") } } } - def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) - def stringifyValue(value: Any) = "%s".format(value) +} + +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } + + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators + } } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 3b684bbeceaf..289aab9bd9e5 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") @@ -45,7 +45,7 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kv: Product2[K, V] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) @@ -58,12 +58,7 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } @@ -76,7 +71,7 @@ case class Aggregator[K, V, C] ( : Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kc: Product2[K, C] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 @@ -88,17 +83,19 @@ case class Aggregator[K, V, C] ( combiners.iterator } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - while (iter.hasNext) { - val pair = iter.next() - combiners.insert(pair._1, pair._2) - } - // Update task metrics if context is not null - // TODO: Make context non-optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + combiners.insertAll(iter) + updateMetrics(context, combiners) combiners.iterator } } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + } + } } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index a0c0372b7f0e..4d20c7369376 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,13 +44,17 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - val inputMetrics = blockResult.inputMetrics val existingMetrics = context.taskMetrics - .getInputMetricsForReadMethod(inputMetrics.readMethod) - existingMetrics.addBytesRead(inputMetrics.bytesRead) - - new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) - + .getInputMetricsForReadMethod(blockResult.readMethod) + existingMetrics.incBytesRead(blockResult.bytes) + + val iter = blockResult.data.asInstanceOf[Iterator[T]] + new InterruptibleIterator[T](context, iter) { + override def next(): T = { + existingMetrics.incRecordsRead(1) + delegate.next() + } + } case None => // Acquire a lock for loading this partition // If another thread already holds the lock, wait for it to finish return its results diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4fcc..d23c1533db75 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.util.Utils /** @@ -32,6 +32,8 @@ private sealed trait CleanupTask private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private case class CleanAccum(accId: Long) extends CleanupTask +private case class CleanCheckpoint(rddId: Int) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -93,68 +95,95 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { @volatile private var stopped = false /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { + def attachListener(listener: CleanerListener): Unit = { listeners += listener } /** Start the cleaner. */ - def start() { + def start(): Unit = { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() } - /** Stop the cleaner. */ - def stop() { + /** + * Stop the cleaning thread and wait until the thread has finished running its current task. + */ + def stop(): Unit = { stopped = true + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkContext, + // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). + synchronized { + cleaningThread.interrupt() + } + cleaningThread.join() } /** Register a RDD for cleanup when it is garbage collected. */ - def registerRDDForCleanup(rdd: RDD[_]) { + def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } + def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + registerForCleanup(a, CleanAccum(a.id)) + } + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]): Unit = { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } /** Register a Broadcast for cleanup when it is garbage collected. */ - def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]): Unit = { registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } + /** Register a RDDCheckpointData for cleanup when it is garbage collected. */ + def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = { + registerForCleanup(rdd, CleanCheckpoint(parentId)) + } + /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } /** Keep cleaning RDD, shuffle, and broadcast state. */ - private def keepCleaning(): Unit = Utils.logUncaughtExceptions { + private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) { while (!stopped) { try { val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) .map(_.asInstanceOf[CleanupTaskWeakReference]) - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get - task match { - case CleanRDD(rddId) => - doCleanupRDD(rddId, blocking = blockOnCleanupTasks) - case CleanShuffle(shuffleId) => - doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) - case CleanBroadcast(broadcastId) => - doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) + case CleanCheckpoint(rddId) => + doCleanCheckpoint(rddId) + } } } } catch { + case ie: InterruptedException if stopped => // ignore case e: Exception => logError("Error in cleaning thread", e) } } } /** Perform RDD cleanup. */ - def doCleanupRDD(rddId: Int, blocking: Boolean) { + def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) @@ -166,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform shuffle cleanup, asynchronously. */ - def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) @@ -179,17 +208,45 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean): Unit = { try { - logDebug("Cleaning broadcast " + broadcastId) + logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) listeners.foreach(_.broadcastCleaned(broadcastId)) - logInfo("Cleaned broadcast " + broadcastId) + logDebug(s"Cleaned broadcast $broadcastId") } catch { case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) } } + /** Perform accumulator cleanup. */ + def doCleanupAccum(accId: Long, blocking: Boolean): Unit = { + try { + logDebug("Cleaning accumulator " + accId) + Accumulators.remove(accId) + listeners.foreach(_.accumCleaned(accId)) + logInfo("Cleaned accumulator " + accId) + } catch { + case e: Exception => logError("Error cleaning accumulator " + accId, e) + } + } + + /** + * Clean up checkpoint files written to a reliable storage. + * Locally checkpointed files are cleaned up separately through RDD cleanups. + */ + def doCleanCheckpoint(rddId: Int): Unit = { + try { + logDebug("Cleaning rdd checkpoint data " + rddId) + ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) + listeners.foreach(_.checkpointCleaned(rddId)) + logInfo("Cleaned rdd checkpoint data " + rddId) + } + catch { + case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -206,4 +263,6 @@ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) + def accumCleaned(accId: Long) + def checkpointCleaned(rddId: Long) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 9a7cd4523e5a..fc8cdde9348e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -74,7 +74,7 @@ class ShuffleDependency[K, V, C]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { - override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] + override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] val shuffleId: Int = _rdd.context.newShuffleId() @@ -91,7 +91,7 @@ class ShuffleDependency[K, V, C]( */ @DeveloperApi class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) + override def getParents(partitionId: Int): List[Int] = List(partitionId) } @@ -107,7 +107,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) } else { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index a46a81eabd96..842bfdbadc94 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -19,24 +19,44 @@ package org.apache.spark /** * A client that communicates with the cluster manager to request or kill executors. + * This is currently supported only in YARN mode. */ private[spark] trait ExecutorAllocationClient { + /** + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean + /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def requestExecutors(numAdditionalExecutors: Int): Boolean /** * Request that the cluster manager kill the specified executors. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutors(executorIds: Seq[String]): Boolean /** * Request that the cluster manager kill the specified executor. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 5d5288bb6e60..b93536e6536e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,20 +17,34 @@ package org.apache.spark +import java.util.concurrent.TimeUnit + import scala.collection.mutable +import scala.util.control.ControlThrowable + +import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.scheduler._ +import org.apache.spark.metrics.source.Source +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. * - * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If - * the scheduler queue is not drained in N seconds, then new executors are added. If the queue - * persists for another M seconds, then more executors are added and so on. The number added - * in each round increases exponentially from the previous round until an upper bound on the - * number of executors has been reached. The upper bound is based both on a configured property - * and on the number of tasks pending: the policy will never increase the number of executor - * requests past the number needed to handle all pending tasks. + * The ExecutorAllocationManager maintains a moving target number of executors which is periodically + * synced to the cluster manager. The target starts at a configured initial value and changes with + * the number of pending and running tasks. + * + * Decreasing the target number of executors happens when the current target is more than needed to + * handle the current load. The target number of executors is always truncated to the number of + * executors that could run all current running and pending tasks at once. + * + * Increasing the target number of executors happens in response to backlogged tasks waiting to be + * scheduled. If the scheduler queue is not drained in N seconds, then new executors are added. If + * the queue persists for another M seconds, then more executors are added and so on. The number + * added in each round increases exponentially from the previous round until an upper bound has been + * reached. The upper bound is based both on a configured property and on the current number of + * running and pending tasks, as described above. * * The rationale for the exponential increase is twofold: (1) Executors should be added slowly * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, @@ -76,17 +90,20 @@ private[spark] class ExecutorAllocationManager( private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Integer.MAX_VALUE) - // How long there must be backlogged tasks for before an addition is triggered - private val schedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.schedulerBacklogTimeout", 60) + // How long there must be backlogged tasks for before an addition is triggered (seconds) + private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.schedulerBacklogTimeout", "1s") + + // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded + private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", s"${schedulerBacklogTimeoutS}s") - // Same as above, but used only after `schedulerBacklogTimeout` is exceeded - private val sustainedSchedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) + // How long an executor must be idle for before it is removed (seconds) + private val executorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.executorIdleTimeout", "60s") - // How long an executor must be idle for before it is removed - private val executorIdleTimeout = conf.getLong( - "spark.dynamicAllocation.executorIdleTimeout", 600) + private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -102,8 +119,10 @@ private[spark] class ExecutorAllocationManager( // Number of executors to add in the next round private var numExecutorsToAdd = 1 - // Number of executors that have been requested but have not registered yet - private var numExecutorsPending = 0 + // The desired number of executors at this moment in time. If all our executors were to die, this + // is the number of executors we would immediately want from the cluster manager. + private var numExecutorsTarget = + conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -123,11 +142,31 @@ private[spark] class ExecutorAllocationManager( private val intervalMillis: Long = 100 // Clock used to schedule when executors should be added and removed - private var clock: Clock = new RealClock + private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-dynamic-executor-allocation") + + // Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem. + val executorAllocationManagerSource = new ExecutorAllocationManagerSource + + // Whether we are still waiting for the initial set of executors to be allocated. + // While this is true, we will not cancel outstanding executor requests. This is + // set to false when: + // (1) a stage is submitted, or + // (2) an executor idle timeout has elapsed. + @volatile private var initializing: Boolean = true + + // Number of locality aware tasks, used for executor placement. + private var localityAwareTasks = 0 + + // Host to possible task running on it, used for executor placement. + private var hostToLocalTaskCount: Map[String, Int] = Map.empty + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -143,14 +182,14 @@ private[spark] class ExecutorAllocationManager( throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } - if (schedulerBacklogTimeout <= 0) { + if (schedulerBacklogTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") } - if (sustainedSchedulerBacklogTimeout <= 0) { + if (sustainedSchedulerBacklogTimeoutS <= 0) { throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeout <= 0) { + if (executorIdleTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") } // Require external shuffle service for dynamic allocation @@ -172,106 +211,165 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() - } - /** - * Start the main polling thread that keeps track of when to add and remove executors. - */ - private def startPolling(): Unit = { - val t = new Thread { + val scheduleTask = new Runnable() { override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) } } } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) + } + + /** + * Stop the allocation manager. + */ + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) } /** - * If the add time has expired, request new executors and refresh the add time. - * If the remove time for an existing executor has expired, kill the executor. + * The maximum number of executors we would need under the current load to satisfy all running + * and pending tasks, rounded up. + */ + private def maxNumExecutorsNeeded(): Int = { + val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks + (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + } + + /** + * This is called at a fixed interval to regulate the number of pending executor requests + * and number of executors running. + * + * First, adjust our requested executors based on the add time and our current needs. + * Then, if the remove time for an existing executor has expired, kill the executor. + * * This is factored out into its own method for testing. */ private def schedule(): Unit = synchronized { val now = clock.getTimeMillis - if (addTime != NOT_SET && now >= addTime) { - addExecutors() - logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeout seconds)") - addTime += sustainedSchedulerBacklogTimeout * 1000 - } + + updateAndSyncNumExecutorsTarget(now) removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { + initializing = false removeExecutor(executorId) } !expired } } + /** + * Updates our target number of executors and syncs the result with the cluster manager. + * + * Check to see whether our existing allocation and the requests we've made previously exceed our + * current needs. If so, truncate our target and let the cluster manager know so that it can + * cancel pending requests that are unneeded. + * + * If not, and the add time has expired, see if we can request new executors and refresh the add + * time. + * + * @return the delta in the target number of executors. + */ + private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { + val maxNeeded = maxNumExecutorsNeeded + + if (initializing) { + // Do not change our target while we are still initializing, + // Otherwise the first job may have to ramp up unnecessarily + 0 + } else if (maxNeeded < numExecutorsTarget) { + // The target number exceeds the number we actually need, so stop adding new + // executors and inform the cluster manager to cancel the extra pending requests + val oldNumExecutorsTarget = numExecutorsTarget + numExecutorsTarget = math.max(maxNeeded, minNumExecutors) + numExecutorsToAdd = 1 + + // If the new target has not changed, avoid sending a message to the cluster manager + if (numExecutorsTarget < oldNumExecutorsTarget) { + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") + } + numExecutorsTarget - oldNumExecutorsTarget + } else if (addTime != NOT_SET && now >= addTime) { + val delta = addExecutors(maxNeeded) + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") + addTime += sustainedSchedulerBacklogTimeoutS * 1000 + delta + } else { + 0 + } + } + /** * Request a number of executors from the cluster manager. * If the cap on the number of executors is reached, give up and reset the * number of executors to add next round instead of continuing to double it. - * Return the number actually requested. + * + * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending + * tasks could fill + * @return the number of additional executors actually requested. */ - private def addExecutors(): Int = synchronized { - // Do not request more executors if we have already reached the upper bound - val numExistingExecutors = executorIds.size + numExecutorsPending - if (numExistingExecutors >= maxNumExecutors) { - logDebug(s"Not adding executors because there are already ${executorIds.size} " + - s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") + private def addExecutors(maxNumExecutorsNeeded: Int): Int = { + // Do not request more executors if it would put our target over the upper bound + if (numExecutorsTarget >= maxNumExecutors) { + logDebug(s"Not adding executors because our current target total " + + s"is already $numExecutorsTarget (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } - // The number of executors needed to satisfy all pending tasks is the number of tasks pending - // divided by the number of tasks each executor can fit, rounded up. - val maxNumExecutorsPending = - (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor - if (numExecutorsPending >= maxNumExecutorsPending) { - logDebug(s"Not adding executors because there are already $numExecutorsPending " + - s"pending and pending tasks could only fill $maxNumExecutorsPending") + val oldNumExecutorsTarget = numExecutorsTarget + // There's no point in wasting time ramping up to the number of executors we already have, so + // make sure our target is at least as much as our current allocation: + numExecutorsTarget = math.max(numExecutorsTarget, executorIds.size) + // Boost our target with the number to add for this round: + numExecutorsTarget += numExecutorsToAdd + // Ensure that our target doesn't exceed what we need at the present moment: + numExecutorsTarget = math.min(numExecutorsTarget, maxNumExecutorsNeeded) + // Ensure that our target fits within configured bounds: + numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + + val delta = numExecutorsTarget - oldNumExecutorsTarget + + // If our target has not changed, do not send a message + // to the cluster manager and reset our exponential growth + if (delta == 0) { numExecutorsToAdd = 1 return 0 } - // It's never useful to request more executors than could satisfy all the pending tasks, so - // cap request at that amount. - // Also cap request with respect to the configured upper bound. - val maxNumExecutorsToAdd = math.min( - maxNumExecutorsPending - numExecutorsPending, - maxNumExecutors - numExistingExecutors) - assert(maxNumExecutorsToAdd > 0) - - val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) - - val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd - val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd) + val addRequestAcknowledged = testing || + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) if (addRequestAcknowledged) { - logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " + - s"tasks are backlogged (new desired total will be $newTotalExecutors)") - numExecutorsToAdd = - if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1 - numExecutorsPending += actualNumExecutorsToAdd - actualNumExecutorsToAdd + val executorsString = "executor" + { if (delta > 1) "s" else "" } + logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + + s" (new desired total will be $numExecutorsTarget)") + numExecutorsToAdd = if (delta == numExecutorsToAdd) { + numExecutorsToAdd * 2 + } else { + 1 + } + delta } else { - logWarning(s"Unable to reach the cluster manager " + - s"to request $actualNumExecutorsToAdd executors!") + logWarning( + s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") 0 } } @@ -306,7 +404,7 @@ private[spark] class ExecutorAllocationManager( val removeRequestAcknowledged = testing || client.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeoutS seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -327,10 +425,6 @@ private[spark] class ExecutorAllocationManager( // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951) executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle) logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") - if (numExecutorsPending > 0) { - numExecutorsPending -= 1 - logDebug(s"Decremented number of pending executors ($numExecutorsPending left)") - } } else { logWarning(s"Duplicate executor $executorId has registered") } @@ -362,8 +456,8 @@ private[spark] class ExecutorAllocationManager( private def onSchedulerBacklogged(): Unit = synchronized { if (addTime == NOT_SET) { logDebug(s"Starting timer to add executors because pending tasks " + - s"are building up (to expire in $schedulerBacklogTimeout seconds)") - addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000 + s"are building up (to expire in $schedulerBacklogTimeoutS seconds)") + addTime = clock.getTimeMillis + schedulerBacklogTimeoutS * 1000 } } @@ -372,7 +466,7 @@ private[spark] class ExecutorAllocationManager( * This resets all variables used for adding executors. */ private def onSchedulerQueueEmpty(): Unit = synchronized { - logDebug(s"Clearing timer to add executors because there are no more pending tasks") + logDebug("Clearing timer to add executors because there are no more pending tasks") addTime = NOT_SET numExecutorsToAdd = 1 } @@ -385,9 +479,23 @@ private[spark] class ExecutorAllocationManager( private def onExecutorIdle(executorId: String): Unit = synchronized { if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + // Note that it is not necessary to query the executors since all the cached + // blocks we are concerned with are reported to the driver. Note that this + // does not include broadcast blocks. + val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val now = clock.getTimeMillis() + val timeout = { + if (hasCachedBlocks) { + // Use a different timeout if the executor has cached blocks. + now + cachedExecutorIdleTimeoutS * 1000 + } else { + now + executorIdleTimeoutS * 1000 + } + } + val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow + removeTimes(executorId) = realTimeout logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)") } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") @@ -415,13 +523,40 @@ private[spark] class ExecutorAllocationManager( private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] + // Number of tasks currently running on the cluster. Should be 0 when no stages are active. + private var numRunningTasks: Int = _ + + // stageId to tuple (the number of task with locality preferences, a map where each pair is a + // node and the number of tasks that would like to be scheduled on that node) map, + // maintain the executor placement hints for each stage Id used by resource framework to better + // place the executors. + private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])] override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + initializing = false val stageId = stageSubmitted.stageInfo.stageId val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() + + // Compute the number of tasks requested by the stage on each host + var numTasksPending = 0 + val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]() + stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality => + if (!locality.isEmpty) { + numTasksPending += 1 + locality.foreach { location => + val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1 + hostToLocalTaskCountPerStage(location.host) = count + } + } + } + stageIdToExecutorPlacementHints.put(stageId, + (numTasksPending, hostToLocalTaskCountPerStage.toMap)) + + // Update the executor placement hints + updateExecutorPlacementHints() } } @@ -430,11 +565,19 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToExecutorPlacementHints -= stageId + + // Update the executor placement hints + updateExecutorPlacementHints() // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason if (stageIdToNumTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() + if (numRunningTasks != 0) { + logWarning("No stages are running, but numRunningTasks != 0") + numRunningTasks = 0 + } } } } @@ -446,6 +589,7 @@ private[spark] class ExecutorAllocationManager( val executorId = taskStart.taskInfo.executorId allocationManager.synchronized { + numRunningTasks += 1 // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is // possible because these events are posted in different threads. (see SPARK-4951) @@ -455,14 +599,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -474,8 +612,11 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { - // If the executor is no longer running scheduled any tasks, mark it as idle + numRunningTasks -= 1 + // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId if (executorIdToTaskIds(executorId).isEmpty) { @@ -483,11 +624,21 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { - val executorId = blockManagerAdded.blockManagerId.executorId + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + val executorId = executorAdded.executorId if (executorId != SparkContext.DRIVER_IDENTIFIER) { // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is @@ -498,9 +649,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerRemoved( - blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { - allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + allocationManager.onExecutorRemoved(executorRemoved.executorId) } /** @@ -515,6 +665,11 @@ private[spark] class ExecutorAllocationManager( }.sum } + /** + * The number of tasks currently running across all stages. + */ + def totalRunningTasks(): Int = numRunningTasks + /** * Return true if an executor is not currently running a task, and false otherwise. * @@ -523,35 +678,56 @@ private[spark] class ExecutorAllocationManager( def isExecutorIdle(executorId: String): Boolean = { !executorIdToTaskIds.contains(executorId) } - } -} + /** + * Update the Executor placement hints (the number of tasks with locality preferences, + * a map where each pair is a node and the number of tasks that would like to be scheduled + * on that node). + * + * These hints are updated when stages arrive and complete, so are not up-to-date at task + * granularity within stages. + */ + def updateExecutorPlacementHints(): Unit = { + var localityAwareTasks = 0 + val localityToCount = new mutable.HashMap[String, Int]() + stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) => + localityAwareTasks += numTasksPending + localities.foreach { case (hostname, count) => + val updatedCount = localityToCount.getOrElse(hostname, 0) + count + localityToCount(hostname) = updatedCount + } + } -private object ExecutorAllocationManager { - val NOT_SET = Long.MaxValue -} + allocationManager.localityAwareTasks = localityAwareTasks + allocationManager.hostToLocalTaskCount = localityToCount.toMap + } + } -/** - * An abstract clock for measuring elapsed time. - */ -private trait Clock { - def getTimeMillis: Long -} + /** + * Metric source for ExecutorAllocationManager to expose its internal executor allocation + * status to MetricsSystem. + * Note: These metrics heavily rely on the internal implementation of + * ExecutorAllocationManager, metrics or value of metrics will be changed when internal + * implementation is changed, so these metrics are not stable across Spark version. + */ + private[spark] class ExecutorAllocationManagerSource extends Source { + val sourceName = "ExecutorAllocationManager" + val metricRegistry = new MetricRegistry() + + private def registerGauge[T](name: String, value: => T, defaultValue: T): Unit = { + metricRegistry.register(MetricRegistry.name("executors", name), new Gauge[T] { + override def getValue: T = synchronized { Option(value).getOrElse(defaultValue) } + }) + } -/** - * A clock backed by a monotonically increasing time source. - * The time returned by this clock does not correspond to any notion of wall-clock time. - */ -private class RealClock extends Clock { - override def getTimeMillis: Long = System.nanoTime / (1000 * 1000) + registerGauge("numberExecutorsToAdd", numExecutorsToAdd, 0) + registerGauge("numberExecutorsPendingToRemove", executorsPendingToRemove.size, 0) + registerGauge("numberAllExecutors", executorIds.size, 0) + registerGauge("numberTargetExecutors", numExecutorsTarget, 0) + registerGauge("numberMaxNeededExecutors", maxNumExecutorsNeeded(), 0) + } } -/** - * A clock that allows the caller to customize the time. - * This is used mainly for testing. - */ -private class TestClock(startTimeMillis: Long) extends Clock { - private var time: Long = startTimeMillis - override def getTimeMillis: Long = time - def tick(ms: Long): Unit = { time += ms } +private object ExecutorAllocationManager { + val NOT_SET = Long.MaxValue } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e97a7375a267..48792a958130 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { @@ -168,7 +168,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - def jobIds = Seq(jobWaiter.jobId) + def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } @@ -276,7 +276,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds = jobs + def jobIds: Seq[Int] = jobs } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 83ae57b7f151..ee60d697d879 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,33 +17,188 @@ package org.apache.spark -import akka.actor.Actor +import java.util.concurrent.{ScheduledFuture, TimeUnit} + +import scala.collection.mutable + import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. + * components to convey liveness or execution information for in-progress tasks. It will also + * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( executorId: String, taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId) +/** + * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is + * created. + */ +private[spark] case object TaskSchedulerIsSet + +private[spark] case object ExpireDeadHosts + +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { - - override def receiveWithLogging = { - case Heartbeat(executorId, taskMetrics, blockManagerId) => - val response = HeartbeatResponse( - !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) - sender ! response +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) + + override val rpcEnv: RpcEnv = sc.env.rpcEnv + + private[spark] var scheduler: TaskScheduler = null + + // executor ID -> timestamp of when the last heartbeat from this executor was received + private val executorLastSeen = new mutable.HashMap[String, Long] + + // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses + // "milliseconds" + private val slaveTimeoutMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") + private val executorTimeoutMs = + sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + + // "spark.network.timeoutInterval" uses "seconds", while + // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" + private val timeoutIntervalMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") + private val checkTimeoutIntervalMs = + sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 + + private var timeoutCheckingTask: ScheduledFuture[_] = null + + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") + + private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") + + override def onStart(): Unit = { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) + case TaskSchedulerIsSet => + scheduler = sc.taskScheduler + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) + + // Messages received from executors + case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + if (scheduler != null) { + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } + } else { + // Because Executor will sleep several seconds before sending the first "Heartbeat", this + // case rarely happens. However, if it really happens, log it and ask the executor to + // register itself again. + logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + + private def expireDeadHosts(): Unit = { + logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") + val now = clock.getTimeMillis() + for ((executorId, lastSeenMs) <- executorLastSeen) { + if (now - lastSeenMs > executorTimeoutMs) { + logWarning(s"Removing executor $executorId with no recent heartbeats: " + + s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms") + scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + + s"timed out after ${now - lastSeenMs} ms")) + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + // Note: we want to get an executor back after expiring this one, + // so do not simply call `sc.killExecutor` here (SPARK-8119) + sc.killAndReplaceExecutor(executorId) + } + }) + executorLastSeen.remove(executorId) + } + } + } + + override def onStop(): Unit = { + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel(true) + } + eventLoopThread.shutdownNow() + killExecutorThread.shutdownNow() + } +} + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" } diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 3f33332a81ea..7cf7bc0dc681 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,6 +50,15 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs + try { + Utils.deleteRecursively(baseDir) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: ${baseDir.getAbsolutePath}", e) + } } def addFile(file: File) : String = { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 09a9ccc22672..8de3a6c04df3 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -160,7 +160,7 @@ private[spark] class HttpServer( throw new ServerStateException("Server is not started") } else { val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" - s"$scheme://${Utils.localIpAddress}:$port" + s"$scheme://${Utils.localHostNameForURI()}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 419d093d5564..f0598816d6c0 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,14 +121,28 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + // scalastyle:off println + if (Utils.isInInterpreter) { + val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" + Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") + System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") + case None => + System.err.println(s"Spark was unable to load $replDefaultLogProps") + } + } else { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } + // scalastyle:on println } } Logging.initialized = true @@ -145,7 +159,7 @@ private object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7..a38759278385 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,16 +21,14 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await -import scala.collection.JavaConversions._ - -import akka.actor._ -import akka.pattern.ask +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.reflect.ClassTag +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage @@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size @@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." - /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. - * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) - * will ultimately remove this entire code path. */ + /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. + * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ val exception = new SparkException(msg) logError(msg, exception) - throw exception + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) } - sender ! mapOutputStatuses case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() } } @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val fetching = new HashSet[Int] /** - * Send a message to the trackerActor and get its result within a default timeout, or + * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T: ClassTag](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerEndpoint.askWithRetry[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -128,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and output sizes of the map outputs of - * a given shuffle. + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given reduce task. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. */ - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") + val startTime = System.currentTimeMillis + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -157,11 +161,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -172,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } + logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + s"${System.currentTimeMillis - startTime} ms") + if (fetchedStatuses != null) { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) @@ -289,6 +295,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + * + * This method is not thread-safe. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + + if (mapStatuses.contains(shuffleId)) { + val statuses = mapStatuses(shuffleId) + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.map(_._1).toArray) + } + } + } + None + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -328,7 +381,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def stop() { sendTracker(StopMapOutputTracker) mapStatuses.clear() - trackerActor = null + trackerEndpoint = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -345,48 +398,72 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { protected val mapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]] + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala } private[spark] object MapOutputTracker extends Logging { + val ENDPOINT_NAME = "MapOutputTracker" + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - objOut.close() out.toByteArray } // Opposite of serializeMapStatuses. def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject().asInstanceOf[Array[MapStatus]] + Utils.tryWithSafeFinally { + objIn.readObject().asInstanceOf[Array[MapStatus]] + } { + objIn.close() + } } - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. + /** + * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block + * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that + * block manager. + * + * If any of the statuses is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param reduceId Identifier for the reduce task + * @param statuses List of map statuses, indexed by map ID. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - statuses.map { - status => - if (status == null) { - logError("Missing an output location for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) - } else { - (status.location, status.getSizeForBlock(reduceId)) - } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.zipWithIndex) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + } else { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + } } + + splitsByAddress.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e53a78ead2c0..4b9d59975bdc 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,7 +56,7 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner.isDefined) { + for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { @@ -76,7 +76,9 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0 @@ -103,7 +105,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { */ class RangePartitioner[K : Ordering : ClassTag, V]( @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], + @transient rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -154,7 +156,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } - def numPartitions = rangeBounds.length + 1 + def numPartitions: Int = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] @@ -185,7 +187,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => + case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false @@ -249,7 +251,7 @@ private[spark] object RangePartitioner { * @param sampleSizePerPartition max sample size per partition * @return (total number of items, an array of (partitionId, number of items, sample)) */ - def sketch[K:ClassTag]( + def sketch[K : ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id @@ -272,7 +274,7 @@ private[spark] object RangePartitioner { * @param partitions number of partitions * @return selected bounds */ - def determineBounds[K:Ordering:ClassTag]( + def determineBounds[K : Ordering : ClassTag]( candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af..3b9c885bf97a 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -18,6 +18,10 @@ package org.apache.spark import java.io.File +import java.security.NoSuchAlgorithmException +import javax.net.ssl.SSLContext + +import scala.collection.JavaConverters._ import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +42,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +52,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +68,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -76,7 +81,6 @@ private[spark] case class SSLOptions( * object. It can be used then to compose the ultimate Akka configuration. */ def createAkkaConfig: Option[Config] = { - import scala.collection.JavaConversions._ if (enabled) { Some(ConfigFactory.empty() .withValue("akka.remote.netty.tcp.security.key-store", @@ -94,7 +98,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.asJava)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +106,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 88d35a4bacc6..673ef49e7c1c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.network.sasl.SecretKeyHolder +import org.apache.spark.util.Utils /** * Spark class responsible for security. @@ -149,8 +150,13 @@ import org.apache.spark.network.sasl.SecretKeyHolder * authorization. If not filter is in place the user is generally null and no authorization * can take place. * - * Connection encryption (SSL) configuration is organized hierarchically. The user can configure - * the default SSL settings which will be used for all the supported communication protocols unless + * When authentication is being used, encryption can also be enabled by setting the option + * spark.authenticate.enableSaslEncryption to true. This is only supported by communication + * channels that use the network-common library, and can be used as an alternative to SSL in those + * cases. + * + * SSL can be used for encryption for certain communication channels. The user can configure the + * default SSL settings which will be used for all the supported communication protocols unless * they are overwritten by protocol specific settings. This way the user can easily provide the * common settings for all the protocols without disabling the ability to configure each one * individually. @@ -186,7 +192,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" - private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) @@ -203,7 +209,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), - Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty) + Utils.getCurrentUserName()) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) @@ -359,10 +365,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) cookie } else { // user must have set spark.authenticate.secret config - sparkConf.getOption("spark.authenticate.secret") match { + // For Master/Worker, auth secret is in conf; for Executors, it is in env variable + sys.env.get(SecurityManager.ENV_AUTH_SECRET) + .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { case Some(value) => value case None => throw new Exception("Error: a secret key must be specified via the " + - "spark.authenticate.secret config") + SecurityManager.SPARK_AUTH_SECRET_CONF + " config") } } sCookie @@ -411,6 +419,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf) */ def isAuthenticationEnabled(): Boolean = authOn + /** + * Checks whether SASL encryption should be enabled. + * @return Whether to enable SASL encryption when connecting to services that support it. + */ + def isSaslEncryptionEnabled(): Boolean = { + sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false) + } + /** * Gets the user used for authenticating HTTP connections. * For now use a single hardcoded user. @@ -435,3 +451,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() } + +private[spark] object SecurityManager { + + val SPARK_AUTH_CONF: String = "spark.authenticate" + val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" + // This is used to set auth secret to an executor's env variable. It should have the same + // value as SPARK_AUTH_SECERET_CONF set in SparkConf + val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" +} diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index 55cb25946c2a..beb2e2725472 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -28,8 +28,10 @@ import org.apache.spark.util.Utils @DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString + + def value: T = t + + override def toString: String = t.toString private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { out.defaultWriteObject() @@ -39,7 +41,7 @@ class SerializableWritable[T <: Writable](@transient var t: T) extends Serializa private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() - ow.setConf(new Configuration()) + ow.setConf(new Configuration(false)) ow.readFields(in) t = ow.get().asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 13aa9960ac33..b344b5e173d6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -22,6 +22,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet +import org.apache.avro.{SchemaNormalization, Schema} + import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -67,6 +69,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } + logDeprecationWarning(key) settings.put(key, value) this } @@ -132,14 +135,16 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Set multiple parameters together */ - def setAll(settings: Traversable[(String, String)]) = { - this.settings.putAll(settings.toMap.asJava) + def setAll(settings: Traversable[(String, String)]): SparkConf = { + settings.foreach { case (k, v) => set(k, v) } this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - settings.putIfAbsent(key, value) + if (settings.putIfAbsent(key, value) == null) { + logDeprecationWarning(key) + } this } @@ -157,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private final val avroNamespace = "avro.schema." + + /** + * Use Kryo serialization and register the given set of Avro schemas so that the generic + * record serializer can decrease network IO + */ + def registerAvroSchemas(schemas: Schema*): SparkConf = { + for (schema <- schemas) { + set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString) + } + this + } + + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroNamespace) } + .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) @@ -173,9 +198,118 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getOption(key).getOrElse(defaultValue) } + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then seconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsSeconds(key: String): Long = { + Utils.timeStringAsSeconds(get(key)) + } + + /** + * Get a time parameter as seconds, falling back to a default if not set. If no + * suffix is provided then seconds are assumed. + */ + def getTimeAsSeconds(key: String, defaultValue: String): Long = { + Utils.timeStringAsSeconds(get(key, defaultValue)) + } + + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsMs(key: String): Long = { + Utils.timeStringAsMs(get(key)) + } + + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. + */ + def getTimeAsMs(key: String, defaultValue: String): Long = { + Utils.timeStringAsMs(get(key, defaultValue)) + } + + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. + */ + def getSizeAsBytes(key: String, defaultValue: Long): Long = { + Utils.byteStringAsBytes(get(key, defaultValue + "B")) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - Option(settings.get(key)) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } /** Get all parameters as a list of pairs */ @@ -255,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -285,7 +420,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate memory fractions val memoryKeys = Seq( "spark.storage.memoryFraction", - "spark.shuffle.memoryFraction", + "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") @@ -342,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** @@ -351,9 +504,93 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def toDebugString: String = { getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } + } -private[spark] object SparkConf { +private[spark] object SparkConf extends Logging { + + /** + * Maps deprecated config keys to information about the deprecation. + * + * The extra information is logged as a warning when the config is present in the user's + * configuration. + */ + private val deprecatedConfigs: Map[String, DeprecatedConfig] = { + val configs = Seq( + DeprecatedConfig("spark.cache.class", "0.8", + "The spark.cache.class property is no longer being used! Specify storage levels using " + + "the RDD.persist() method instead."), + DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) + } + + /** + * Maps a current config key to alternate keys that were used in previous version of Spark. + * + * The alternates are used in the order defined in this map. If deprecated configs are + * present in the user's configuration, a warning is logged. + */ + private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( + "spark.executor.userClassPathFirst" -> Seq( + AlternateConfig("spark.files.userClassPathFirst", "1.3")), + "spark.history.fs.update.interval" -> Seq( + AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), + AlternateConfig("spark.history.fs.updateInterval", "1.3"), + AlternateConfig("spark.history.updateInterval", "1.3")), + "spark.history.fs.cleaner.interval" -> Seq( + AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), + "spark.history.fs.cleaner.maxAge" -> Seq( + AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), + "spark.yarn.am.waitTime" -> Seq( + AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", + // Translate old value to a duration, with 10s wait time per try. + translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), + "spark.rpc.numRetries" -> Seq( + AlternateConfig("spark.akka.num.retries", "1.4")), + "spark.rpc.retry.wait" -> Seq( + AlternateConfig("spark.akka.retry.wait", "1.4")), + "spark.rpc.askTimeout" -> Seq( + AlternateConfig("spark.akka.askTimeout", "1.4")), + "spark.rpc.lookupTimeout" -> Seq( + AlternateConfig("spark.akka.lookupTimeout", "1.4")), + "spark.streaming.fileStream.minRememberDuration" -> Seq( + AlternateConfig("spark.streaming.minRememberDuration", "1.5")) + ) + + /** + * A view of `configsWithAlternatives` that makes it more efficient to look up deprecated + * config keys. + * + * Maps the deprecated config name to a 2-tuple (new config name, alternate config info). + */ + private val allAlternatives: Map[String, (String, AlternateConfig)] = { + configsWithAlternatives.keys.flatMap { key => + configsWithAlternatives(key).map { cfg => (cfg.key -> (key -> cfg)) } + }.toMap + } + /** * Return whether the given config is an akka config (e.g. akka.actor.provider). * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). @@ -369,7 +606,7 @@ private[spark] object SparkConf { def isExecutorStartupConf(name: String): Boolean = { isAkkaConf(name) || name.startsWith("spark.akka") || - name.startsWith("spark.auth") || + (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || isSparkPortConf(name) } @@ -380,4 +617,59 @@ private[spark] object SparkConf { def isSparkPortConf(name: String): Boolean = { (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") } + + /** + * Looks for available deprecated keys for the given config option, and return the first + * value available. + */ + def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + configsWithAlternatives.get(key).flatMap { alts => + alts.collectFirst { case alt if conf.contains(alt.key) => + val value = conf.get(alt.key) + if (alt.translation != null) alt.translation(value) else value + } + } + } + + /** + * Logs a warning message if the given config key is deprecated. + */ + def logDeprecationWarning(key: String): Unit = { + deprecatedConfigs.get(key).foreach { cfg => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"may be removed in the future. ${cfg.deprecationMessage}") + } + + allAlternatives.get(key).foreach { case (newKey, cfg) => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"and may be removed in the future. Please use the new key '$newKey' instead.") + } + } + + /** + * Holds information about keys that have been deprecated and do not have a replacement. + * + * @param key The deprecated key. + * @param version Version of Spark where key was deprecated. + * @param deprecationMessage Message to include in the deprecation warning. + */ + private case class DeprecatedConfig( + key: String, + version: String, + deprecationMessage: String) + + /** + * Information about an alternate configuration key that has been deprecated. + * + * @param key The deprecated config key. + * @param version The Spark version in which the key was deprecated. + * @param translation A translation function for converting old config values into new ones. + */ + private case class AlternateConfig( + key: String, + version: String, + translation: String => String = null) + } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a16a3165463..f3da04a7f55d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -20,33 +20,44 @@ package org.apache.spark import scala.language.implicitConversions import java.io._ +import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID + +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} -import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, + FixedLengthBinaryInputFormat} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ @@ -85,10 +96,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - @volatile private var stopped: Boolean = false + private val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("Cannot call methods on a stopped SparkContext") } } @@ -103,13 +114,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * :: DeveloperApi :: * Alternative constructor for setting preferred locations where Spark will create executors. * + * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") @DeveloperApi def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { this(config) + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") this.preferredNodeLocationData = preferredNodeLocationData } @@ -132,6 +146,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. + * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. */ def this( master: String, @@ -142,6 +159,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) + if (preferredNodeLocationData.nonEmpty) { + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") + } this.preferredNodeLocationData = preferredNodeLocationData } @@ -182,9 +202,45 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - - private[spark] val conf = config.clone() - conf.validateSettings() + + /* ------------------------------------------------------------------------------------- * + | Private variables. These variables keep the internal state of the context, and are | + | not accessible by the outside world. They're mutable since we want to initialize all | + | of them to some neutral value ahead of time, so that calling "stop()" while the | + | constructor is still running is safe. | + * ------------------------------------------------------------------------------------- */ + + private var _conf: SparkConf = _ + private var _eventLogDir: Option[URI] = None + private var _eventLogCodec: Option[String] = None + private var _env: SparkEnv = _ + private var _metadataCleaner: MetadataCleaner = _ + private var _jobProgressListener: JobProgressListener = _ + private var _statusTracker: SparkStatusTracker = _ + private var _progressBar: Option[ConsoleProgressBar] = None + private var _ui: Option[SparkUI] = None + private var _hadoopConfiguration: Configuration = _ + private var _executorMemory: Int = _ + private var _schedulerBackend: SchedulerBackend = _ + private var _taskScheduler: TaskScheduler = _ + private var _heartbeatReceiver: RpcEndpointRef = _ + @volatile private var _dagScheduler: DAGScheduler = _ + private var _applicationId: String = _ + private var _applicationAttemptId: Option[String] = None + private var _eventLogger: Option[EventLoggingListener] = None + private var _executorAllocationManager: Option[ExecutorAllocationManager] = None + private var _cleaner: Option[ContextCleaner] = None + private var _listenerBusStarted: Boolean = false + private var _jars: Seq[String] = _ + private var _files: Seq[String] = _ + private var _shutdownHookRef: AnyRef = _ + + /* ------------------------------------------------------------------------------------- * + | Accessors and public fields. These provide access to the internal state of the | + | context. | + * ------------------------------------------------------------------------------------- */ + + private[spark] def conf: SparkConf = _conf /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be @@ -192,56 +248,35 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def getConf: SparkConf = conf.clone() - if (!conf.contains("spark.master")) { - throw new SparkException("A master URL must be set in your configuration") - } - if (!conf.contains("spark.app.name")) { - throw new SparkException("An application name must be set in your configuration") - } - - if (conf.getBoolean("spark.logConf", false)) { - logInfo("Spark configuration:\n" + conf.toDebugString) - } + def jars: Seq[String] = _jars + def files: Seq[String] = _files + def master: String = _conf.get("spark.master") + def appName: String = _conf.get("spark.app.name") - // Set Spark driver host and port system properties - conf.setIfMissing("spark.driver.host", Utils.localHostName()) - conf.setIfMissing("spark.driver.port", "0") + private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir + private[spark] def eventLogCodec: Option[String] = _eventLogCodec - val jars: Seq[String] = - conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val files: Seq[String] = - conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val master = conf.get("spark.master") - val appName = conf.get("spark.app.name") - - private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { - if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) - } else { - None - } - } - - // Generate the random name for a temp folder in Tachyon + // Generate the random name for a temp folder in external block store. // Add a timestamp as the suffix here to make it more safe - val tachyonFolderName = "spark-" + randomUUID.toString() - conf.set("spark.tachyonStore.folderName", tachyonFolderName) - - val isLocal = (master == "local" || master.startsWith("local[")) + val externalBlockStoreFolderName = "spark-" + randomUUID.toString() + @deprecated("Use externalBlockStoreFolderName instead.", "1.4.0") + val tachyonFolderName = externalBlockStoreFolderName - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + def isLocal: Boolean = (master == "local" || master.startsWith("local[")) // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) + // This function allows components created by SparkEnv to be mocked in unit tests: + private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + } - // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus) - SparkEnv.set(env) + private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -249,164 +284,310 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) + private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener + def statusTracker: SparkStatusTracker = _statusTracker - private[spark] val jobProgressListener = new JobProgressListener(conf) - listenerBus.addListener(jobProgressListener) + private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar - val statusTracker = new SparkStatusTracker(this) + private[spark] def ui: Option[SparkUI] = _ui - private[spark] val progressBar: Option[ConsoleProgressBar] = - if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { - Some(new ConsoleProgressBar(this)) - } else { - None - } + /** + * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. + */ + def hadoopConfiguration: Configuration = _hadoopConfiguration - // Initialize the Spark UI - private[spark] val ui: Option[SparkUI] = - if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, - env.securityManager,appName)) - } else { - // For tests, do not enable the UI - None - } + private[spark] def executorMemory: Int = _executorMemory - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.foreach(_.bind()) + // Environment variables to pass to our executors. + private[spark] val executorEnvs = HashMap[String, String]() - /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ - val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Utils.getCurrentUserName() - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach(addJar) + private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend + private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { + _schedulerBackend = sb } - if (files != null) { - files.foreach(addFile) + private[spark] def taskScheduler: TaskScheduler = _taskScheduler + private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { + _taskScheduler = ts } + private[spark] def dagScheduler: DAGScheduler = _dagScheduler + private[spark] def dagScheduler_=(ds: DAGScheduler): Unit = { + _dagScheduler = ds + } + + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ + def applicationId: String = _applicationId + def applicationAttemptId: Option[String] = _applicationAttemptId + + def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null + + private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger + + private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = + _executorAllocationManager + + private[spark] def cleaner: Option[ContextCleaner] = _cleaner + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } + + /* ------------------------------------------------------------------------------------- * + | Initialization. This code initializes the context in a manner that is exception-safe. | + | All internal fields holding state are initialized here, and any error prompts the | + | stop() method to be called. | + * ------------------------------------------------------------------------------------- */ + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value } - private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) - .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) - .map(Utils.memoryStringToMb) - .getOrElse(512) - - // Environment variables to pass to our executors. - private[spark] val executorEnvs = HashMap[String, String]() - - // Convert java options to env vars as a work around - // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) - value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { - executorEnvs(envKey) = value - } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => - executorEnvs("SPARK_PREPEND_CLASSES") = v + /** Control our logLevel. This overrides any user-defined log settings. + * @param logLevel The desired log level as a string. + * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + */ + def setLogLevel(logLevel: String) { + val validLevels = Seq("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN") + if (!validLevels.contains(logLevel)) { + throw new IllegalArgumentException( + s"Supplied level $logLevel did not match one of: ${validLevels.mkString(",")}") + } + Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel)) } - // The Mesos scheduler backend relies on this environment variable to set executor memory. - // TODO: Set this only in the Mesos scheduler. - executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" - executorEnvs ++= conf.getExecutorEnv - // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Option { - Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name")) - }.getOrElse { - SparkContext.SPARK_UNKNOWN_USER - } - executorEnvs("SPARK_USER") = sparkUser - - // Create and start the scheduler - private[spark] var (schedulerBackend, taskScheduler) = - SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") - @volatile private[spark] var dagScheduler: DAGScheduler = _ try { - dagScheduler = new DAGScheduler(this) - } catch { - case e: Exception => { - try { - stop() - } finally { - throw new SparkException("Error while constructing DAGScheduler", e) + _conf = config.clone() + _conf.validateSettings() + + if (!_conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!_conf.contains("spark.app.name")) { + throw new SparkException("An application name must be set in your configuration") + } + + // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster + // yarn-standalone is deprecated, but still supported + if ((master == "yarn-cluster" || master == "yarn-standalone") && + !_conf.contains("spark.yarn.app.id")) { + throw new SparkException("Detected yarn-cluster mode, but isn't running on a cluster. " + + "Deployment to YARN is not supported directly by SparkContext. Please use spark-submit.") + } + + if (_conf.getBoolean("spark.logConf", false)) { + logInfo("Spark configuration:\n" + _conf.toDebugString) + } + + // Set Spark driver host and port system properties + _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + _conf.setIfMissing("spark.driver.port", "0") + + _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) + + _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) + .toSeq.flatten + + _eventLogDir = + if (isEventLogEnabled) { + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } + + _eventLogCodec = { + val compress = _conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) + } else { + None } } - } - // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's - // constructor - taskScheduler.start() + _conf.set("spark.externalBlockStore.folderName", externalBlockStoreFolderName) - val applicationId: String = taskScheduler.applicationId() - conf.set("spark.app.id", applicationId) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - env.blockManager.initialize(applicationId) + // "_jobProgressListener" should be set up before creating SparkEnv because when creating + // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. + _jobProgressListener = new JobProgressListener(_conf) + listenerBus.addListener(jobProgressListener) - val metricsSystem = env.metricsSystem + // Create the Spark execution environment (cache, map output tracker, etc) + _env = createSparkEnv(_conf, isLocal, listenerBus) + SparkEnv.set(_env) - // The metrics system for Driver need to be set spark.app.id to app ID. - // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (isEventLogEnabled) { - val logger = - new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } + _statusTracker = new SparkStatusTracker(this) - // Optionally scale number of executors dynamically based on workload. Exposed for testing. - private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false) - private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = - if (dynamicAllocationEnabled) { - assert(master.contains("yarn") || dynamicAllocationTesting, - "Dynamic allocation of executors is currently only supported in YARN mode") - Some(new ExecutorAllocationManager(this, listenerBus, conf)) - } else { - None + _progressBar = + if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + _ui = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + _env.securityManager, appName, startTime = startTime)) + } else { + // For tests, do not enable the UI + None + } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) + + _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach(addJar) } - executorAllocationManager.foreach(_.start()) - // At this point, all relevant SparkListeners have been registered, so begin releasing events - listenerBus.start() + if (files != null) { + files.foreach(addFile) + } - private[spark] val cleaner: Option[ContextCleaner] = { - if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) - } else { - None + _executorMemory = _conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")) + .map(warnSparkMem)) + .map(Utils.memoryStringToMb) + .getOrElse(1024) + + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value } - } - cleaner.foreach(_.start()) + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" + executorEnvs ++= _conf.getExecutorEnv + executorEnvs("SPARK_USER") = sparkUser + + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + _heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + + // Create and start the scheduler + val (sched, ts) = SparkContext.createTaskScheduler(this, master) + _schedulerBackend = sched + _taskScheduler = ts + _dagScheduler = new DAGScheduler(this) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) + + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's + // constructor + _taskScheduler.start() + + _applicationId = _taskScheduler.applicationId() + _applicationAttemptId = taskScheduler.applicationAttemptId() + _conf.set("spark.app.id", _applicationId) + _env.blockManager.initialize(_applicationId) + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + + _eventLogger = + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, + _conf, _hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else { + None + } - postEnvironmentUpdate() - postApplicationStart() + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } - private[spark] var checkpointDir: Option[String] = None + _executorAllocationManager = + if (dynamicAllocationEnabled) { + Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + } else { + None + } + _executorAllocationManager.foreach(_.start()) - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + _cleaner = + if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) + + setupAndStartListenerBus() + postEnvironmentUpdate() + postApplicationStart() + + // Post init + _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + _executorAllocationManager.foreach { e => + _env.metricsSystem.registerSource(e.executorAllocationManagerSource) + } + + // Make sure the context is stopped if the user forgets about it. This avoids leaving + // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM + // is killed, though. + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + logInfo("Invoking stop() from shutdown hook") + stop() + } + } catch { + case NonFatal(e) => + logError("Error initializing SparkContext.", e) + try { + stop() + } catch { + case NonFatal(inner) => + logError("Error stopping SparkContext after init error.", inner) + } finally { + throw e + } } /** @@ -420,10 +601,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -448,9 +631,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { @@ -463,7 +643,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * [[org.apache.spark.SparkContext.setLocalProperty]]. */ def getLocalProperty(key: String): String = - Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) + Option(localProperties.get).map(_.getProperty(key)).orNull /** Set a human readable description of the current job. */ def setJobDescription(value: String) { @@ -511,18 +691,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } - // Post init - taskScheduler.postStartHook() - - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - private def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() + /** + * Execute a block of code in a scope such that all new RDDs created in this body will + * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) // Methods for creating RDDs @@ -531,24 +706,102 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call * to parallelize and before the first action on the RDD, the resultant RDD will reflect the * modified collection. Pass a copy of the argument to avoid this. + * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an + * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions. */ - def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def parallelize[T: ClassTag]( + seq: Seq[T], + numSlices: Int = defaultParallelism): RDD[T] = withScope { assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } + /** + * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by + * `step` every element. + * + * @note if we need to cache this RDD, we should make sure each partition does not exceed limit. + * + * @param start the start value. + * @param end the end value. + * @param step the incremental step + * @param numSlices the partition number of the new RDD. + * @return + */ + def range( + start: Long, + end: Long, + step: Long = 1, + numSlices: Int = defaultParallelism): RDD[Long] = withScope { + assertNotStopped() + // when step is 0, range will run infinitely + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + + new Iterator[Long] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + ret + } + } + }) + } + /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. */ - def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def makeRDD[T: ClassTag]( + seq: Seq[T], + numSlices: Int = defaultParallelism): RDD[T] = withScope { parallelize(seq, numSlices) } /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) @@ -558,10 +811,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. */ - def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + def textFile( + path: String, + minPartitions: Int = defaultMinPartitions): RDD[String] = withScope { assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], - minPartitions).map(pair => pair._2.toString).setName(path) + minPartitions).map(pair => pair._2.toString) } /** @@ -588,14 +843,21 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ - def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): - RDD[(String, String)] = { + def wholeTextFiles( + path: String, + minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = withScope { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updateConf = job.getConfiguration new WholeTextFileRDD( this, @@ -606,7 +868,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -622,7 +883,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ @@ -632,16 +893,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental - def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): - RDD[(String, PortableDataStream)] = { + def binaryFiles( + path: String, + minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updateConf = job.getConfiguration new BinaryFileRDD( this, @@ -657,21 +925,33 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Load data from a flat binary file, assuming the length of each record is constant. * - * @param path Directory to the input data files + * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * has the provided record length. + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param recordLength The length at which to split the records + * @param conf Configuration for setting up the dataset. + * * @return An RDD of data with values, represented as byte arrays */ @Experimental - def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) - : RDD[Array[Byte]] = { + def binaryRecords( + path: String, + recordLength: Int, + conf: Configuration = hadoopConfiguration): RDD[Array[Byte]] = withScope { assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], - conf=conf) - val data = br.map{ case (k, v) => v.getBytes} + conf = conf) + val data = br.map { case (k, v) => + val bytes = v.getBytes + assert(bytes.length == recordLength, "Byte array does not have correct length") + bytes + } data } @@ -680,7 +960,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), * using the older MapReduce API (`org.apache.hadoop.mapred`). * - * @param conf JobConf for setting up the dataset + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. * @param inputFormatClass Class of the InputFormat * @param keyClass Class of the keys * @param valueClass Class of the values @@ -697,8 +980,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minPartitions: Int = defaultMinPartitions - ): RDD[(K, V)] = { + minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) @@ -718,11 +1000,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minPartitions: Int = defaultMinPartitions - ): RDD[(K, V)] = { + minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. - val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) + val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( this, @@ -750,7 +1031,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def hadoopFile[K, V, F <: InputFormat[K, V]] (path: String, minPartitions: Int) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { hadoopFile(path, fm.runtimeClass.asInstanceOf[Class[F]], km.runtimeClass.asInstanceOf[Class[K]], @@ -773,13 +1054,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * copy them using a `map` function. */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { hadoopFile[K, V, F](path, defaultMinPartitions) + } /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]] (path: String) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { newAPIHadoopFile( path, fm.runtimeClass.asInstanceOf[Class[F]], @@ -802,12 +1084,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() - // The call to new NewHadoopJob automatically adds security credentials to conf, + // The call to new NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updatedConf = job.getConfiguration new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -816,6 +1100,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. @@ -826,7 +1118,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], - vClass: Class[V]): RDD[(K, V)] = { + vClass: Class[V]): RDD[(K, V)] = withScope { assertNotStopped() // Add necessary security credentials to the JobConf. Required to access secure HDFS. val jconf = new JobConf(conf) @@ -846,7 +1138,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli keyClass: Class[K], valueClass: Class[V], minPartitions: Int - ): RDD[(K, V)] = { + ): RDD[(K, V)] = withScope { assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) @@ -860,7 +1152,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + def sequenceFile[K, V]( + path: String, + keyClass: Class[K], + valueClass: Class[V]): RDD[(K, V)] = withScope { assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) } @@ -890,16 +1185,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def sequenceFile[K, V] (path: String, minPartitions: Int = defaultMinPartitions) (implicit km: ClassTag[K], vm: ClassTag[V], - kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) - : RDD[(K, V)] = { - assertNotStopped() - val kc = kcf() - val vc = vcf() - val format = classOf[SequenceFileInputFormat[Writable, Writable]] - val writables = hadoopFile(path, format, + kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]): RDD[(K, V)] = { + withScope { + assertNotStopped() + val kc = clean(kcf)() + val vc = clean(vcf)() + val format = classOf[SequenceFileInputFormat[Writable, Writable]] + val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], vc.writableClass(vm).asInstanceOf[Class[Writable]], minPartitions) - writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) } + writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) } + } } /** @@ -912,28 +1208,33 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def objectFile[T: ClassTag]( path: String, - minPartitions: Int = defaultMinPartitions - ): RDD[T] = { + minPartitions: Int = defaultMinPartitions): RDD[T] = withScope { assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } - protected[spark] def checkpointFile[T: ClassTag]( - path: String - ): RDD[T] = { - new CheckpointRDD[T](this, path) + protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope { + new ReliableCheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ - def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { + val partitioners = rdds.flatMap(_.partitioner).toSet + if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, rdds) + } else { + new UnionRDD(this, rdds) + } + } /** Build the union of a list of RDDs passed as variable-length arguments. */ - def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = - new UnionRDD(this, Seq(first) ++ rest) + def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = withScope { + union(Seq(first) ++ rest) + } /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables @@ -941,16 +1242,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = + { + val acc = new Accumulator(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the * driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { - new Accumulator(initialValue, param, Some(name)) + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) + : Accumulator[T] = { + val acc = new Accumulator(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -959,8 +1267,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param) + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the @@ -969,8 +1281,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param, Some(name)) + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { + val acc = new Accumulable(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an accumulator from a "mutable collection" type. @@ -980,8 +1296,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { - val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) + val param = new GrowableAccumulableParam[R, T] + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -1010,12 +1328,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ - def addFile(path: String) { + def addFile(path: String): Unit = { + addFile(path, false) + } + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * + * A directory can be given if the recursive option is set to true. Currently directories are only + * supported for Hadoop-supported filesystems. + */ + def addFile(path: String, recursive: Boolean): Unit = { val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case "local" => "file:" + uri.getPath - case _ => path + val schemeCorrectedPath = uri.getScheme match { + case null | "local" => new File(path).getCanonicalFile.toURI.toString + case _ => path + } + + val hadoopPath = new Path(schemeCorrectedPath) + val scheme = new URI(schemeCorrectedPath).getScheme + if (!Array("http", "https", "ftp").contains(scheme)) { + val fs = hadoopPath.getFileSystem(hadoopConfiguration) + if (!fs.exists(hadoopPath)) { + throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") + } + val isDir = fs.getFileStatus(hadoopPath).isDir + if (!isLocal && scheme == "file" && isDir) { + throw new SparkException(s"addFile does not support local directories when not running " + + "local mode.") + } + if (!recursive && isDir) { + throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + + "turned on.") + } + } + + val key = if (!isLocal && scheme == "file") { + env.httpFileServer.addFile(new File(uri.getPath)) + } else { + schemeCorrectedPath } val timestamp = System.currentTimeMillis addedFiles(key) = timestamp @@ -1037,15 +1391,41 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.addListener(listener) } + /** + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] + ): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + false + } + } + /** * :: DeveloperApi :: * Request an additional number of executors from the cluster manager. - * This is currently only supported in Yarn mode. Return whether the request is received. + * @return whether the request is received. */ @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, - "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1058,12 +1438,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. - * This is currently only supported in Yarn mode. Return whether the request is received. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executors it kills + * through this method with new ones, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * + * @return whether the request is received. */ @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, - "Killing executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) @@ -1075,14 +1459,44 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: - * Request that cluster manager the kill the specified executor. - * This is currently only supported in Yarn mode. Return whether the request is received. + * Request that the cluster manager kill the specified executor. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executor it kills + * through this method with a new one, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * + * @return whether the request is received. */ @DeveloperApi override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) + /** + * Request that the cluster manager kill the specified executor without adjusting the + * application resource requirements. + * + * The effect is that a new executor will be launched in place of the one killed by + * this request. This assumes the cluster manager will automatically and eventually + * fulfill all missing application resource requests. + * + * Note: The replace is by no means guaranteed; another application on the same cluster + * can steal the window of opportunity and acquire this application's resources in the + * mean time. + * + * @return whether the request is received. + */ + private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = true) + case _ => + logWarning("Killing executors is only supported in coarse-grained mode") + false + } + } + /** The version of Spark on which this application is running. */ - def version = SPARK_VERSION + def version: String = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining @@ -1132,7 +1546,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getAllPools: Seq[Schedulable] = { assertNotStopped() // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq + taskScheduler.rootPool.schedulableQueue.asScala.toSeq } /** @@ -1176,6 +1590,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { + _executorAllocationManager.foreach { _ => + logWarning( + s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + + s"${rdd.id} will be lost when executors are removed.") + } persistentRdds(rdd.id) = rdd } @@ -1224,7 +1643,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli null } } else { - env.httpFileServer.addJar(new File(uri.getPath)) + try { + env.httpFileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null + case e: Exception => + // For now just log an error but allow to go through so spark examples work. + // The spark examples don't really need the jar distributed since its also + // the app jar. + logError("Error adding jar (" + e + "), was the --addJars option used?") + null + } } // A JAR file which exists locally on every worker node case "local" => @@ -1250,31 +1681,73 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addedJars.clear() } - /** Shut down the SparkContext. */ + // Shut down the SparkContext. def stop() { - SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + // Use the stopping variable to ensure no contention for the stop scenario. + // Still track the stopped variable for use elsewhere in the code. + if (!stopped.compareAndSet(false, true)) { + logInfo("SparkContext already stopped.") + return + } + if (_shutdownHookRef != null) { + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) + } + + Utils.tryLogNonFatalError { postApplicationEnd() - ui.foreach(_.stop()) - if (!stopped) { - stopped = true + } + Utils.tryLogNonFatalError { + _ui.foreach(_.stop()) + } + if (env != null) { + Utils.tryLogNonFatalError { env.metricsSystem.report() + } + } + if (metadataCleaner != null) { + Utils.tryLogNonFatalError { metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) + } + } + Utils.tryLogNonFatalError { + _cleaner.foreach(_.stop()) + } + Utils.tryLogNonFatalError { + _executorAllocationManager.foreach(_.stop()) + } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } + if (_listenerBusStarted) { + Utils.tryLogNonFatalError { listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - SparkContext.clearActiveContext() - } else { - logInfo("SparkContext already stopped") + _listenerBusStarted = false + } + } + Utils.tryLogNonFatalError { + _eventLogger.foreach(_.stop()) + } + if (env != null && _heartbeatReceiver != null) { + Utils.tryLogNonFatalError { + env.rpcEnv.stop(_heartbeatReceiver) + } + } + Utils.tryLogNonFatalError { + _progressBar.foreach(_.stop()) + } + _taskScheduler = null + // TODO: Cache.stop()? + if (_env != null) { + Utils.tryLogNonFatalError { + _env.stop() } + SparkEnv.set(null) } + SparkContext.clearActiveContext() + logInfo("Successfully stopped SparkContext") } @@ -1326,69 +1799,120 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). + * handler function. This is the main entry point for all actions in Spark. */ def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { - if (stopped) { + resultHandler: (Int, U) => Unit): Unit = { + if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) - dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, - resultHandler, localProperties.get) + if (conf.getBoolean("spark.logLineage", false)) { + logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) + } + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and return the results as an array. */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int]): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res) + results + } + + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: Iterator[T] => U, + partitions: Seq[Int]): Array[U] = { + val cleanedFunc = clean(func) + runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) + } + + + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. + */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit): Unit = { + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions, resultHandler) + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. + */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * + * The allowLocal argument is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal) + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** @@ -1399,7 +1923,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler) } /** @@ -1411,7 +1935,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit) { val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler) } /** @@ -1428,7 +1952,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, + val cleanedFunc = clean(func) + val result = dagScheduler.runApproximateJob(rdd, cleanedFunc, evaluator, callSite, timeout, localProperties.get) logInfo( "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") @@ -1455,7 +1980,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (context: TaskContext, iter: Iterator[T]) => cleanF(iter), partitions, callSite, - allowLocal = false, resultHandler, localProperties.get) new SimpleFutureAction(waiter, resultFunc) @@ -1495,7 +2019,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability - * @throws SparkException if checkSerializable is set but f is not + * @throws SparkException if checkSerializable is set but f is not * serializable */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { @@ -1508,6 +2032,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * be a HDFS path if running on a cluster. */ def setCheckpointDir(directory: String) { + + // If we are running on a cluster, log a warning if the directory is local. + // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from + // its own local file system, which is incorrect because the checkpoint files + // are actually on the executor machines. + if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { + logWarning("Checkpoint directory must be non-local " + + "if Spark is running on a cluster: " + directory) + } + checkpointDir = Option(directory).map { dir => val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) @@ -1516,7 +2050,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - def getCheckpointDir = checkpointDir + def getCheckpointDir: Option[String] = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: Int = { @@ -1528,8 +2062,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** - * Default min number of partitions for Hadoop RDDs when not given by user + /** + * Default min number of partitions for Hadoop RDDs when not given by user * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 */ @@ -1544,12 +2078,65 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** + * Registers listeners specified in spark.extraListeners, then starts the listener bus. + * This should be called after all internal listeners have been registered with the listener bus + * (e.g. after the web UI and event logging listeners have been registered). + */ + private def setupAndStartListenerBus(): Unit = { + // Use reflection to instantiate listeners specified via `spark.extraListeners` + try { + val listenerClassNames: Seq[String] = + conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") + for (className <- listenerClassNames) { + // Use reflection to find the right constructor + val constructors = { + val listenerClass = Utils.classForName(className) + listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] + } + val constructorTakingSparkConf = constructors.find { c => + c.getParameterTypes.sameElements(Array(classOf[SparkConf])) + } + lazy val zeroArgumentConstructor = constructors.find { c => + c.getParameterTypes.isEmpty + } + val listener: SparkListener = { + if (constructorTakingSparkConf.isDefined) { + constructorTakingSparkConf.get.newInstance(conf) + } else if (zeroArgumentConstructor.isDefined) { + zeroArgumentConstructor.get.newInstance() + } else { + throw new SparkException( + s"$className did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the listener as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + } + } + listenerBus.addListener(listener) + logInfo(s"Registered listener $className") + } + } catch { + case e: Exception => + try { + stop() + } finally { + throw new SparkException(s"Exception when registering SparkListener", e) + } + } + + listenerBus.start(this) + _listenerBusStarted = true + } + /** Post the application start event */ private def postApplicationStart() { // Note: this code assumes that the task scheduler has been initialized and has contacted // the cluster manager to get an application ID (in case the cluster manager provides one). listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), - startTime, sparkUser)) + startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls)) } /** Post the application end event */ @@ -1563,8 +2150,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val schedulingMode = getSchedulingMode.toString val addedJarPaths = addedJars.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq - val environmentDetails = - SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) + val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, + addedFilePaths) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) listenerBus.post(environmentUpdate) } @@ -1593,11 +2180,12 @@ object SparkContext extends Logging { private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() /** - * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private var activeContext: Option[SparkContext] = None + private val activeContext: AtomicReference[SparkContext] = + new AtomicReference[SparkContext](null) /** * Points to a partially-constructed SparkContext if some thread is in the SparkContext @@ -1632,7 +2220,8 @@ object SparkContext extends Logging { logWarning(warnMsg) } - activeContext.foreach { ctx => + if (activeContext.get() != null) { + val ctx = activeContext.get() val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" @@ -1647,6 +2236,39 @@ object SparkContext extends Logging { } } + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(config: SparkConf): SparkContext = { + // Synchronize to ensure that multiple create requests don't trigger an exception + // from assertNoOtherContextIsRunning within setActiveContext + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + if (activeContext.get() == null) { + setActiveContext(new SparkContext(config), allowMultipleContexts = false) + } + activeContext.get() + } + } + + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * This method allows not passing a SparkConf (useful if just retrieving). + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(): SparkContext = { + getOrCreate(new SparkConf()) + } + /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is * running. Throws an exception if a running context is detected and logs a warning if another @@ -1673,7 +2295,7 @@ object SparkContext extends Logging { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { assertNoOtherContextIsRunning(sc, allowMultipleContexts) contextBeingConstructed = None - activeContext = Some(sc) + activeContext.set(sc) } } @@ -1684,19 +2306,27 @@ object SparkContext extends Logging { */ private[spark] def clearActiveContext(): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - activeContext = None + activeContext.set(null) } } private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" - private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" - private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" + private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" + private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" - private[spark] val SPARK_UNKNOWN_USER = "" + /** + * Executor id for the driver. In earlier versions of Spark, this was ``, but this was + * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see + * SPARK-6716 for more details). + */ + private[spark] val DRIVER_IDENTIFIER = "driver" - private[spark] val DRIVER_IDENTIFIER = "" + /** + * Legacy version of DRIVER_IDENTIFIER, retained for backwards-compatibility. + */ + private[spark] val LEGACY_DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to // make the compiler find them automatically. They are duplicate codes only for backward @@ -1707,28 +2337,28 @@ object SparkContext extends Logging { "backward compatibility.", "1.3.0") object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // The following deprecated functions have already been moved to `object RDD` to @@ -1738,49 +2368,71 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = RDD.rddToPairRDDFunctions(rdd) - } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = + RDD.rddToAsyncRDDActions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { + val kf = implicitly[K => Writable] + val vf = implicitly[V => Writable] + // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it + implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf) + implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf) RDD.rddToSequenceFileRDDFunctions(rdd) + } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = RDD.rddToOrderedRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = + RDD.doubleRDDToDoubleRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = RDD.numericRDDToDoubleRDDFunctions(rdd) - // Implicit conversions to common Writable types, for saveAsSequenceFile + // The following deprecated functions have already been moved to `object WritableFactory` to + // make the compiler find them automatically. They are still kept here for backward compatibility. + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def stringToText(s: String): Text = new Text(s) private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) @@ -1915,29 +2567,29 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) + val backend = new LocalBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) case LOCAL_N_REGEX(threads) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt if (threadCount <= 0) { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -1959,7 +2611,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) val masterUrls = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) @@ -1974,7 +2626,7 @@ object SparkContext extends Logging { "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") } val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { @@ -1986,7 +2638,7 @@ object SparkContext extends Logging { } val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -1999,8 +2651,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { - val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2012,7 +2663,7 @@ object SparkContext extends Logging { val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2030,7 +2681,7 @@ object SparkContext extends Logging { val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url) + new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) } else { new MesosSchedulerBackend(scheduler, sc, url) } @@ -2070,7 +2721,7 @@ object WritableConverter { new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) } - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -2103,3 +2754,46 @@ object WritableConverter { implicit def writableWritableConverter[T <: Writable](): WritableConverter[T] = new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) } + +/** + * A class encapsulating how to convert some type T to Writable. It stores both the Writable class + * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. + * The Writable class will be used in `SequenceFileRDDFunctions`. + */ +private[spark] class WritableFactory[T]( + val writableClass: ClassTag[T] => Class[_ <: Writable], + val convert: T => Writable) extends Serializable + +object WritableFactory { + + private[spark] def simpleWritableFactory[T: ClassTag, W <: Writable : ClassTag](convert: T => W) + : WritableFactory[T] = { + val writableClass = implicitly[ClassTag[W]].runtimeClass.asInstanceOf[Class[W]] + new WritableFactory[T](_ => writableClass, convert) + } + + implicit def intWritableFactory: WritableFactory[Int] = + simpleWritableFactory(new IntWritable(_)) + + implicit def longWritableFactory: WritableFactory[Long] = + simpleWritableFactory(new LongWritable(_)) + + implicit def floatWritableFactory: WritableFactory[Float] = + simpleWritableFactory(new FloatWritable(_)) + + implicit def doubleWritableFactory: WritableFactory[Double] = + simpleWritableFactory(new DoubleWritable(_)) + + implicit def booleanWritableFactory: WritableFactory[Boolean] = + simpleWritableFactory(new BooleanWritable(_)) + + implicit def bytesWritableFactory: WritableFactory[Array[Byte]] = + simpleWritableFactory(new BytesWritable(_)) + + implicit def stringWritableFactory: WritableFactory[String] = + simpleWritableFactory(new Text(_)) + + implicit def writableWritableFactory[T <: Writable: ClassTag]: WritableFactory[T] = + simpleWritableFactory(w => w) + +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f25db7f8de56..0f1e2e069568 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,11 +20,11 @@ package org.apache.spark import java.io.File import java.net.Socket -import scala.collection.JavaConversions._ +import akka.actor.ActorSystem + import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -34,11 +34,15 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -53,7 +57,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -67,8 +71,14 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val executorMemoryManager: ExecutorMemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + // TODO Remove actorSystem + @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") + val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -76,23 +86,46 @@ class SparkEnv ( // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() + private var driverTmpDirToDelete: Option[String] = None + private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - actorSystem.shutdown() - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } + } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor + } + } } private[spark] @@ -140,7 +173,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - @deprecated("Use SparkEnv.get instead", "1.2") + @deprecated("Use SparkEnv.get instead", "1.2.0") def getThreadLocal: SparkEnv = { env } @@ -151,7 +184,8 @@ object SparkEnv extends Logging { private[spark] def createDriverEnv( conf: SparkConf, isLocal: Boolean, - listenerBus: LiveListenerBus): SparkEnv = { + listenerBus: LiveListenerBus, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") @@ -163,7 +197,8 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, - listenerBus = listenerBus + listenerBus = listenerBus, + mockOutputCommitCoordinator = mockOutputCommitCoordinator ) } @@ -202,7 +237,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0): SparkEnv = { + numUsableCores: Int = 0, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -212,21 +248,20 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.address.port.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + val cls = Utils.classForName(className) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { @@ -257,16 +292,18 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + rpcEnv.setupEndpoint(name, endpointCreator) } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) + RpcUtils.makeDriverRef(name, conf, rpcEnv) } } - val mapOutputTracker = if (isDriver) { + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { new MapOutputTrackerWorker(conf) @@ -274,34 +311,38 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", + "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => + logWarning("NIO-based block transfer service is deprecated, " + + "and will be removed in Spark 1.6.0.") new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) @@ -314,7 +355,7 @@ object SparkEnv extends Logging { val fileServerPort = conf.getInt("spark.fileserver.port", 0) val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) + conf.set("spark.fileserver.uri", server.serverUri) server } else { null @@ -344,15 +385,25 @@ object SparkEnv extends Logging { "." } - // Warn about deprecated spark.cache.class property - if (conf.contains("spark.cache.class")) { - logWarning("The spark.cache.class property is no longer being used! Specify storage " + - "levels using the RDD.persist() method instead.") + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { + new OutputCommitCoordinator(conf, isDriver) + } + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) + + val executorMemoryManager: ExecutorMemoryManager = { + val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { + MemoryAllocator.UNSAFE + } else { + MemoryAllocator.HEAP + } + new ExecutorMemoryManager(allocator) } - new SparkEnv( + val envInstance = new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, @@ -366,7 +417,18 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + executorMemoryManager, + outputCommitCoordinator, conf) + + // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is + // called, and we only need to do it for driver. Because driver may run as a service, and if we + // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. + if (isDriver) { + envInstance.driverTmpDirToDelete = Some(sparkFilesDir) + } + + envInstance } /** diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 2ebd7a7151a5..977a27bdfe1b 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -30,3 +30,10 @@ class SparkException(message: String, cause: Throwable) */ private[spark] class SparkDriverExecutionException(cause: Throwable) extends SparkException("Execution error", cause) + +/** + * Exception thrown when the main user code is run as a child process (e.g. pyspark) and we want + * the parent SparkSubmit process to exit with the same exit code. + */ +private[spark] case class SparkUserAppException(exitCode: Int) + extends SparkException(s"User application exited with $exitCode") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 40237596570d..f5dd36cbcfe6 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -42,7 +43,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with Serializable { private val now = new Date() - private val conf = new SerializableWritable(jobConf) + private val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 @@ -50,8 +51,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private var jID: SerializableWritable[JobID] = null private var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var writer: RecordWriter[AnyRef, AnyRef] = null + @transient private var format: OutputFormat[AnyRef, AnyRef] = null @transient private var committer: OutputCommitter = null @transient private var jobContext: JobContext = null @transient private var taskContext: TaskAttemptContext = null @@ -103,36 +104,21 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => { - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - } else { - logInfo ("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask( + getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) } def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? val cmtr = getOutputCommitter() cmtr.commitJob(getJobContext()) } // ********* Private Functions ********* - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + .asInstanceOf[OutputFormat[AnyRef, AnyRef]] } format } @@ -153,7 +139,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index edbdda8a0bcb..34ee3a48f8e7 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a44631..63cca80b2d73 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,8 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.source.Source +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -31,7 +33,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc eq null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. @@ -44,6 +59,14 @@ object TaskContext { * Unset the thread local TaskContext. Internal to Spark. */ protected[spark] def unset(): Unit = taskContext.remove() + + /** + * An empty task context that does not represent an actual task. + */ + private[spark] def empty(): TaskContextImpl = { + new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) + } + } @@ -133,4 +156,40 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + /** + * ::DeveloperApi:: + * Returns all metrics sources with the given name which are associated with the instance + * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + */ + @DeveloperApi + def getMetricsSources(sourceName: String): Seq[Source] + + /** + * Returns the manager for this task's managed memory. + */ + private[spark] def taskMemoryManager(): TaskMemoryManager + + /** + * Register an accumulator that belongs to this task. Accumulators must call this method when + * deserializing in executors. + */ + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit + + /** + * Return the local values of internal accumulators that belong to this task. The key of the Map + * is the accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectInternalAccumulators(): Map[Long, Any] + + /** + * Return the local values of accumulators that belong to this task. The key of the Map is the + * accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectAccumulators(): Map[Long, Any] + + /** + * Accumulators for tracking internal metrics indexed by the name. + */ + private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala deleted file mode 100644 index 4636c4600a01..000000000000 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -/** - * This class exists to restrict the visibility of TaskContext setters. - */ -private [spark] object TaskContextHelper { - - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) - - def unset(): Unit = TaskContext.unset() - -} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebc..5df94c6d3a10 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,16 +17,22 @@ package org.apache.spark +import scala.collection.mutable.{ArrayBuffer, HashMap} + import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import scala.collection.mutable.ArrayBuffer - private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + @transient private val metricsSystem: MetricsSystem, + internalAccumulators: Seq[Accumulator[Long]], val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -92,5 +98,28 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = runningLocally override def isInterrupted(): Boolean = interrupted -} + override def getMetricsSources(sourceName: String): Seq[Source] = + metricsSystem.getSourcesByName(sourceName) + + @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] + + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { + accumulators(a.id) = a + } + + private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { + accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + } + + private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { + accumulators.mapValues(_.localValue).toMap + } + + private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { + // Explicitly register internal accumulators here because these are + // not captured in the task closure and are already deserialized + internalAccumulators.foreach(registerAccumulator) + internalAccumulators.map { a => (a.name.get, a) }.toMap + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index af5fd8e0ac00..934d00dc708b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -90,6 +92,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +103,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +148,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before @@ -146,6 +186,16 @@ case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" } +/** + * :: DeveloperApi :: + * Task requested the driver to commit, but was denied. + */ +@DeveloperApi +case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { + override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" +} + /** * :: DeveloperApi :: * The task failed because the executor that it was running on was lost. This may happen because diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index 0bf1e4a5e2cc..fe19f07e32d1 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,7 +27,9 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) + def isFailed(state: TaskState): Boolean = (LOST == state) || (FAILED == state) + + def isFinished(state: TaskState): Boolean = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { case LAUNCHING => MesosTaskState.TASK_STARTING @@ -46,5 +48,6 @@ private[spark] object TaskState extends Enumeration { case MesosTaskState.TASK_FAILED => FAILED case MesosTaskState.TASK_KILLED => KILLED case MesosTaskState.TASK_LOST => LOST + case MesosTaskState.TASK_ERROR => LOST } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 34078142f538..888763a3e8eb 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -17,11 +17,13 @@ package org.apache.spark -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} +import java.nio.charset.StandardCharsets +import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -43,13 +45,38 @@ private[spark] object TestUtils { * Note: if this is used during class loader tests, class names should be unique * in order to avoid interference between tests. */ - def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { + def createJarWithClasses( + classNames: Seq[String], + toStringValue: String = "", + classNamesWithBase: Seq[(String, String)] = Seq(), + classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() - val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) + val files1 = for (name <- classNames) yield { + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + } + val files2 = for ((childName, baseName) <- classNamesWithBase) yield { + createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) + } val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) - createJar(files, jarFile) + createJar(files1 ++ files2, jarFile) } + /** + * Create a jar file containing multiple files. The `files` map contains a mapping of + * file names in the jar file to their contents. + */ + def createJarWithFiles(files: Map[String, String], dir: File = null): URL = { + val tempDir = Option(dir).getOrElse(Utils.createTempDir()) + val jarFile = File.createTempFile("testJar", ".jar", tempDir) + val jarStream = new JarOutputStream(new FileOutputStream(jarFile)) + files.foreach { case (k, v) => + val entry = new JarEntry(k) + jarStream.putNextEntry(entry) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream) + } + jarStream.close() + jarFile.toURI.toURL + } /** * Create a jar file that contains this set of files. All files will be located at the root @@ -79,21 +106,27 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private class JavaSourceFromString(val name: String, val code: String) + private[spark] class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { - override def getCharContent(ignoreEncodingErrors: Boolean) = code + override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ - def createCompiledClass(className: String, destDir: File, value: String = ""): File = { + /** Creates a compiled class with the source file. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + sourceFile: JavaSourceFromString, + classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - val sourceFile = new JavaSourceFromString(className, - "public class " + className + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + value + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. - compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call() + val options = if (classpathUrls.nonEmpty) { + Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) + } else { + Seq() + } + compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) @@ -107,4 +140,18 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") + createCompiledClass(className, destDir, sourceFile, classpathUrls) + } } diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/scala/org/apache/spark/annotation/Private.java new file mode 100644 index 000000000000..9082fcf0c84b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Private.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * A class that is considered private to the internals of Spark -- there is a high-likelihood + * they will be changed in future versions of Spark. + * + * This should be used only when the standard Scala / Java means of protecting classes are + * insufficient. In particular, Java has no equivalent of private[spark], so we use this annotation + * in its place. + * + * NOTE: If there exists a Scaladoc comment that immediately precedes this annotation, the first + * line of the comment must be ":: Private ::" with no trailing blank line. This is because + * of the known issue that Scaladoc displays only either the annotation or the comment, whichever + * comes first. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Private {} diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 000000000000..af483e361e33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation +import scala.annotation.meta._ + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +@param @field @getter @setter @beanGetter @beanSetter +private[spark] class Since(version: String) extends StaticAnnotation diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 8e8f7f6c4fda..a650df605b92 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -32,7 +32,8 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter import org.apache.spark.util.Utils -class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] { +class JavaDoubleRDD(val srdd: RDD[scala.Double]) + extends AbstractJavaRDDLike[JDouble, JavaDoubleRDD] { override val classTag: ClassTag[JDouble] = implicitly[ClassTag[JDouble]] @@ -136,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ @@ -162,6 +163,20 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja /** Add up the elements in this RDD. */ def sum(): JDouble = srdd.sum() + /** + * Returns the minimum element from this RDD as defined by + * the default comparator natural order. + * @return the minimum of the RDD + */ + def min(): JDouble = min(com.google.common.collect.Ordering.natural()) + + /** + * Returns the maximum element from this RDD as defined by + * the default comparator natural order. + * @return the maximum of the RDD + */ + def max(): JDouble = max(com.google.common.collect.Ordering.natural()) + /** * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala index 0ae0b4ec042e..891bcddeac28 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapred.InputSplit @@ -37,7 +37,7 @@ class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala index ec4f3964d75e..0f49279f3e64 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapreduce.InputSplit @@ -37,7 +37,7 @@ class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7af3538262fd..fb787979c182 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -39,12 +39,13 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.rdd.RDD.rddToPairRDDFunctions +import org.apache.spark.serializer.Serializer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V]) - extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { + extends AbstractJavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) @@ -141,7 +142,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). @@ -172,7 +173,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) /** * ::Experimental:: @@ -227,24 +228,51 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). + * In addition, users can control the partitioning of the output RDD, the serializer that is use + * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple + * items with the same key). */ def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val ctag: ClassTag[C] = fakeClassTag + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner, + mapSideCombine: Boolean, + serializer: Serializer): JavaPairRDD[K, C] = { + implicit val ctag: ClassTag[C] = fakeClassTag fromRDD(rdd.combineByKey( createCombiner, mergeValue, mergeCombiners, - partitioner + partitioner, + mapSideCombine, + serializer )) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD. This method automatically + * uses map-side aggregation in shuffling the RDD. + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner): JavaPairRDD[K, C] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, true, null) + } + + /** + * Simplified version of combineByKey that hash-partitions the output RDD and uses map-side + * aggregation. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -488,7 +516,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing - * partitioner/parallelism level. + * partitioner/parallelism level and using map-side aggregation. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -633,7 +661,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.call(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } @@ -740,7 +768,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + def lookup(key: K): JList[V] = rdd.lookup(key).asJava /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( @@ -959,30 +987,27 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) object JavaPairRDD { private[spark] def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = { - rddToPairRDDFunctions(rdd).mapValues(asJavaIterable) + rddToPairRDDFunctions(rdd).mapValues(_.asJava) } private[spark] def cogroupResultToJava[K: ClassTag, V, W]( rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = { - rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava)) } private[spark] def cogroupResult2ToJava[K: ClassTag, V, W1, W2]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava)) } private[spark] def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => - (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava, x._4.asJava)) } def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 86fb374bef1e..ed312770ee13 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) - extends JavaRDDLike[T, JavaRDD[T]] { + extends AbstractJavaRDDLike[T, JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) @@ -101,12 +101,23 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param seed seed for the random number generator */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -168,7 +179,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - override def toString = rdd.toString + override def toString: String = rdd.toString /** Assign a name to this RDD */ def setName(name: String): JavaRDD[T] = { @@ -181,7 +192,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x) + def fn: (T) => S = (x: T) => f.call(x) import com.google.common.collect.Ordering // shadows scala.math.Ordering implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] implicit val ctag: ClassTag[S] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0f91c942ecd5..fc817cdd6a3f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,10 +17,10 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} +import java.util.{Comparator, List => JList, Iterator => JIterator} -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -38,6 +38,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaRDDLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This]] + extends JavaRDDLike[T, This] + /** * Defines operations common to several Java RDD implementations. * Note that this trait is not intended to be implemented by user code. @@ -50,10 +58,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + def splits: JList[Partition] = rdd.partitions.toSeq.asJava + /** Set of partitions in this RDD. */ - def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def partitions: JList[Partition] = rdd.partitions.toSeq.asJava + + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -70,7 +81,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * subclasses of RDD. */ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) + rdd.iterator(split, taskContext).asJava // Transformations (return a new RDD) @@ -85,9 +96,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -101,7 +112,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassTag[(K2, V2)]] + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -111,7 +122,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -121,8 +132,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -131,8 +142,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - def cm = implicitly[ClassTag[(K2, V2)]] + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -140,7 +151,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -149,7 +162,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -158,8 +173,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -167,7 +184,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -176,7 +195,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) } @@ -186,7 +207,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -195,14 +218,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f.call(asJavaIterator(x)))) + rdd.foreachPartition((x => f.call(x.asJava))) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + new JavaRDD(rdd.glom().map(_.toSeq.asJava)) /** * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of @@ -242,13 +265,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) + rdd.pipe(command.asScala) /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + rdd.pipe(command.asScala, env.asScala) /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -269,8 +292,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala + } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } @@ -307,28 +331,22 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } + def collect(): JList[T] = + rdd.collect().toSeq.asJava /** * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. */ - def toLocalIterator(): JIterator[T] = { - import scala.collection.JavaConversions._ - rdd.toLocalIterator - } - + def toLocalIterator(): JIterator[T] = + asJavaIteratorConverter(rdd.toLocalIterator).asJava /** * Return an array that contains all of the elements in this RDD. * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead */ - @Deprecated + @deprecated("use collect()", "1.0.0") def toArray(): JList[T] = collect() /** @@ -337,9 +355,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. - import scala.collection.JavaConversions._ - val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) - res.map(x => new java.util.ArrayList(x.toSeq)).toArray + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) + res.map(_.toSeq.asJava) } /** @@ -363,9 +380,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = rdd.fold(zeroValue)(f) @@ -433,8 +457,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + def countByValue(): java.util.Map[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). @@ -456,20 +480,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } + def take(num: Int): JList[T] = + rdd.take(num).toSeq.asJava - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - - def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } + + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = + rdd.takeSample(withReplacement, num, seed).toSeq.asJava /** * Return the first element in this RDD. @@ -549,10 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.top(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -574,10 +589,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -663,7 +675,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { - new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x.asJava)), { x => null.asInstanceOf[Void] }) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 97f5c9f257e0..609496ccdfef 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,8 +21,7 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -104,11 +103,11 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) private[spark] val env = sc.env - def statusTracker = new JavaSparkStatusTracker(sc) + def statusTracker: JavaSparkStatusTracker = new JavaSparkStatusTracker(sc) def isLocal: java.lang.Boolean = sc.isLocal @@ -118,7 +117,7 @@ class JavaSparkContext(val sc: SparkContext) def appName: String = sc.appName - def jars: util.List[String] = sc.jars + def jars: util.List[String] = sc.jars.asJava def startTime: java.lang.Long = sc.startTime @@ -142,7 +141,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) + sc.parallelize(list.asScala, numSlices) } /** Get an RDD that has no partitions or elements. */ @@ -161,7 +160,7 @@ class JavaSparkContext(val sc: SparkContext) : JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[V] = fakeClassTag - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) + JavaPairRDD.fromRDD(sc.parallelize(list.asScala, numSlices)) } /** Distribute a local Scala collection to form an RDD. */ @@ -170,8 +169,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) + JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()), numSlices)) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = @@ -373,6 +371,15 @@ class JavaSparkContext(val sc: SparkContext) * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * @param minPartitions Minimum number of Hadoop Splits to generate. + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -395,6 +402,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -476,6 +491,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -494,7 +517,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[T]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[T] = first.classTag sc.union(rdds) } @@ -502,7 +525,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[(K, V)] = first.classTag implicit val ctagK: ClassTag[K] = first.kClassTag implicit val ctagV: ClassTag[V] = first.vClassTag @@ -511,7 +534,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) + val rdds: Seq[RDD[Double]] = (Seq(first) ++ rest.asScala).map(_.srdd) new JavaDoubleRDD(sc.union(rdds)) } @@ -675,6 +698,9 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { sc.hadoopConfiguration @@ -727,6 +753,14 @@ class JavaSparkContext(val sc: SparkContext) */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) + /** Control our logLevel. This overrides any user-defined log settings. + * @param logLevel The desired log level as a string. + * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + */ + def setLogLevel(logLevel: String) { + sc.setLogLevel(logLevel) + } + /** * Assigns a group ID to all the jobs started by this thread until the group ID is set to a * different value or cleared. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 71b26737b8c0..8f9647eea9e2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.java +import java.util.Map.Entry + import com.google.common.base.Optional import java.{util => ju} @@ -30,8 +32,8 @@ private[spark] object JavaUtils { } // Workaround for SPARK-3926 / SI-8911 - def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = - new SerializableMapWrapper(underlying) + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]): SerializableMapWrapper[A, B] + = new SerializableMapWrapper(underlying) // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper, // but implements java.io.Serializable. It can't just be subclassed to make it @@ -40,36 +42,33 @@ private[spark] object JavaUtils { class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) extends ju.AbstractMap[A, B] with java.io.Serializable { self => - override def size = underlying.size + override def size: Int = underlying.size override def get(key: AnyRef): B = try { - underlying get key.asInstanceOf[A] match { - case None => null.asInstanceOf[B] - case Some(v) => v - } + underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { case ex: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { - def size = self.size + override def size: Int = self.size - def iterator = new ju.Iterator[ju.Map.Entry[A, B]] { + override def iterator: ju.Iterator[ju.Map.Entry[A, B]] = new ju.Iterator[ju.Map.Entry[A, B]] { val ui = underlying.iterator var prev : Option[A] = None - def hasNext = ui.hasNext + def hasNext: Boolean = ui.hasNext - def next() = { - val (k, v) = ui.next + def next(): Entry[A, B] = { + val (k, v) = ui.next() prev = Some(k) new ju.Map.Entry[A, B] { import scala.util.hashing.byteswap32 - def getKey = k - def getValue = v - def setValue(v1 : B) = self.put(k, v1) - override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) - override def equals(other: Any) = other match { + override def getKey: A = k + override def getValue: B = v + override def setValue(v1 : B): B = self.put(k, v1) + override def hashCode: Int = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) + override def equals(other: Any): Boolean = other match { case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue case _ => false } @@ -81,7 +80,7 @@ private[spark] object JavaUtils { case Some(k) => underlying match { case mm: mutable.Map[A, _] => - mm remove k + mm.remove(k) prev = None case _ => throw new UnsupportedOperationException("remove") diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala new file mode 100644 index 000000000000..164e95081583 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.DataOutputStream +import java.net.Socket + +import py4j.GatewayServer + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port + * back to its caller via a callback port specified by the caller. + * + * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). + */ +private[spark] object PythonGatewayServer extends Logging { + def main(args: Array[String]): Unit = Utils.tryOrExit { + // Start a GatewayServer on an ephemeral port + val gatewayServer: GatewayServer = new GatewayServer(null, 0) + gatewayServer.start() + val boundPort: Int = gatewayServer.getListeningPort + if (boundPort == -1) { + logError("GatewayServer failed to bind; exiting") + System.exit(1) + } else { + logDebug(s"Started PythonGatewayServer on port $boundPort") + } + + // Communicate the bound port back to the caller via the caller-specified callback port + val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") + val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt + logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") + val callbackSocket = new Socket(callbackHost, callbackPort) + val dos = new DataOutputStream(callbackSocket.getOutputStream) + dos.writeInt(boundPort) + dos.close() + callbackSocket.close() + + // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: + while (System.in.read() != -1) { + // Do nothing + } + logDebug("Exiting due to broken pipe from Python driver") + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index c9181a29d475..a7dfa1d257cf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,15 +17,17 @@ package org.apache.spark.api.python -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException} +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * :: Experimental :: @@ -61,14 +63,13 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { + conf: Broadcast[SerializableConfiguration]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or * object representation */ private def convertWritable(writable: Writable): Any = { - import collection.JavaConversions._ writable match { case iw: IntWritable => iw.get() case dw: DoubleWritable => dw.get() @@ -89,9 +90,7 @@ private[python] class WritableToJavaConverter( aw.get().map(convertWritable(_)) case mw: MapWritable => val map = new java.util.HashMap[Any, Any]() - mw.foreach { case (k, v) => - map.put(convertWritable(k), convertWritable(v)) - } + mw.asScala.foreach { case (k, v) => map.put(convertWritable(k), convertWritable(v)) } map case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other @@ -122,7 +121,6 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { * supported out-of-the-box. */ private def convertToWritable(obj: Any): Writable = { - import collection.JavaConversions._ obj match { case i: java.lang.Integer => new IntWritable(i) case d: java.lang.Double => new DoubleWritable(d) @@ -134,7 +132,7 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { case null => NullWritable.get() case map: java.util.Map[_, _] => val mapWritable = new MapWritable() - map.foreach { case (k, v) => + map.asScala.foreach { case (k, v) => mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable @@ -161,9 +159,8 @@ private[python] object PythonHadoopUtil { * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]] */ def mapToConf(map: java.util.Map[String, String]): Configuration = { - import collection.JavaConversions._ val conf = new Configuration() - map.foreach{ case (k, v) => conf.set(k, v) } + map.asScala.foreach { case (k, v) => conf.set(k, v) } conf } @@ -172,9 +169,8 @@ private[python] object PythonHadoopUtil { * any matching keys in left */ def mergeConfs(left: Configuration, right: Configuration): Configuration = { - import collection.JavaConversions._ val copy = new Configuration(left) - right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue)) + right.asScala.foreach(entry => copy.set(entry.getKey, entry.getValue)) copy } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b89effc16d36..b4d152b33660 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,25 +19,26 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} +import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} -import org.apache.spark.input.PortableDataStream - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} + import org.apache.spark._ -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} + +import scala.util.control.NonFatal private[spark] class PythonRDD( @transient parent: RDD[_], @@ -46,6 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -53,20 +55,22 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = firstParent.partitions + override def getPartitions: Array[Partition] = firstParent.partitions - override val partitioner = if (preservePartitoning) firstParent.partitioner else None + override val partitioner: Option[Partitioner] = { + if (preservePartitoning) firstParent.partitioner else None + } override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { - envVars += ("SPARK_REUSE_WORKER" -> "1") + envVars.put("SPARK_REUSE_WORKER", "1") } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool @volatile var released = false @@ -75,7 +79,6 @@ private[spark] class PythonRDD( context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - writerThread.join() if (!reuse_worker || !released) { try { worker.close() @@ -92,7 +95,7 @@ private[spark] class PythonRDD( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { - def next(): Array[Byte] = { + override def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { _nextObj = read() @@ -147,7 +150,7 @@ private[spark] class PythonRDD( // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) released = true } } @@ -175,7 +178,7 @@ private[spark] class PythonRDD( var _nextObj = read() - def hasNext = _nextObj != null + override def hasNext: Boolean = _nextObj != null } new InterruptibleIterator(context, stdoutIterator) } @@ -204,31 +207,33 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(split.index) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { + dataOut.writeInt(pythonIncludes.size()) + for (include <- pythonIncludes.asScala) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet + val newBids = broadcastVars.asScala.map(_.id).toSet // number of different broadcasts - val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size dataOut.writeInt(cnt) - for (bid <- oldBids) { - if (!newBids.contains(bid)) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) } - for (broadcast <- broadcastVars) { + for (broadcast <- broadcastVars.asScala) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -248,18 +253,17 @@ private[spark] class PythonRDD( } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) - worker.shutdownOutput() + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - worker.shutdownOutput() - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } } } } @@ -283,7 +287,7 @@ private[spark] class PythonRDD( if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap, worker) + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -300,10 +304,10 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Long, Array[Byte])](prev) { - override def getPartitions = prev.partitions - override def compute(split: Partition, context: TaskContext) = +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { + override def getPartitions: Array[Partition] = prev.partitions + override val partitioner: Option[Partitioner] = prev.partitioner + override def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Byte])] = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) @@ -329,24 +333,44 @@ private[spark] object PythonRDD extends Logging { } } + /** + * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true + * + * This is useful for PySpark to have the partitioner after partitionBy() + */ + def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = { + pair.rdd.mapPartitions(it => it.map(_._2), true) + } + /** * Adapter for calling SparkContext#runJob from Python. * - * This method will return an iterator of an array that contains all elements in the RDD + * This method will serve an iterator of an array that contains all elements in the RDD * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. + * + * @return the port number of a local socket which serves the data collected from this job. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int], - allowLocal: Boolean): Iterator[Array[Byte]] = { + partitions: JArrayList[Int]): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) - flattenedPartition.iterator + serveIterator(flattenedPartition.iterator, + s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") + } + + /** + * A helper function to collect an RDD as an iterator, then serve it via socket. + * + * @return the port number of a local socket which serves the data collected from this job. + */ + def collectAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -410,13 +434,13 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, minSplits: Int, - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -437,12 +461,12 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -463,12 +487,12 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -480,7 +504,7 @@ private[spark] object PythonRDD extends Logging { inputFormatClass: String, keyClass: String, valueClass: String, - conf: Configuration) = { + conf: Configuration): RDD[(K, V)] = { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] @@ -506,12 +530,12 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -532,12 +556,12 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -566,15 +590,43 @@ private[spark] object PythonRDD extends Logging { dataOut.write(bytes) } - def writeToFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeToFile(items.asScala, filename) - } + /** + * Create a socket server and a background thread to serve the data in `items`, + * + * The socket server can only accept one connection, or close if no connection + * in 3 seconds. + * + * Once a connection comes in, it tries to serialize all the data in `items` + * and send them into this connection. + * + * The thread will terminate after all the data are sent or any exceptions happen. + */ + private def serveIterator[T](items: Iterator[T], threadName: String): Int = { + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + // Close the socket if no connection in 3 seconds + serverSocket.setSoTimeout(3000) + + new Thread(threadName) { + setDaemon(true) + override def run() { + try { + val sock = serverSocket.accept() + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + Utils.tryWithSafeFinally { + writeIteratorToStream(items, out) + } { + out.close() + } + } catch { + case NonFatal(e) => + logError(s"Error while sending iterator", e) + } finally { + serverSocket.close() + } + } + }.start() - def writeToFile[T](items: Iterator[T], filename: String) { - val file = new DataOutputStream(new FileOutputStream(filename)) - writeIteratorToStream(items, file) - file.close() + serverSocket.getLocalPort } private def getMergedConf(confAsMap: java.util.HashMap[String, String], @@ -632,7 +684,7 @@ private[spark] object PythonRDD extends Logging { pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { saveAsHadoopFile( pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", null, null, null, null, new java.util.HashMap(), compressionCodecClass) @@ -657,7 +709,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -666,7 +718,7 @@ private[spark] object PythonRDD extends Logging { val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] - converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } /** @@ -687,7 +739,7 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String]): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -712,7 +764,7 @@ private[spark] object PythonRDD extends Logging { confAsMap: java.util.HashMap[String, String], keyConverterClass: String, valueConverterClass: String, - useNewAPI: Boolean) = { + useNewAPI: Boolean): Unit = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), keyConverterClass, valueConverterClass, new JavaToWritableConverter) @@ -740,10 +792,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added - * by the DAGScheduler's single-threaded actor anyway. - */ + * by the DAGScheduler's single-threaded RpcEndpoint anyway. + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { @@ -767,7 +819,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) - for (array <- val2) { + for (array <- val2.asScala) { out.writeInt(array.length) out.write(array) } @@ -786,6 +838,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -808,9 +861,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial val file = File.createTempFile("broadcast", "", dir) path = file.getAbsolutePath val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { Utils.copyStream(in, out) - } finally { + } { out.close() } } @@ -827,3 +880,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index b7cfc8bd9c54..31e534f160ee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.api.python -import java.io.{File, InputStream, IOException, OutputStream} +import java.io.File +import java.util.{List => JList} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -29,7 +31,7 @@ private[spark] object PythonUtils { def sparkPythonPath: String = { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { - pythonPath += Seq(sparkHome, "python").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) @@ -44,4 +46,32 @@ private[spark] object PythonUtils { def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { sc.parallelize(List("a", null, "b")) } + + /** + * Convert list of T into seq of T (for calling API with varargs) + */ + def toSeq[T](vs: JList[T]): Seq[T] = { + vs.asScala + } + + /** + * Convert list of T into a (Scala) List of T + */ + def toList[T](vs: JList[T]): List[T] = { + vs.asScala.toList + } + + /** + * Convert list of T into array of T (for calling API with array) + */ + def toArray[T](vs: JList[T]): Array[T] = { + vs.toArray().asInstanceOf[Array[T]] + } + + /** + * Convert java map of K, V into Map of K, V (for calling API with varargs) + */ + def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { + jm.asScala.toMap + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e314408c067e..7039b734d2e4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.python import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.util.Arrays import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.util.{RedirectThread, Utils} @@ -108,9 +109,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") @@ -151,9 +152,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index fb52a960e076..fd27276e70bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList} import org.apache.spark.api.java.JavaRDD -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure @@ -56,16 +55,13 @@ private[spark] object SerDeUtil extends Logging { // {'\0', 0, 0, 0} /* Sentinel */ // }; // TODO: support Py_UNICODE with 2 bytes - // FIXME: unpickle array of float is wrong in Pyrolite, so we reverse the - // machine code for float/double here to workaround it. - // we should fix this after Pyrolite fix them val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9, - 'L' -> 11, 'l' -> 13, 'f' -> 14, 'd' -> 16, 'u' -> 21 + 'L' -> 11, 'l' -> 13, 'f' -> 15, 'd' -> 17, 'u' -> 21 ) } else { Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8, - 'L' -> 10, 'l' -> 12, 'f' -> 15, 'd' -> 17, 'u' -> 20 + 'L' -> 10, 'l' -> 12, 'f' -> 14, 'd' -> 16, 'u' -> 20 ) } override def construct(args: Array[Object]): Object = { @@ -84,7 +80,7 @@ private[spark] object SerDeUtil extends Logging { private var initialized = false // This should be called before trying to unpickle array.array from Python // In cluster mode, this should be put in closure - def initialize() = { + def initialize(): Unit = { synchronized{ if (!initialized) { Unpickler.registerConstructor("array", "array", new ArrayConstructor()) @@ -217,7 +213,7 @@ private[spark] object SerDeUtil extends Logging { new AutoBatchedPickler(cleaned) } else { val pickle = new Pickler - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index c0cbd28a845b..ee1fb056f0d9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,38 +18,39 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} +import java.{util => ju} + +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + +import org.apache.spark.SparkException import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.{SparkContext, SparkException} /** * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python - * @param str - * @param int - * @param double */ case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { def this() = this("", 0, 0.0) - def getStr = str + def getStr: String = str def setStr(str: String) { this.str = str } - def getInt = int + def getInt: Int = int def setInt(int: Int) { this.int = int } - def getDouble = double + def getDouble: Double = double def setDouble(double: Double) { this.double = double } - def write(out: DataOutput) = { + def write(out: DataOutput): Unit = { out.writeUTF(str) out.writeInt(int) out.writeDouble(double) } - def readFields(in: DataInput) = { + def readFields(in: DataInput): Unit = { str = in.readUTF() int = in.readInt() double = in.readDouble() @@ -57,36 +58,34 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } private[python] class TestInputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Char = { obj.asInstanceOf[IntWritable].get().toChar } } private[python] class TestInputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] - seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + m.keySet.asScala.map(_.asInstanceOf[DoubleWritable].get()).toSeq.asJava } } private[python] class TestOutputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Text = { new Text(obj.asInstanceOf[Int].toString) } } private[python] class TestOutputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ - override def convert(obj: Any) = { - new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + override def convert(obj: Any): DoubleWritable = { + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().iterator().next()) } } private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { - override def convert(obj: Any) = obj match { + override def convert(obj: Any): DoubleArrayWritable = obj match { case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => val daw = new DoubleArrayWritable daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) @@ -107,7 +106,6 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra * given directory (probably a temp directory) */ object WriteInputFormatTestDataGenerator { - import SparkContext._ def main(args: Array[String]) { val path = args(0) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 000000000000..b7e72d4d0ed0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.{Logging, SparkConf} + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + val conf = new SparkConf() + bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + // scalastyle:off println + System.err.println("Usage: RBackend ") + // scalastyle:on println + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 000000000000..bb82f3285f1d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap +import scala.language.existentials + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ +import org.apache.spark.util.Utils + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") + } + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Utils.classForName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args : _*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args : _*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed") + writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 000000000000..9d5bbb5d609f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.{InetAddress, ServerSocket} +import java.util.Arrays +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + dataStream = new DataInputStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val taskContext = TaskContext.get() + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def readData(length: Int): U + + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private object SpecialLengths { + val TIMING_DATA = -1 +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap.asScala) { + sparkConf.set(name.toString, value.toString) + } + for ((name, value) <- sparkExecutorEnvMap.asScala) { + sparkConf.setExecutorEnv(name.toString, value.toString) + } + + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") + val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala new file mode 100644 index 000000000000..9e807cc52f18 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.File +import java.util.Arrays + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object RUtils { + /** + * Get the SparkR package path in the local spark distribution. + */ + def localSparkRPackagePath: Option[String] = { + val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.test.home")) + sparkHome.map( + Seq(_, "R", "lib").mkString(File.separator) + ) + } + + /** + * Get the SparkR package path in various deployment modes. + * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` + * and environment variable `SPARK_HOME` are set. + */ + def sparkRPackagePath(isDriver: Boolean): String = { + val (master, deployMode) = + if (isDriver) { + (sys.props("spark.master"), sys.props("spark.submit.deployMode")) + } else { + val sparkConf = SparkEnv.get.conf + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + } + + val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master != null && master.contains("yarn") && deployMode == "client" + + // In YARN mode, the SparkR package is distributed as an archive symbolically + // linked to the "sparkr" file in the current directory. Note that this does not apply + // to the driver in client mode because it is run outside of the cluster. + if (isYarnCluster || (isYarnClient && !isDriver)) { + new File("sparkr").getAbsolutePath + } else { + // Otherwise, assume the package is local + // TODO: support this for Mesos + localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + } + } + + /** Check if R is installed before running tests that use R commands. */ + def isRInstalled: Boolean = { + try { + val builder = new ProcessBuilder(Arrays.asList("R", "--version")) + builder.start().waitFor() == 0 + } catch { + case e: Exception => false + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 000000000000..26ad4f1d4697 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Timestamp, Date, Time} + +import scala.collection.JavaConverters._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + assert(bytes(len - 1) == 0) + val str = new String(bytes.dropRight(1), "UTF-8") + str + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + readStringBytes(in, len) + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case 'l' => { + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + } + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) + keys.zip(values).toMap.asJava + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Float -> double + // Double -> double + // Decimal -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.Character" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[Character].toString) + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) + case "java.math.BigDecimal" => + writeType(dos, "double") + val javaDecimal = value.asInstanceOf[java.math.BigDecimal] + writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) + case "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "java.lang.Byte" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Byte].toInt) + case "java.lang.Short" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Short].toInt) + case "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) + + // Handle arrays + + // Array of primitive types + + // Special handling for byte array + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + + case "[C" => + writeType(dos, "array") + writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) + case "[S" => + writeType(dos, "array") + writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) + case "[I" => + writeType(dos, "array") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[F" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) + case "[D" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "array") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + + // Array of objects, null objects use "void" type + case c if c.startsWith("[") => + writeType(dos, "list") + val array = value.asInstanceOf[Array[Object]] + writeInt(dos, array.length) + array.foreach(elem => writeObject(dos, elem)) + + case _ => + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index a5ea478f231d..12d79f6ed311 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -146,5 +146,5 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo } } - override def toString = "Broadcast(" + id + ")" + override def toString: String = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2..fac6666bb341 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -42,7 +43,7 @@ private[spark] class BroadcastManager( conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isDriver, conf, securityManager) @@ -58,7 +59,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 1444c0dd3d2d..b69af639f786 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging { securityManager = securityMgr if (isDriver) { createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) + conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) @@ -160,12 +160,12 @@ private[broadcast] object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) private def write(id: Long, value: Any) { val file = getFile(id) val fileOutputStream = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(fileOutputStream) @@ -175,16 +175,19 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() + Utils.tryWithSafeFinally { + serOut.writeObject(value) + } { + serOut.close() + } files += file - } finally { + } { fileOutputStream.close() } } private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null @@ -212,9 +215,11 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj + Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } } /** @@ -222,7 +227,7 @@ private[broadcast] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c7ef02d572a1..cf3ae36f2794 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = new HttpBroadcast[T](value_, isLocal, id) override def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 94142d33369c..7e3764d802fe 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer -import scala.collection.JavaConversions.asJavaEnumeration +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) @@ -209,7 +210,7 @@ private object TorrentBroadcast extends Logging { compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) @@ -222,7 +223,7 @@ private object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index fb024c12094f..96d8dd79908c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -30,7 +30,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index ae55b4ff40b7..ae99432f5ce8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,13 +17,18 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], - val memoryPerSlave: Int, + val memoryPerExecutorMB: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None) + val eventLogDir: Option[URI] = None, + // short name of compression codec used when writing event logs, if any (e.g. lzf) + val eventLogCodec: Option[String] = None, + val coresPerExecutor: Option[Int] = None) extends Serializable { val user = System.getProperty("user.name", "") @@ -31,11 +36,13 @@ private[spark] class ApplicationDescription( def copy( name: String = name, maxCores: Option[Int] = maxCores, - memoryPerSlave: Int = memoryPerSlave, + memoryPerExecutorMB: Int = memoryPerExecutorMB, command: Command = command, appUiUrl: String = appUiUrl, - eventLogDir: Option[String] = eventLogDir): ApplicationDescription = - new ApplicationDescription(name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir) + eventLogDir: Option[URI] = eventLogDir, + eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = + new ApplicationDescription( + name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 38b3da0b1375..f03875a3e8c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,35 +17,49 @@ package org.apache.spark.deploy -import scala.concurrent._ +import scala.collection.mutable.HashSet +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. + * + * We currently don't support retry if submission fails. In HA mode, client will submit request to + * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - var masterActor: ActorSelection = _ - val timeout = AkkaUtils.askTimeout(conf) - - override def preStart() = { - masterActor = context.actorSelection( - Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) - - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - - println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -68,8 +82,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) .map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) val javaOpts = sparkJavaOpts ++ extraJavaOpts - val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) + val command = new Command(mainClass, + Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions, + sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, @@ -77,39 +92,52 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - masterActor ! RequestSubmitDriver(driverDescription) + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - masterActor ! RequestKillDriver(driverId) + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println(s"... waiting before polling master for driver state") + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") - val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) - + logInfo("... polling master for driver state") + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -117,24 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { + override def receive: PartialFunction[Any, Unit] = { + + case SubmitDriverResponse(master, success, driverId, message) => + logInfo(message) + if (success) { + activeMasterEndpoint = master + pollAndReportStatus(driverId.get) + } else if (!Utils.responseFromBackup(message)) { + System.exit(-1) + } + + + case KillDriverResponse(master, driverId, success, message) => + logInfo(message) + if (success) { + activeMasterEndpoint = master + pollAndReportStatus(driverId) + } else if (!Utils.responseFromBackup(message)) { + System.exit(-1) + } + } - case SubmitDriverResponse(success, driverId, message) => - println(message) - if (success) pollAndReportStatus(driverId.get) else System.exit(-1) + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) + } + } + } - case KillDriverResponse(driverId, success, message) => - println(message) - if (success) pollAndReportStatus(driverId) else System.exit(-1) + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) + } + } + } - case DisassociatedEvent(_, remoteAddress, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - System.exit(-1) + override def onError(cause: Throwable): Unit = { + logError(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - println(s"Cause was: $cause") - System.exit(-1) + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -143,10 +209,12 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) @@ -154,17 +222,17 @@ object Client { if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { conf.set("spark.akka.logLifecycleEvents", "true") } - conf.set("spark.akka.askTimeout", "10") + conf.set("spark.rpc.askTimeout", "10") conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem)) - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index e5873ce724b9..72cc330a398d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -22,35 +22,33 @@ import java.net.{URI, URISyntaxException} import scala.collection.mutable.ListBuffer import org.apache.log4j.Level - -import org.apache.spark.util.{IntParam, MemoryParam} +import org.apache.spark.util.{IntParam, MemoryParam, Utils} /** * Command-line parser for the driver client. */ -private[spark] class ClientArguments(args: Array[String]) { - val defaultCores = 1 - val defaultMemory = 512 +private[deploy] class ClientArguments(args: Array[String]) { + import ClientArguments._ var cmd: String = "" // 'launch' or 'kill' var logLevel = Level.WARN // launch parameters - var master: String = "" + var masters: Array[String] = null var jarUrl: String = "" var mainClass: String = "" - var supervise: Boolean = false - var memory: Int = defaultMemory - var cores: Int = defaultCores + var supervise: Boolean = DEFAULT_SUPERVISE + var memory: Int = DEFAULT_MEMORY + var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() - def driverOptions = _driverOptions.toSeq + def driverOptions: Seq[String] = _driverOptions.toSeq // kill parameters var driverId: String = "" parse(args.toList) - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--cores" | "-c") :: IntParam(value) :: tail => cores = value parse(tail) @@ -74,20 +72,22 @@ private[spark] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } jarUrl = _jarUrl - master = _master + masters = Utils.parseStandaloneMasterUrls(_master) mainClass = _mainClass _driverOptions ++= tail case "kill" :: _master :: _driverId :: tail => cmd = "kill" - master = _master + masters = Utils.parseStandaloneMasterUrls(_master) driverId = _driverId case _ => @@ -97,7 +97,7 @@ private[spark] class ClientArguments(args: Array[String]) { /** * Print usage and exit JVM with the given exit code. */ - def printUsageAndExit(exitCode: Int) { + private def printUsageAndExit(exitCode: Int) { // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars // separately similar to in the YARN client. val usage = @@ -106,17 +106,24 @@ private[spark] class ClientArguments(args: Array[String]) { |Usage: DriverClient kill | |Options: - | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) - | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY) | -s, --supervise Whether to restart the driver on failure + | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } -object ClientArguments { +private[deploy] object ClientArguments { + val DEFAULT_CORES = 1 + val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB + val DEFAULT_SUPERVISE = false + def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 243d8edb72ed..d8084a57658a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,20 +94,26 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage + case class UnregisterApplication(appId: String) + case class MasterChangeAcknowledged(appId: String) + case class RequestExecutors(appId: String, requestedTotal: Int) + + case class KillExecutors(appId: String, executorIds: Seq[String]) + // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -121,12 +129,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -140,7 +150,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master @@ -148,15 +158,22 @@ private[deploy] object DeployMessages { // Master to MasterWebUI - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], - activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], - status: MasterState) { + case class MasterStateResponse( + host: String, + port: Int, + restPort: Option[Int], + workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], + completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], + completedDrivers: Array[DriverInfo], + status: MasterState) { Utils.checkHost(host, "Required hostname") assert (port > 0) - def uri = "spark://" + host + ":" + port + def uri: String = "spark://" + host + ":" + port + def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } // WorkerWebUI to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala index b056a19ce659..659fb434a80f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -private[spark] class DriverDescription( +private[deploy] class DriverDescription( val jarUrl: String, val mem: Int, val cores: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala index 2abf0b69dddb..ec23371b52f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala @@ -22,7 +22,7 @@ package org.apache.spark.deploy * This state is sufficient for the Master to reconstruct its internal data structures during * failover. */ -private[spark] class ExecutorDescription( +private[deploy] class ExecutorDescription( val appId: String, val execId: Int, val cores: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 9f34d01e6db4..efa88c62e1f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -private[spark] object ExecutorState extends Enumeration { +private[deploy] object ExecutorState extends Enumeration { val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala new file mode 100644 index 000000000000..6840a3ae831f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConverters._ + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslServerBootstrap +import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.util.TransportConf +import org.apache.spark.util.Utils + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val blockHandler = newShuffleBlockHandler(transportConf) + private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) + + private var server: TransportServer = _ + + /** Create a new shuffle block handler. Factored out for subclasses to override. */ + protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { + new ExternalShuffleBlockHandler(conf, null) + } + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + start() + } + } + + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + val bootstraps: Seq[TransportServerBootstrap] = + if (useSasl) { + Seq(new SaslServerBootstrap(transportConf, securityManager)) + } else { + Nil + } + server = transportContext.createServer(port, bootstraps.asJava) + } + + /** Clean up all shuffle files associated with an application that has exited. */ + def applicationRemoved(appId: String): Unit = { + blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) + } + + def stop() { + if (server != null) { + server.close() + server = null + } + } +} + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm)) + } + + /** A helper main method that allows the caller to call this with a custom shuffle service. */ + private[spark] def main( + args: Array[String], + newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = newShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 47dbcd87c35b..b4edb6109e83 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -32,7 +32,8 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} -import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.deploy.master.RecoveryState +import org.apache.spark.util.Utils /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. @@ -55,29 +56,29 @@ import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} * - The docker images tagged spark-test-master and spark-test-worker are built from the * docker/ directory. Run 'docker/spark-test/build' to generate these. */ -private[spark] object FaultToleranceTest extends App with Logging { +private object FaultToleranceTest extends App with Logging { - val conf = new SparkConf() - val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + private val conf = new SparkConf() + private val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") - val masters = ListBuffer[TestMasterInfo]() - val workers = ListBuffer[TestWorkerInfo]() - var sc: SparkContext = _ + private val masters = ListBuffer[TestMasterInfo]() + private val workers = ListBuffer[TestWorkerInfo]() + private var sc: SparkContext = _ - val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) - var numPassed = 0 - var numFailed = 0 + private var numPassed = 0 + private var numFailed = 0 - val sparkHome = System.getenv("SPARK_HOME") + private val sparkHome = System.getenv("SPARK_HOME") assertTrue(sparkHome != null, "Run with a valid SPARK_HOME") - val containerSparkHome = "/opt/spark" - val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome) + private val containerSparkHome = "/opt/spark" + private val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome) System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip - def afterEach() { + private def afterEach() { if (sc != null) { sc.stop() sc = null @@ -179,7 +180,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } } - def test(name: String)(fn: => Unit) { + private def test(name: String)(fn: => Unit) { try { fn numPassed += 1 @@ -197,19 +198,19 @@ private[spark] object FaultToleranceTest extends App with Logging { afterEach() } - def addMasters(num: Int) { + private def addMasters(num: Int) { logInfo(s">>>>> ADD MASTERS $num <<<<<") (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) } } - def addWorkers(num: Int) { + private def addWorkers(num: Int) { logInfo(s">>>>> ADD WORKERS $num <<<<<") val masterUrls = getMasterUrls(masters) (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) } } /** Creates a SparkContext, which constructs a Client to interact with our cluster. */ - def createClient() = { + private def createClient() = { logInfo(">>>>> CREATE CLIENT <<<<<") if (sc != null) { sc.stop() } // Counter-hack: Because of a hack in SparkEnv#create() that changes this @@ -218,17 +219,17 @@ private[spark] object FaultToleranceTest extends App with Logging { sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome) } - def getMasterUrls(masters: Seq[TestMasterInfo]): String = { + private def getMasterUrls(masters: Seq[TestMasterInfo]): String = { "spark://" + masters.map(master => master.ip + ":7077").mkString(",") } - def getLeader: TestMasterInfo = { + private def getLeader: TestMasterInfo = { val leaders = masters.filter(_.state == RecoveryState.ALIVE) assertTrue(leaders.size == 1) leaders(0) } - def killLeader(): Unit = { + private def killLeader(): Unit = { logInfo(">>>>> KILL LEADER <<<<<") masters.foreach(_.readState()) val leader = getLeader @@ -236,9 +237,9 @@ private[spark] object FaultToleranceTest extends App with Logging { leader.kill() } - def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) + private def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) - def terminateCluster() { + private def terminateCluster() { logInfo(">>>>> TERMINATE CLUSTER <<<<<") masters.foreach(_.kill()) workers.foreach(_.kill()) @@ -247,7 +248,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } /** This includes Client retry logic, so it may take a while if the cluster is recovering. */ - def assertUsable() = { + private def assertUsable() = { val f = future { try { val res = sc.parallelize(0 until 10).collect() @@ -269,7 +270,7 @@ private[spark] object FaultToleranceTest extends App with Logging { * Asserts that the cluster is usable and that the expected masters and workers * are all alive in a proper configuration (e.g., only one leader). */ - def assertValidClusterState() = { + private def assertValidClusterState() = { logInfo(">>>>> ASSERT VALID CLUSTER STATE <<<<<") assertUsable() var numAlive = 0 @@ -325,7 +326,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } } - def assertTrue(bool: Boolean, message: String = "") { + private def assertTrue(bool: Boolean, message: String = "") { if (!bool) { throw new IllegalStateException("Assertion failed: " + message) } @@ -335,7 +336,7 @@ private[spark] object FaultToleranceTest extends App with Logging { numFailed)) } -private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File) +private class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File) extends Logging { implicit val formats = org.json4s.DefaultFormats @@ -377,7 +378,7 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val format(ip, dockerId.id, logFile.getAbsolutePath, state) } -private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File) +private class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File) extends Logging { implicit val formats = org.json4s.DefaultFormats @@ -390,7 +391,7 @@ private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val "[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath) } -private[spark] object SparkDocker { +private object SparkDocker { def startMaster(mountDir: String): TestMasterInfo = { val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir) val (ip, id, outFile) = startNode(cmd) @@ -405,8 +406,7 @@ private[spark] object SparkDocker { private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = { val ipPromise = promise[String]() - val outFile = File.createTempFile("fault-tolerance-test", "") - outFile.deleteOnExit() + val outFile = File.createTempFile("fault-tolerance-test", "", Utils.createTempDir()) val outStream: FileWriter = new FileWriter(outFile) def findIpAndLog(line: String): Unit = { if (line.startsWith("CONTAINER_IP=")) { @@ -425,11 +425,11 @@ private[spark] object SparkDocker { } } -private[spark] class DockerId(val id: String) { - override def toString = id +private class DockerId(val id: String) { + override def toString: String = id } -private[spark] object Docker extends Logging { +private object Docker extends Logging { def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = { val mountCmd = if (mountDir != "") { " -v " + mountDir } else "" diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 696f32a6f573..ccffb3665298 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -17,14 +17,15 @@ package org.apache.spark.deploy +import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.ExecutorRunner -private[spark] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { +private[deploy] object JsonProtocol { + def writeWorkerInfo(obj: WorkerInfo): JObject = { ("id" -> obj.id) ~ ("host" -> obj.host) ~ ("port" -> obj.port) ~ @@ -39,34 +40,34 @@ private[spark] object JsonProtocol { ("lastheartbeat" -> obj.lastHeartbeat) } - def writeApplicationInfo(obj: ApplicationInfo) = { + def writeApplicationInfo(obj: ApplicationInfo): JObject = { ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ + ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ ("duration" -> obj.duration) } - def writeApplicationDescription(obj: ApplicationDescription) = { + def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ + ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } - def writeExecutorRunner(obj: ExecutorRunner) = { + def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ ("appid" -> obj.appId) ~ ("appdesc" -> writeApplicationDescription(obj.appDesc)) } - def writeDriverInfo(obj: DriverInfo) = { + def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ @@ -74,20 +75,21 @@ private[spark] object JsonProtocol { ("memory" -> obj.desc.mem) } - def writeMasterState(obj: MasterStateResponse) = { + def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ ("status" -> obj.status.toString) } - def writeWorkerState(obj: WorkerStateResponse) = { + def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ ("masterwebuiurl" -> obj.masterWebUiUrl) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c9571..83ccaadfe744 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -33,28 +32,39 @@ import org.apache.spark.util.Utils * fault recovery without spinning up a lot of processes. */ private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) +class LocalSparkCluster( + numWorkers: Int, + coresPerWorker: Int, + memoryPerWorker: Int, + conf: SparkConf) extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() + // exposed for testing + var masterWebUIPort = -1 def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + // Disable REST server on Master in this mode unless otherwise specified + val _conf = conf.clone() + .setIfMissing("spark.master.rest.enabled", "false") + .set("spark.shuffle.service.enabled", "false") + /* Start the Master */ - val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) - masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) + masterWebUIPort = webUiPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masters, null, Some(workerNum)) - workerActorSystems += workerSystem + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, + memoryPerWorker, masters, null, Some(workerNum), _conf) + workerRpcEnvs += workerEnv } masters @@ -63,13 +73,9 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def stop() { logInfo("Shutting down local Spark cluster.") // Stop the workers before the master so they don't get upset that it disconnected - // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! - // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) - // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) - // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 53e18c4bcec2..d85327603f64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,10 +18,13 @@ package org.apache.spark.deploy import java.net.URI +import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.util.Try +import org.apache.spark.SparkUserAppException import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -44,7 +47,20 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such val gatewayServer = new py4j.GatewayServer(null, 0) - gatewayServer.start() + val thread = new Thread(new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions { + gatewayServer.start() + } + }) + thread.setName("py4j-gateway-init") + thread.setDaemon(true) + thread.start() + + // Wait until the gateway server has started, so that we know which port is it bound to. + // `gatewayServer.start()` will start a new thread and run the server code there, after + // initializing the socket, so the thread started above will end as soon as the server is + // ready to serve connections. + thread.join() // Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument @@ -55,18 +71,25 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize - val process = builder.start() + try { + val process = builder.start() - new RedirectThread(process.getInputStream, System.out, "redirect output").start() + new RedirectThread(process.getInputStream, System.out, "redirect output").start() - System.exit(process.waitFor()) + val exitCode = process.waitFor() + if (exitCode != 0) { + throw new SparkUserAppException(exitCode) + } + } finally { + gatewayServer.shutdown() + } } /** @@ -81,16 +104,13 @@ object PythonRunner { throw new IllegalArgumentException("Launching Python applications through " + s"spark-submit is currently only supported for local files: $path") } - val windows = Utils.isWindows || testWindows - var formattedPath = if (windows) Utils.formatWindowsPath(path) else path - - // Strip the URI scheme from the path - formattedPath = - new URI(formattedPath).getScheme match { - case null => formattedPath - case Utils.windowsDrive(d) if windows => formattedPath - case _ => new URI(formattedPath).getPath - } + // get path when scheme is file. + val uri = Try(new URI(path)).getOrElse(new File(path).toURI) + var formattedPath = uri.getScheme match { + case null => path + case "file" | "local" => uri.getPath + case _ => null + } // Guard against malformed paths potentially throwing NPE if (formattedPath == null) { @@ -99,7 +119,9 @@ object PythonRunner { // In Windows, the drive should not be prefixed with "/" // For instance, python does not understand "/C:/path/to/sheep.py" - formattedPath = if (windows) formattedPath.stripPrefix("/") else formattedPath + if (Utils.isWindows && formattedPath.matches("/[a-zA-Z]:/.*")) { + formattedPath = formattedPath.stripPrefix("/") + } formattedPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala new file mode 100644 index 000000000000..4b28866dcaa7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.jar.JarFile +import java.util.logging.Level +import java.util.zip.{ZipEntry, ZipOutputStream} + +import scala.collection.JavaConverters._ + +import com.google.common.io.{ByteStreams, Files} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.api.r.RUtils +import org.apache.spark.util.{RedirectThread, Utils} + +private[deploy] object RPackageUtils extends Logging { + + /** The key in the MANIFEST.mf that we look for, in case a jar contains R code. */ + private final val hasRPackage = "Spark-HasRPackage" + + /** Base of the shell command used in order to install R packages. */ + private final val baseInstallCmd = Seq("R", "CMD", "INSTALL", "-l") + + /** R source code should exist under R/pkg in a jar. */ + private final val RJarEntries = "R/pkg" + + /** Documentation on how the R source file layout should be in the jar. */ + private[deploy] final val RJarDoc = + s"""In order for Spark to build R packages that are parts of Spark Packages, there are a few + |requirements. The R source code must be shipped in a jar, with additional Java/Scala + |classes. The jar must be in the following format: + | 1- The Manifest (META-INF/MANIFEST.mf) must contain the key-value: $hasRPackage: true + | 2- The standard R package layout must be preserved under R/pkg/ inside the jar. More + | information on the standard R package layout can be found in: + | http://cran.r-project.org/doc/contrib/Leisch-CreatingPackages.pdf + | An example layout is given below. After running `jar tf $$JAR_FILE | sort`: + | + |META-INF/MANIFEST.MF + |R/ + |R/pkg/ + |R/pkg/DESCRIPTION + |R/pkg/NAMESPACE + |R/pkg/R/ + |R/pkg/R/myRcode.R + |org/ + |org/apache/ + |... + """.stripMargin.trim + + /** Internal method for logging. We log to a printStream in tests, for debugging purposes. */ + private def print( + msg: String, + printStream: PrintStream, + level: Level = Level.FINE, + e: Throwable = null): Unit = { + if (printStream != null) { + // scalastyle:off println + printStream.println(msg) + // scalastyle:on println + if (e != null) { + e.printStackTrace(printStream) + } + } else { + level match { + case Level.INFO => logInfo(msg) + case Level.WARNING => logWarning(msg) + case Level.SEVERE => logError(msg, e) + case _ => logDebug(msg) + } + } + } + + /** + * Checks the manifest of the Jar whether there is any R source code bundled with it. + * Exposed for testing. + */ + private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + val manifest = jar.getManifest.getMainAttributes + manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" + } + + /** + * Runs the standard R package installation code to build the R package from source. + * Multiple runs don't cause problems. + */ + private def rPackageBuilder(dir: File, printStream: PrintStream, verbose: Boolean): Boolean = { + // this code should be always running on the driver. + val pathToSparkR = RUtils.localSparkRPackagePath.getOrElse( + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")) + val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator) + val installCmd = baseInstallCmd ++ Seq(pathToSparkR, pathToPkg) + if (verbose) { + print(s"Building R package with the command: $installCmd", printStream) + } + try { + val builder = new ProcessBuilder(installCmd.asJava) + builder.redirectErrorStream(true) + val env = builder.environment() + env.clear() + val process = builder.start() + new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start() + process.waitFor() == 0 + } catch { + case e: Throwable => + print("Failed to build R package.", printStream, Level.SEVERE, e) + false + } + } + + /** + * Extracts the files under /R in the jar to a temporary directory for building. + */ + private def extractRFolder(jar: JarFile, printStream: PrintStream, verbose: Boolean): File = { + val tempDir = Utils.createTempDir(null) + val jarEntries = jar.entries() + while (jarEntries.hasMoreElements) { + val entry = jarEntries.nextElement() + val entryRIndex = entry.getName.indexOf(RJarEntries) + if (entryRIndex > -1) { + val entryPath = entry.getName.substring(entryRIndex) + if (entry.isDirectory) { + val dir = new File(tempDir, entryPath) + if (verbose) { + print(s"Creating directory: $dir", printStream) + } + dir.mkdirs + } else { + val inStream = jar.getInputStream(entry) + val outPath = new File(tempDir, entryPath) + Files.createParentDirs(outPath) + val outStream = new FileOutputStream(outPath) + if (verbose) { + print(s"Extracting $entry to $outPath", printStream) + } + Utils.copyStream(inStream, outStream, closeStreams = true) + } + } + } + tempDir + } + + /** + * Extracts the files under /R in the jar to a temporary directory for building. + */ + private[deploy] def checkAndBuildRPackage( + jars: String, + printStream: PrintStream = null, + verbose: Boolean = false): Unit = { + jars.split(",").foreach { jarPath => + val file = new File(Utils.resolveURI(jarPath)) + if (file.exists()) { + val jar = new JarFile(file) + if (checkManifestForR(jar)) { + print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) + val rSource = extractRFolder(jar, printStream, verbose) + try { + if (!rPackageBuilder(rSource, printStream, verbose)) { + print(s"ERROR: Failed to build R package in $file.", printStream) + print(RJarDoc, printStream) + } + } finally { + rSource.delete() // clean up + } + } else { + if (verbose) { + print(s"$file doesn't contain R source code, skipping...", printStream) + } + } + } else { + print(s"WARN: $file resolved as dependency, but not found.", printStream, Level.WARNING) + } + } + } + + private def listFilesRecursively(dir: File, excludePatterns: Seq[String]): Set[File] = { + if (!dir.exists()) { + Set.empty[File] + } else { + if (dir.isDirectory) { + val subDir = dir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + !excludePatterns.map(name.contains).reduce(_ || _) // exclude files with given pattern + } + }) + subDir.flatMap(listFilesRecursively(_, excludePatterns)).toSet + } else { + Set(dir) + } + } + } + + /** Zips all the libraries found with SparkR in the R/lib directory for distribution with Yarn. */ + private[deploy] def zipRLibraries(dir: File, name: String): File = { + val filesToBundle = listFilesRecursively(dir, Seq(".zip")) + // create a zip file from scratch, do not append to existing file. + val zipFile = new File(dir, name) + zipFile.delete() + val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false)) + try { + filesToBundle.foreach { file => + // get the relative paths for proper naming in the zip file + val relPath = file.getAbsolutePath.replaceFirst(dir.getAbsolutePath, "") + val fis = new FileInputStream(file) + val zipEntry = new ZipEntry(relPath) + zipOutputStream.putNextEntry(zipEntry) + ByteStreams.copy(fis, zipOutputStream) + zipOutputStream.closeEntry() + fis.close() + } + } finally { + zipOutputStream.close() + } + zipFile + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 000000000000..05b954ce3699 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir) + env.put("R_PROFILE_USER", + Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + // scalastyle:off println + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala similarity index 77% rename from core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala rename to core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index 4781a80d470e..8d5e716e6aea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.master +package org.apache.spark.deploy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.retry.ExponentialBackoffRetry @@ -25,15 +25,17 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -object SparkCuratorUtil extends Logging { +private[spark] object SparkCuratorUtil extends Logging { - val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 - val ZK_SESSION_TIMEOUT_MILLIS = 60000 - val RETRY_WAIT_MILLIS = 5000 - val MAX_RECONNECT_ATTEMPTS = 3 + private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 + private val ZK_SESSION_TIMEOUT_MILLIS = 60000 + private val RETRY_WAIT_MILLIS = 5000 + private val MAX_RECONNECT_ATTEMPTS = 3 - def newClient(conf: SparkConf): CuratorFramework = { - val ZK_URL = conf.get("spark.deploy.zookeeper.url") + def newClient( + conf: SparkConf, + zkUrlConf: String = "spark.deploy.zookeeper.url"): CuratorFramework = { + val ZK_URL = conf.get(zkUrlConf) val zk = CuratorFrameworkFactory.newClient(ZK_URL, ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) @@ -55,7 +57,7 @@ object SparkCuratorUtil extends Logging { def deleteRecursive(zk: CuratorFramework, path: String) { if (zk.checkExists().forPath(path) != null) { - for (child <- zk.getChildren.forPath(path)) { + for (child <- zk.getChildren.forPath(path).asScala) { zk.delete().forPath(path + "/" + child) } zk.delete().forPath(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index d68854214ef0..f7723ef5bde4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,22 +17,30 @@ package org.apache.spark.deploy +import java.io.{ByteArrayInputStream, DataInputStream} import java.lang.reflect.Method import java.security.PrivilegedExceptionAction +import java.util.{Arrays, Comparator} +import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.control.NonFatal + +import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} +import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import org.apache.spark.{Logging, SparkConf, SparkException} /** * :: DeveloperApi :: @@ -40,7 +48,8 @@ import scala.collection.JavaConversions._ */ @DeveloperApi class SparkHadoopUtil extends Logging { - val conf: Configuration = newConfiguration(new SparkConf()) + private val sparkConf = new SparkConf() + val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) /** @@ -52,27 +61,22 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { - val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER) - if (user != SparkContext.SPARK_UNKNOWN_USER) { - logDebug("running as user: " + user) - val ugi = UserGroupInformation.createRemoteUser(user) - transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) - } else { - logDebug("running as SPARK_UNKNOWN_USER") - func() - } + val user = Utils.getCurrentUserName() + logDebug("running as user: " + user) + val ugi = UserGroupInformation.createRemoteUser(user) + transferCredentials(UserGroupInformation.getCurrentUser(), ugi) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - for (token <- source.getTokens()) { + for (token <- source.getTokens.asScala) { dest.addToken(token) } } - @Deprecated + @deprecated("use newConfiguration with SparkConf argument", "1.2.0") def newConfiguration(): Configuration = newConfiguration(null) /** @@ -171,13 +175,13 @@ class SparkHadoopUtil extends Logging { } private def getFileSystemThreadStatistics(): Seq[AnyRef] = { - val stats = FileSystem.getAllStatistics() - stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + FileSystem.getAllStatistics.asScala.map( + Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") statisticsDataClass.getDeclaredMethod(methodName) } @@ -191,6 +195,188 @@ class SparkHadoopUtil extends Logging { val method = context.getClass.getMethod("getConfiguration") method.invoke(context).asInstanceOf[Configuration] } + + /** + * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly + * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes + * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ + * while it's interface in Hadoop 2.+. + */ + def getTaskAttemptIDFromTaskAttemptContext( + context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + val method = context.getClass.getMethod("getTaskAttemptID") + method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] + } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + listLeafStatuses(fs, fs.getFileStatus(basePath)) + } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { + def recurse(status: FileStatus): Seq[FileStatus] = { + val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f)) + } + + if (baseStatus.isDir) recurse(baseStatus) else Seq(baseStatus) + } + + def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + listLeafDirStatuses(fs, fs.getFileStatus(basePath)) + } + + def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { + def recurse(status: FileStatus): Seq[FileStatus] = { + val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir) + val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus] + leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir)) + } + + assert(baseStatus.isDir) + recurse(baseStatus) + } + + def globPath(pattern: Path): Seq[Path] = { + val fs = pattern.getFileSystem(conf) + Option(fs.globStatus(pattern)).map { statuses => + statuses.map(_.getPath.makeQualified(fs.getUri, fs.getWorkingDirectory)).toSeq + }.getOrElse(Seq.empty[Path]) + } + + def globPathIfNecessary(pattern: Path): Seq[Path] = { + if (pattern.toString.exists("{}[]*?\\".toSet.contains)) { + globPath(pattern) + } else { + Seq(pattern) + } + } + + /** + * Lists all the files in a directory with the specified prefix, and does not end with the + * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of + * the respective files. + */ + def listFilesSorted( + remoteFs: FileSystem, + dir: Path, + prefix: String, + exclusionSuffix: String): Array[FileStatus] = { + try { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) + } + }) + fileStatuses + } catch { + case NonFatal(e) => + logWarning("Error while attempting to list files from application staging dir", e) + Array.empty + } + } + + /** + * How much time is remaining (in millis) from now to (fraction * renewal time for the token that + * is valid the latest)? + * This will return -ve (or 0) value if the fraction of validity has already expired. + */ + def getTimeFromNowToRenewal( + sparkConf: SparkConf, + fraction: Double, + credentials: Credentials): Long = { + val now = System.currentTimeMillis() + + val renewalInterval = + sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) + + credentials.getAllTokens.asScala + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + .map { t => + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) + } + + + private[spark] def getSuffixForCredentialsPath(credentialsPath: Path): Int = { + val fileName = credentialsPath.getName + fileName.substring( + fileName.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) + 1).toInt + } + + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored + + /** + * Substitute variables by looking them up in Hadoop configs. Only variables that match the + * ${hadoopconf- .. } pattern are substituted. + */ + def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { + text match { + case HADOOP_CONF_PATTERN(matched) => { + logDebug(text + " matched " + HADOOP_CONF_PATTERN) + val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } + val eval = Option[String](hadoopConf.get(key)) + .map { value => + logDebug("Substituted " + matched + " with " + value) + text.replace(matched, value) + } + if (eval.isEmpty) { + // The variable was not found in Hadoop configs, so return text as is. + text + } else { + // Continue to substitute more variables. + substituteHadoopVariables(eval.get, hadoopConf) + } + } + case _ => { + logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) + text + } + } + } + + /** + * Start a thread to periodically update the current user's credentials with new delegation + * tokens so that writes to HDFS do not fail. + */ + private[spark] def startExecutorDelegationTokenRenewer(conf: SparkConf) {} + + /** + * Stop the thread that does the delegation token updates. + */ + private[spark] def stopExecutorDelegationTokenRenewer() {} + + /** + * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. + * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. + */ + private[spark] def getConfBypassingFSCache( + hadoopConf: Configuration, + scheme: String): Configuration = { + val newConf = new Configuration(hadoopConf) + val confKey = s"fs.${scheme}.impl.disable.cache" + newConf.setBoolean(confKey, true) + newConf + } } object SparkHadoopUtil { @@ -200,7 +386,7 @@ object SparkHadoopUtil { System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") .newInstance() .asInstanceOf[SparkHadoopUtil] } catch { @@ -211,6 +397,10 @@ object SparkHadoopUtil { } } + val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp" + + val SPARK_YARN_CREDS_COUNTER_DELIM = "-" + def get: SparkHadoopUtil = { hadoop } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 02021be9f93d..86fcf942c2c4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -18,15 +18,41 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} -import java.lang.reflect.{Modifier, InvocationTargetException} +import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.security.UserGroupInformation +import org.apache.ivy.Ivy +import org.apache.ivy.core.LogOptions +import org.apache.ivy.core.module.descriptor._ +import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId} +import org.apache.ivy.core.report.ResolveReport +import org.apache.ivy.core.resolve.ResolveOptions +import org.apache.ivy.core.retrieve.RetrieveOptions +import org.apache.ivy.core.settings.IvySettings +import org.apache.ivy.plugins.matcher.GlobPatternMatcher +import org.apache.ivy.plugins.repository.file.FileRepository +import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} -import org.apache.spark.executor.ExecutorURLClassLoader -import org.apache.spark.util.Utils +import org.apache.spark.{SparkUserAppException, SPARK_VERSION} +import org.apache.spark.api.r.RUtils +import org.apache.spark.deploy.rest._ +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + + +/** + * Whether to submit, kill, or request the status of an application. + * The latter two operations are currently supported only for standalone cluster mode. + */ +private[deploy] object SparkSubmitAction extends Enumeration { + type SparkSubmitAction = Value + val SUBMIT, KILL, REQUEST_STATUS = Value +} /** * Main gateway of launching a Spark application. @@ -55,39 +81,143 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(-1) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String) = { + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() + exitFn(1) } + private[spark] def printVersionAndExit(): Unit = { + printStream.println("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + printStream.println("Type --help for more information.") + exitFn(0) + } + // scalastyle:on println - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println + } + appArgs.action match { + case SparkSubmitAction.SUBMIT => submit(appArgs) + case SparkSubmitAction.KILL => kill(appArgs) + case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) } - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) - launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) } /** - * @return a tuple containing - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a list of system properties and env vars, and - * (4) the main class for the child + * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + */ + private def kill(args: SparkSubmitArguments): Unit = { + new RestSubmissionClient(args.master) + .killSubmission(args.submissionToKill) + } + + /** + * Request the status of an existing submission using the REST protocol. + * Standalone and Mesos cluster mode only. */ - private[spark] def createLaunchEnv(args: SparkSubmitArguments) - : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { + private def requestStatus(args: SparkSubmitArguments): Unit = { + new RestSubmissionClient(args.master) + .requestSubmissionStatus(args.submissionToRequestStatusFor) + } + + /** + * Submit the application using the provided parameters. + * + * This runs in two steps. First, we prepare the launch environment by setting up + * the appropriate classpath, system properties, and application arguments for + * running the child main class based on the cluster manager and the deploy mode. + * Second, we use this launch environment to invoke the main method of the child + * main class. + */ + private def submit(args: SparkSubmitArguments): Unit = { + val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + + def doRunMain(): Unit = { + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + // scalastyle:off println + printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println + exitFn(1) + } else { + throw e + } + } + } else { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + } - // Values to return + // In standalone cluster mode, there are two submission gateways: + // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper + // (2) The new REST-based gateway introduced in Spark 1.3 + // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over + // to use the legacy gateway if the master endpoint turns out to be not a REST server. + if (args.isStandaloneCluster && args.useRest) { + try { + // scalastyle:off println + printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println + doRunMain() + } catch { + // Fail over to use the legacy submission gateway + case e: SubmitRestConnectionException => + printWarning(s"Master endpoint ${args.master} was not a REST server. " + + "Falling back to legacy submission gateway instead.") + args.useRest = false + submit(args) + } + // In all other modes, just run the main class as prepared + } else { + doRunMain() + } + } + + /** + * Prepare the environment for submitting an application. + * This returns a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * Exposed for testing. + */ + private[deploy] def prepareSubmitEnvironment(args: SparkSubmitArguments) + : (Seq[String], Seq[String], Map[String, String], String) = { + // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() @@ -136,7 +266,37 @@ object SparkSubmit { } } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER + + // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files + // too for packages that include Python code + val exclusions: Seq[String] = + if (!StringUtils.isBlank(args.packagesExclusions)) { + args.packagesExclusions.split(",") + } else { + Nil + } + val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, + Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) + if (args.isPython) { + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) + } + } + + // install any R packages that may have been passed through --jars or --packages. + // Spark Packages may contain R source code inside the jar. + if (args.isR && !StringUtils.isBlank(args.jars)) { + RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) + } // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local @@ -150,13 +310,24 @@ object SparkSubmit { } } + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (MESOS, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -169,26 +340,61 @@ object SparkSubmit { // If we're running a python app, set the main class to our specific python runner if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { - args.mainClass = "py4j.GatewayServer" - args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0") + args.mainClass = "org.apache.spark.api.python.PythonGatewayServer" } else { // If a python file is provided, add it to the child arguments and list of files to deploy. // Usage: PythonAppRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs - args.files = mergeFileLists(args.files, args.primaryResource) + if (clusterManager != YARN) { + // The YARN backend distributes the primary file differently, so don't merge it. + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + if (clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) } - args.files = mergeFileLists(args.files, args.pyFiles) if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } } - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // In YARN mode for an R app, add the SparkR package archive to archives // that can be distributed with the job - if (args.isPython && isYarnCluster) { + if (args.isR && clusterManager == YARN) { + val rPackagePath = RUtils.localSparkRPackagePath + if (rPackagePath.isEmpty) { + printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + } + val rPackageFile = + RPackageUtils.zipRLibraries(new File(rPackagePath.get), SPARKR_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + } + val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + } + + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster && args.isR) { + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) } // Special flag to avoid deprecation warnings at the client @@ -200,8 +406,11 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, @@ -211,37 +420,45 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Standalone cluster only - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), - // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options + OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") ) // In client mode, launch the application main class directly @@ -255,7 +472,6 @@ object SparkSubmit { if (args.childArgs != null) { childArgs ++= args.childArgs } } - // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { if (opt.value != null && @@ -268,8 +484,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -277,32 +493,50 @@ object SparkSubmit { sysProps.put("spark.jars", jars.mkString(",")) } - // In standalone-cluster mode, use Client as a wrapper around the user class - if (clusterManager == STANDALONE && deployMode == CLUSTER) { - childMainClass = "org.apache.spark.deploy.Client" - if (args.supervise) { - childArgs += "--supervise" + // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). + // All Spark parameters are expected to be passed to the client through system properties. + if (args.isStandaloneCluster) { + if (args.useRest) { + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + } else { + // In legacy standalone cluster mode, use Client as a wrapper around the user class + childMainClass = "org.apache.spark.deploy.Client" + if (args.supervise) { childArgs += "--supervise" } + Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } + Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } + childArgs += "launch" + childArgs += (args.master, args.primaryResource, args.mainClass) } - childArgs += "launch" - childArgs += (args.master, args.primaryResource, args.mainClass) if (args.childArgs != null) { childArgs ++= args.childArgs } } + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN) { + if (args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when the keytab is specified") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { - val mainPyFile = new Path(args.primaryResource).getName - childArgs += ("--primary-py-file", mainPyFile) + childArgs += ("--primary-py-file", args.primaryResource) if (args.pyFiles != null) { - // These files will be distributed to each machine's working directory, so strip the - // path prefix - val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") - childArgs += ("--py-files", pyFilesNames) + childArgs += ("--py-files", args.pyFiles) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) @@ -314,6 +548,15 @@ object SparkSubmit { } } + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) @@ -321,7 +564,7 @@ object SparkSubmit { // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= ("spark.driver.host") + sysProps -= "spark.driver.host" } // Resolve paths in certain spark properties @@ -350,12 +593,19 @@ object SparkSubmit { (childArgs, childClasspath, sysProps, childMainClass) } - private def launch( - childArgs: ArrayBuffer[String], - childClasspath: ArrayBuffer[String], + /** + * Run the main method of the child class using the provided launch environment. + * + * Note that this main class will not be the one provided by the user if we're + * running cluster deploy mode or python applications. + */ + private def runMain( + childArgs: Seq[String], + childClasspath: Seq[String], sysProps: Map[String, String], childMainClass: String, - verbose: Boolean = false) { + verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -363,9 +613,16 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println - val loader = new ExecutorURLClassLoader(new Array[URL](0), - Thread.currentThread.getContextClassLoader) + val loader = + if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } else { + new MutableURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } Thread.currentThread.setContextClassLoader(loader) for (jar <- childClasspath) { @@ -379,13 +636,15 @@ object SparkSubmit { var mainClass: Class[_] = null try { - mainClass = Class.forName(childMainClass, true, loader) + mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { - println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:off println + printStream.println(s"Failed to load main class $childMainClass.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -399,17 +658,31 @@ object SparkSubmit { if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") } + + def findCause(t: Throwable): Throwable = t match { + case e: UndeclaredThrowableException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: InvocationTargetException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: Throwable => + e + } + try { mainMethod.invoke(null, childArgs.toArray) } catch { - case e: InvocationTargetException => e.getCause match { - case cause: Throwable => throw cause - case null => throw e - } + case t: Throwable => + findCause(t) match { + case SparkUserAppException(exitCode) => + System.exit(exitCode) + + case t: Throwable => + throw t + } } } - private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) { + private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -427,59 +700,344 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[spark] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** * Return whether the given main class represents a sql shell. */ - private[spark] def isSqlShell(mainClass: String): Boolean = { + private[deploy] def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } /** * Return whether the given main class represents a thrift server. */ - private[spark] def isThriftServer(mainClass: String): Boolean = { + private def isThriftServer(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" } /** * Return whether the given primary resource requires running python. */ - private[spark] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL } - private[spark] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL + } + + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. */ - private[spark] def mergeFileLists(lists: String*): String = { - val merged = lists.filter(_ != null) + private def mergeFileLists(lists: String*): String = { + val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") if (merged == "") null else merged } } +/** Provides utility functions to be used inside SparkSubmit. */ +private[spark] object SparkSubmitUtils { + + // Exposed for testing + var printStream = SparkSubmit.printStream + + /** + * Represents a Maven Coordinate + * @param groupId the groupId of the coordinate + * @param artifactId the artifactId of the coordinate + * @param version the version of the coordinate + */ + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } + +/** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ + def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { + coordinates.split(",").map { p => + val splits = p.replace("/", ":").split(":") + require(splits.length == 3, s"Provided Maven Coordinates must be in the form " + + s"'groupId:artifactId:version'. The coordinate provided is: $p") + require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " + + s"be whitespace. The groupId provided is: ${splits(0)}") + require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " + + s"be whitespace. The artifactId provided is: ${splits(1)}") + require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + + s"be whitespace. The version provided is: ${splits(2)}") + new MavenCoordinate(splits(0), splits(1), splits(2)) + } + } + + /** Path of the local Maven cache. */ + private[spark] def m2Path: File = { + if (Utils.isTesting) { + // test builds delete the maven cache, and this can cause flakiness + new File("dummy", ".m2" + File.separator + "repository") + } else { + new File(System.getProperty("user.home"), ".m2" + File.separator + "repository") + } + } + + /** + * Extracts maven coordinates from a comma-delimited string + * @param remoteRepos Comma-delimited string of remote repositories + * @param ivySettings The Ivy settings for this session + * @return A ChainResolver used by Ivy to search for and resolve dependencies. + */ + def createRepoResolvers(remoteRepos: Option[String], ivySettings: IvySettings): ChainResolver = { + // We need a chain resolver if we want to check multiple repositories + val cr = new ChainResolver + cr.setName("list") + + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + // scalastyle:off println + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println + } + } + + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + localM2.setRoot(m2Path.toURI.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + + val localIvy = new FileSystemResolver + val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") + localIvy.setLocal(true) + localIvy.setRepository(new FileRepository(localIvyRoot)) + val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", + "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern) + localIvy.setName("local-ivy-cache") + cr.add(localIvy) + + // the biblio resolver resolves POM declared dependencies + val br: IBiblioResolver = new IBiblioResolver + br.setM2compatible(true) + br.setUsepoms(true) + br.setName("central") + cr.add(br) + + val sp: IBiblioResolver = new IBiblioResolver + sp.setM2compatible(true) + sp.setUsepoms(true) + sp.setRoot("http://dl.bintray.com/spark-packages/maven") + sp.setName("spark-packages") + cr.add(sp) + cr + } + + /** + * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath + * (will append to jars in SparkSubmit). + * @param artifacts Sequence of dependencies that were resolved and retrieved + * @param cacheDirectory directory where jars are cached + * @return a comma-delimited list of paths for the dependencies + */ + def resolveDependencyPaths( + artifacts: Array[AnyRef], + cacheDirectory: File): String = { + artifacts.map { artifactInfo => + val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId + cacheDirectory.getAbsolutePath + File.separator + + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar" + }.mkString(",") + } + + /** Adds the given maven coordinates to Ivy's module descriptor. */ + def addDependenciesToIvy( + md: DefaultModuleDescriptor, + artifacts: Seq[MavenCoordinate], + ivyConfName: String): Unit = { + artifacts.foreach { mvn => + val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) + val dd = new DefaultDependencyDescriptor(ri, false, false) + dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println + printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println + md.addDependency(dd) + } + } + + /** Add exclusion rules for dependencies already included in the spark-assembly */ + def addExclusionRules( + ivySettings: IvySettings, + ivyConfName: String, + md: DefaultModuleDescriptor): Unit = { + // Add scala exclusion rule + md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) + + // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and + // other spark-streaming utility components. Underscore is there to differentiate between + // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x + val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") + + components.foreach { comp => + md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, + ivyConfName)) + } + } + + /** A nice function to use in tests as well. Values are dummy strings. */ + def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + + /** + * Resolves any dependencies that were supplied through maven coordinates + * @param coordinates Comma-delimited string of maven coordinates + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @param exclusions Exclusions to apply when resolving transitive dependencies + * @return The comma-delimited path to the jars of the given maven artifacts including their + * transitive dependencies + */ + def resolveMavenCoordinates( + coordinates: String, + remoteRepos: Option[String], + ivyPath: Option[String], + exclusions: Seq[String] = Nil, + isTest: Boolean = false): String = { + if (coordinates == null || coordinates.trim.isEmpty) { + "" + } else { + val sysOut = System.out + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + // scalastyle:off println + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } + + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + + md.setDefaultConf(ivyConfName) + + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) + exclusions.foreach { e => + md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) + } + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) + } + } + } + + private[deploy] def createExclusion( + coords: String, + ivySettings: IvySettings, + ivyConfName: String): ExcludeRule = { + val c = extractMavenCoordinates(coords)(0) + val id = new ArtifactId(new ModuleId(c.groupId, c.artifactId), "*", "*", "*") + val rule = new DefaultExcludeRule(id, ivySettings.getMatcher("glob"), null) + rule.addConfiguration(ivyConfName) + rule + } + +} + /** * Provides an indirection layer for passing arguments as system properties or flags to * the user's driver program or to downstream launcher tools. */ -private[spark] case class OptionAssigner( +private case class OptionAssigner( value: String, clusterManager: Int, deployMode: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 73e921fd83ef..18a1c52ae53f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,18 +17,26 @@ package org.apache.spark.deploy +import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.reflect.InvocationTargetException import java.net.URI +import java.util.{List => JList} import java.util.jar.JarFile +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Source +import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.util.Utils /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. */ -private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) { +private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) + extends SparkSubmitArgumentsParser { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -39,8 +47,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverExtraClassPath: String = null var driverExtraLibraryPath: String = null var driverExtraJavaOptions: String = null - var driverCores: String = null - var supervise: Boolean = false var queue: String = null var numExecutors: String = null var files: String = null @@ -50,36 +56,57 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var name: String = null var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var jars: String = null + var packages: String = null + var repositories: String = null + var ivyRepoPath: String = null + var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false + var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() + var proxyUser: String = null + var principal: String = null + var keytab: String = null + + // Standalone cluster mode only + var supervise: Boolean = false + var driverCores: String = null + var submissionToKill: String = null + var submissionToRequestStatusFor: String = null + var useRest: Boolean = true // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => - if (k.startsWith("spark.")) { - defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") - } else { - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") - } + defaultProperties(k) = v + if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } // Set parameters from command line arguments - parseOpts(args.toList) + try { + parse(args.asJava) + } catch { + case e: IllegalArgumentException => + SparkSubmit.printErrorAndExit(e.getMessage()) + } // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() + // Remove keys that don't start with "spark." from `sparkProperties`. + ignoreNonSparkProperties() // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() - checkRequiredArguments() + validateArguments() /** * Merge values from the default properties file with those specified through --conf. @@ -96,6 +123,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } + /** + * Remove keys that don't start with "spark." from `sparkProperties`. + */ + private def ignoreNonSparkProperties(): Unit = { + sparkProperties.foreach { case (k, v) => + if (!k.startsWith("spark.")) { + sparkProperties -= k + SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + } + } + } + /** * Load arguments from environment variables, Spark properties etc. */ @@ -104,6 +143,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) .orNull + driverExtraClassPath = Option(driverExtraClassPath) + .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orNull + driverExtraJavaOptions = Option(driverExtraJavaOptions) + .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orNull + driverExtraLibraryPath = Option(driverExtraLibraryPath) + .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orNull driverMemory = Option(driverMemory) .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) @@ -117,18 +165,25 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) .orNull name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull + ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull + packagesExclusions = Option(packagesExclusions) + .orElse(sparkProperties.get("spark.jars.excludes")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -162,17 +217,28 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (name == null && primaryResource != null) { name = Utils.stripDirectory(primaryResource) } + + // Action should be SUBMIT unless otherwise specified + action = Option(action).getOrElse(SUBMIT) } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments(): Unit = { + private def validateArguments(): Unit = { + action match { + case SUBMIT => validateSubmitArguments() + case KILL => validateKillArguments() + case REQUEST_STATUS => validateStatusRequestArguments() + } + } + + private def validateSubmitArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { @@ -188,7 +254,31 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - override def toString = { + private def validateKillArguments(): Unit = { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Killing submissions is only supported in standalone or Mesos mode!") + } + if (submissionToKill == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + } + } + + private def validateStatusRequestArguments(): Unit = { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Requesting submission statuses is only supported in standalone or Mesos mode!") + } + if (submissionToRequestStatusFor == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + } + } + + def isStandaloneCluster: Boolean = { + master.startsWith("spark://") && deployMode == "cluster" + } + + override def toString: String = { s"""Parsed arguments: | master $master | deployMode $deployMode @@ -212,6 +302,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | name $name | childArgs [${childArgs.mkString(" ")}] | jars $jars + | packages $packages + | packagesExclusions $packagesExclusions + | repositories $repositories | verbose $verbose | |Spark properties used, including those specified through @@ -220,145 +313,169 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St """.stripMargin } - /** - * Fill in values by parsing user options. - * NOTE: Any changes here must be reflected in YarnClientSchedulerBackend. - */ - private def parseOpts(opts: Seq[String]): Unit = { - val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r - - // Delineates parsing of Spark options from parsing of user options. - parse(opts) - - /** - * NOTE: If you add or remove spark-submit options, - * modify NOT ONLY this file but also utils.sh - */ - def parse(opts: Seq[String]): Unit = opts match { - case ("--name") :: value :: tail => + /** Fill in values by parsing user options. */ + override protected def handle(opt: String, value: String): Boolean = { + opt match { + case NAME => name = value - parse(tail) - case ("--master") :: value :: tail => + case MASTER => master = value - parse(tail) - case ("--class") :: value :: tail => + case CLASS => mainClass = value - parse(tail) - case ("--deploy-mode") :: value :: tail => + case DEPLOY_MODE => if (value != "client" && value != "cluster") { SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value - parse(tail) - case ("--num-executors") :: value :: tail => + case NUM_EXECUTORS => numExecutors = value - parse(tail) - case ("--total-executor-cores") :: value :: tail => + case TOTAL_EXECUTOR_CORES => totalExecutorCores = value - parse(tail) - case ("--executor-cores") :: value :: tail => + case EXECUTOR_CORES => executorCores = value - parse(tail) - case ("--executor-memory") :: value :: tail => + case EXECUTOR_MEMORY => executorMemory = value - parse(tail) - case ("--driver-memory") :: value :: tail => + case DRIVER_MEMORY => driverMemory = value - parse(tail) - case ("--driver-cores") :: value :: tail => + case DRIVER_CORES => driverCores = value - parse(tail) - case ("--driver-class-path") :: value :: tail => + case DRIVER_CLASS_PATH => driverExtraClassPath = value - parse(tail) - case ("--driver-java-options") :: value :: tail => + case DRIVER_JAVA_OPTIONS => driverExtraJavaOptions = value - parse(tail) - case ("--driver-library-path") :: value :: tail => + case DRIVER_LIBRARY_PATH => driverExtraLibraryPath = value - parse(tail) - case ("--properties-file") :: value :: tail => + case PROPERTIES_FILE => propertiesFile = value - parse(tail) - case ("--supervise") :: tail => + case KILL_SUBMISSION => + submissionToKill = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + } + action = KILL + + case STATUS => + submissionToRequestStatusFor = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + } + action = REQUEST_STATUS + + case SUPERVISE => supervise = true - parse(tail) - case ("--queue") :: value :: tail => + case QUEUE => queue = value - parse(tail) - case ("--files") :: value :: tail => + case FILES => files = Utils.resolveURIs(value) - parse(tail) - case ("--py-files") :: value :: tail => + case PY_FILES => pyFiles = Utils.resolveURIs(value) - parse(tail) - case ("--archives") :: value :: tail => + case ARCHIVES => archives = Utils.resolveURIs(value) - parse(tail) - case ("--jars") :: value :: tail => + case JARS => jars = Utils.resolveURIs(value) - parse(tail) - case ("--conf" | "-c") :: value :: tail => + case PACKAGES => + packages = value + + case PACKAGES_EXCLUDE => + packagesExclusions = value + + case REPOSITORIES => + repositories = value + + case CONF => value.split("=", 2).toSeq match { case Seq(k, v) => sparkProperties(k) = v case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") } - parse(tail) - case ("--help" | "-h") :: tail => + case PROXY_USER => + proxyUser = value + + case PRINCIPAL => + principal = value + + case KEYTAB => + keytab = value + + case HELP => printUsageAndExit(0) - case ("--verbose" | "-v") :: tail => + case VERBOSE => verbose = true - parse(tail) - case EQ_SEPARATED_OPT(opt, value) :: tail => - parse(opt :: value :: tail) + case VERSION => + SparkSubmit.printVersionAndExit() - case value :: tail if value.startsWith("-") => - SparkSubmit.printErrorAndExit(s"Unrecognized option '$value'.") + case USAGE_ERROR => + printUsageAndExit(1) - case value :: tail => - primaryResource = - if (!SparkSubmit.isShell(value) && !SparkSubmit.isInternal(value)) { - Utils.resolveURI(value).toString - } else { - value - } - isPython = SparkSubmit.isPython(value) - childArgs ++= tail + case _ => + throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + } + true + } - case Nil => + /** + * Handle unrecognized command line options. + * + * The first unrecognized option is treated as the "primary resource". Everything else is + * treated as application arguments. + */ + override protected def handleUnknown(opt: String): Boolean = { + if (opt.startsWith("-")) { + SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") } + + primaryResource = + if (!SparkSubmit.isShell(opt) && !SparkSubmit.isInternal(opt)) { + Utils.resolveURI(opt).toString + } else { + opt + } + isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) + false + } + + override protected def handleExtraArgs(extra: JList[String]): Unit = { + childArgs ++= extra.asScala } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } + val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( + """Usage: spark-submit [options] [app arguments] + |Usage: spark-submit --kill [submission ID] --master [spark://...] + |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + outStream.println(command) + + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB outStream.println( - """Usage: spark-submit [options] [app options] + s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -368,6 +485,16 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --name NAME A name of your application. | --jars JARS Comma-separated list of local jars to include on the driver | and executor classpaths. + | --packages Comma-separated list of maven coordinates of jars to include + | on the driver and executor classpaths. Will search the local + | maven repo, then maven central and any additional remote + | repositories given by --repositories. The format for the + | coordinates should be groupId:artifactId:version. + | --exclude-packages Comma-separated list of groupId:artifactId, to exclude while + | resolving the dependencies provided in --packages to avoid + | dependency conflicts. + | --repositories Comma-separated list of additional remote repositories to + | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working @@ -377,7 +504,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M). | --driver-java-options Extra Java options to pass to the driver. | --driver-library-path Extra library path entries to pass to the driver. | --driver-class-path Extra class path entries to pass to the driver. Note that @@ -386,26 +513,102 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | + | --proxy-user NAME User to impersonate when submitting the application. + | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output + | --version, Print the version of current Spark | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). + | + | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. + | --kill SUBMISSION_ID If given, kills the driver specified. + | --status SUBMISSION_ID If given, requests the status of the driver specified. | | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | + | Spark standalone and YARN only: + | --executor-cores NUM Number of cores per executor. (Default: 1 in YARN mode, + | or all available cores on the worker in standalone mode) + | | YARN-only: | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). - | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the | working directory of each executor. + | --principal PRINCIPAL Principal to be used to login to KDC, while running on + | secure HDFS. + | --keytab KEYTAB The full path to the file that contains the keytab for the + | principal specified above. This keytab will be copied to + | the node running the Application Master via the Secure + | Distributed Cache, for renewing the login tickets and the + | delegation tokens periodically. """.stripMargin ) - SparkSubmit.exitFn() + + if (SparkSubmit.isSqlShell(mainClass)) { + outStream.println("CLI options:") + outStream.println(getSqlShellOptions()) + } + // scalastyle:on println + + SparkSubmit.exitFn(exitCode) + } + + /** + * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter + * the results to remove unwanted lines. + * + * Since the CLI will call `System.exit()`, we install a security manager to prevent that call + * from working, and restore the original one afterwards. + */ + private def getSqlShellOptions(): String = { + val currentOut = System.out + val currentErr = System.err + val currentSm = System.getSecurityManager() + try { + val out = new ByteArrayOutputStream() + val stream = new PrintStream(out) + System.setOut(stream) + System.setErr(stream) + + val sm = new SecurityManager() { + override def checkExit(status: Int): Unit = { + throw new SecurityException() + } + + override def checkPermission(perm: java.security.Permission): Unit = {} + } + System.setSecurityManager(sm) + + try { + Utils.classForName(mainClass).getMethod("main", classOf[Array[String]]) + .invoke(null, Array(HELP)) + } catch { + case e: InvocationTargetException => + // Ignore SecurityException, since we throw it above. + if (!e.getCause().isInstanceOf[SecurityException]) { + throw e + } + } + + stream.flush() + + // Get the output and discard any unnecessary lines from it. + Source.fromString(new String(out.toByteArray())).getLines + .filter { line => + !line.startsWith("log4j") && !line.startsWith("usage") + } + .mkString("\n") + } finally { + System.setSecurityManager(currentSm) + System.setOut(currentOut) + System.setErr(currentErr) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala deleted file mode 100644 index 2eab9981845e..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.io.File - -import scala.collection.JavaConversions._ - -import org.apache.spark.util.{RedirectThread, Utils} - -/** - * Launch an application through Spark submit in client mode with the appropriate classpath, - * library paths, java options and memory. These properties of the JVM must be set before the - * driver JVM is launched. The sole purpose of this class is to avoid handling the complexity - * of parsing the properties file for such relevant configs in Bash. - * - * Usage: org.apache.spark.deploy.SparkSubmitDriverBootstrapper - */ -private[spark] object SparkSubmitDriverBootstrapper { - - // Note: This class depends on the behavior of `bin/spark-class` and `bin/spark-submit`. - // Any changes made there must be reflected in this file. - - def main(args: Array[String]): Unit = { - - // This should be called only from `bin/spark-class` - if (!sys.env.contains("SPARK_CLASS")) { - System.err.println("SparkSubmitDriverBootstrapper must be called from `bin/spark-class`!") - System.exit(1) - } - - val submitArgs = args - val runner = sys.env("RUNNER") - val classpath = sys.env("CLASSPATH") - val javaOpts = sys.env("JAVA_OPTS") - val defaultDriverMemory = sys.env("OUR_JAVA_MEM") - - // Spark submit specific environment variables - val deployMode = sys.env("SPARK_SUBMIT_DEPLOY_MODE") - val propertiesFile = sys.env("SPARK_SUBMIT_PROPERTIES_FILE") - val bootstrapDriver = sys.env("SPARK_SUBMIT_BOOTSTRAP_DRIVER") - val submitDriverMemory = sys.env.get("SPARK_SUBMIT_DRIVER_MEMORY") - val submitLibraryPath = sys.env.get("SPARK_SUBMIT_LIBRARY_PATH") - val submitClasspath = sys.env.get("SPARK_SUBMIT_CLASSPATH") - val submitJavaOpts = sys.env.get("SPARK_SUBMIT_OPTS") - - assume(runner != null, "RUNNER must be set") - assume(classpath != null, "CLASSPATH must be set") - assume(javaOpts != null, "JAVA_OPTS must be set") - assume(defaultDriverMemory != null, "OUR_JAVA_MEM must be set") - assume(deployMode == "client", "SPARK_SUBMIT_DEPLOY_MODE must be \"client\"!") - assume(propertiesFile != null, "SPARK_SUBMIT_PROPERTIES_FILE must be set") - assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") - - // Parse the properties file for the equivalent spark.driver.* configs - val properties = Utils.getPropertiesFromFile(propertiesFile) - val confDriverMemory = properties.get("spark.driver.memory") - val confLibraryPath = properties.get("spark.driver.extraLibraryPath") - val confClasspath = properties.get("spark.driver.extraClassPath") - val confJavaOpts = properties.get("spark.driver.extraJavaOptions") - - // Favor Spark submit arguments over the equivalent configs in the properties file. - // Note that we do not actually use the Spark submit values for library path, classpath, - // and Java opts here, because we have already captured them in Bash. - - val newDriverMemory = submitDriverMemory - .orElse(confDriverMemory) - .getOrElse(defaultDriverMemory) - - val newClasspath = - if (submitClasspath.isDefined) { - classpath - } else { - classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") - } - - val newJavaOpts = - if (submitJavaOpts.isDefined) { - // SPARK_SUBMIT_OPTS is already captured in JAVA_OPTS - javaOpts - } else { - javaOpts + confJavaOpts.map(" " + _).getOrElse("") - } - - val filteredJavaOpts = Utils.splitCommandString(newJavaOpts) - .filterNot(_.startsWith("-Xms")) - .filterNot(_.startsWith("-Xmx")) - - // Build up command - val command: Seq[String] = - Seq(runner) ++ - Seq("-cp", newClasspath) ++ - filteredJavaOpts ++ - Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ - Seq("org.apache.spark.deploy.SparkSubmit") ++ - submitArgs - - // Print the launch command. This follows closely the format used in `bin/spark-class`. - if (sys.env.contains("SPARK_PRINT_LAUNCH_COMMAND")) { - System.err.print("Spark Command: ") - System.err.println(command.mkString(" ")) - System.err.println("========================================\n") - } - - // Start the driver JVM - val filteredCommand = command.filter(_.nonEmpty) - val builder = new ProcessBuilder(filteredCommand) - val env = builder.environment() - - if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) { - val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName) - env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator"))) - } - - val process = builder.start() - - // If we kill an app while it's running, its sub-process should be killed too. - Runtime.getRuntime().addShutdownHook(new Thread() { - override def run() = { - if (process != null) { - process.destroy() - process.waitFor() - } - } - }) - - // Redirect stdout and stderr from the child JVM - val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") - val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") - stdoutThread.start() - stderrThread.start() - - // Redirect stdin to child JVM only if we're not running Windows. This is because the - // subprocess there already reads directly from our stdin, so we should avoid spawning a - // thread that contends with the subprocess in reading from System.in. - val isWindows = Utils.isWindows - val isSubprocess = sys.env.contains("IS_SUBPROCESS") - if (!isWindows) { - val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin", - propagateEof = true) - stdinThread.start() - // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on - // broken pipe, signaling that the parent process has exited. This is the case if the - // application is launched directly from python, as in the PySpark shell. In Windows, - // the termination logic is handled in java_gateway.py - if (isSubprocess) { - stdinThread.join() - process.destroy() - } - } - val returnCode = process.waitFor() - sys.exit(returnCode) - } - -} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index ffe940fbda2f..25ea6925434a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + + private val REGISTRATION_TIMEOUT_SECONDS = 20 + private val REGISTRATION_RETRIES = 3 + + private var endpoint: RpcEndpointRef = null + private var appId: String = null + @volatile private var registered = false - val REGISTRATION_TIMEOUT = 20.seconds - val REGISTRATION_RETRIES = 3 + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { - var masterAddress: Address = null - var actor: ActorRef = null - var appId: String = null - var registered = false - var activeMasterUrl: String = null + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null - class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") + + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,23 +184,48 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() - case DisassociatedEvent(_, address, _) if address == masterAddress => + case r: RequestExecutors => + master match { + case Some(m) => context.reply(m.askWithRetry[Boolean](r)) + case None => + logWarning("Attempted to request executors before registering with Master.") + context.reply(false) + } + + case k: KillExecutors => + master match { + case Some(m) => context.reply(m.askWithRetry[Boolean](k)) + case None => + logWarning("Attempted to kill executors before registering with Master.") + context.reply(false) + } + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - sender ! true - context.stop(self) + } } /** @@ -178,28 +245,61 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { - // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + // Just launch an rpcEndpoint; it will call back into the listener. + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = AkkaUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + val timeout = RpcUtils.askRpcTimeout(conf) + timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null + } + } + + /** + * Request executors from the Master by specifying the total number desired, + * including existing pending and running executors. + * + * @return whether the request is acknowledged. + */ + def requestTotalExecutors(requestedTotal: Int): Boolean = { + if (endpoint != null && appId != null) { + endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal)) + } else { + logWarning("Attempted to request executors before driver fully initialized.") + false } } + + /** + * Kill the given list of executors through the Master. + * @return whether the kill request is acknowledged. + */ + def killExecutors(executorIds: Seq[String]): Boolean = { + if (endpoint != null && appId != null) { + endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds)) + } else { + logWarning("Attempted to kill executors before driver fully initialized.") + false + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 88a0862b96af..1c79089303e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,13 +17,14 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { - class TestListener extends AppClientListener with Logging { + private class TestListener extends AppClientListener with Logging { def connected(id: String) { logInfo("Connected to master, got app ID " + id) } @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d..a98b1fa8f83a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 553bf3cb945a..5f5e0fe1c34d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,18 +17,25 @@ package org.apache.spark.deploy.history +import java.util.zip.ZipOutputStream + +import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI -private[spark] case class ApplicationHistoryInfo( - id: String, - name: String, +private[spark] case class ApplicationAttemptInfo( + attemptId: Option[String], startTime: Long, endTime: Long, lastUpdated: Long, sparkUser: String, completed: Boolean = false) -private[spark] abstract class ApplicationHistoryProvider { +private[spark] case class ApplicationHistoryInfo( + id: String, + name: String, + attempts: List[ApplicationAttemptInfo]) + +private[history] abstract class ApplicationHistoryProvider { /** * Returns a list of applications available for the history server to show. @@ -41,9 +48,10 @@ private[spark] abstract class ApplicationHistoryProvider { * Returns the Spark UI for a specific application. * * @param appId The application ID. + * @param attemptId The application attempt ID (or None if there is no attempt ID). * @return The application's UI, or None if application is not found. */ - def getAppUI(appId: String): Option[SparkUI] + def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] /** * Called when the server is shutting down. @@ -57,4 +65,12 @@ private[spark] abstract class ApplicationHistoryProvider { */ def getConfig(): Map[String, String] = Map() + /** + * Writes out the event logs to the output stream provided. The logs will be compressed into a + * single zip file and written out. + * @throws SparkException if the logs for the app id cannot be found. + */ + @throws(classOf[SparkException]) + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 0ae45f4ad913..e573ff16c50a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,44 +17,58 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, InputStream} +import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable -import org.apache.hadoop.fs.{FileStatus, Path} +import com.google.common.io.ByteStreams +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.AccessControlException -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A class that provides application history from event logs stored in the file system. * This provider checks for new finished applications in the background periodically and * renders the history application UI by parsing the associated event logs. */ -private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider - with Logging { +private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) + extends ApplicationHistoryProvider with Logging { + + def this(conf: SparkConf) = { + this(conf, new SystemClock()) + } import FsHistoryProvider._ private val NOT_STARTED = "" // Interval between each check for event log updates - private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", - conf.getInt("spark.history.updateInterval", 10)) * 1000 + private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") + + // Interval between each cleaner checks for event logs to delete + private val CLEAN_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.cleaner.interval", "1d") private val logDir = conf.getOption("spark.history.fs.logDirectory") .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) - // A timestamp of when the disk was last accessed to check for log updates - private var lastLogCheckTimeMs = -1L + // Used by check event thread and clean log thread. + // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs + // and applications between check task and clean task. + private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() + .setNameFormat("spark-history-task-%d").setDaemon(true).build()) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that @@ -66,36 +80,32 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = "SPARK_VERSION_" - private[history] val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" + // List of application logs to be deleted by event log cleaner. + private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] /** - * A background thread that periodically checks for event log updates on disk. - * - * If a log check is invoked manually in the middle of a period, this thread re-adjusts the - * time at which it performs the next log check to maintain the same period as before. - * - * TODO: Add a mechanism to update manually. + * Return a runnable that performs the given operation on the event logs. + * This operation is expected to be executed periodically. */ - private val logCheckingThread = new Thread("LogCheckingThread") { - override def run() = Utils.logUncaughtExceptions { - while (true) { - val now = getMonotonicTimeMs() - if (now - lastLogCheckTimeMs > UPDATE_INTERVAL_MS) { - Thread.sleep(UPDATE_INTERVAL_MS) - } else { - // If the user has manually checked for logs recently, wait until - // UPDATE_INTERVAL_MS after the last check time - Thread.sleep(lastLogCheckTimeMs + UPDATE_INTERVAL_MS - now) - } - checkForLogs() + private def getRunner(operateFun: () => Unit): Runnable = { + new Runnable() { + override def run(): Unit = Utils.tryOrExit { + operateFun() } } } + /** + * An Executor to fetch and parse log files. + */ + private val replayExecutor: ExecutorService = { + if (!conf.contains("spark.testing")) { + ThreadUtils.newDaemonSingleThreadExecutor("log-replay-executor") + } else { + MoreExecutors.sameThreadExecutor() + } + } + initialize() private def initialize(): Unit = { @@ -104,7 +114,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } @@ -113,42 +123,47 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis "Logging directory specified is not a directory: %s".format(logDir)) } - checkForLogs() - // Disable the background thread during tests. if (!conf.contains("spark.testing")) { - logCheckingThread.setDaemon(true) - logCheckingThread.start() + // A task that periodically checks for event log updates on disk. + pool.scheduleWithFixedDelay(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) + + if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { + // A task that periodically cleans event logs on disk. + pool.scheduleWithFixedDelay(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) + } } } - override def getListing() = applications.values + override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values - override def getAppUI(appId: String): Option[SparkUI] = { + override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { - applications.get(appId).map { info => - val replayBus = new ReplayListenerBus() - val ui = { - val conf = this.conf.clone() - val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, - s"${HistoryServer.UI_PATH_PREFIX}/$appId") - // Do not call ui.bind() to avoid creating a new server for each application + applications.get(appId).flatMap { appInfo => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => + val replayBus = new ReplayListenerBus() + val ui = { + val conf = this.conf.clone() + val appSecManager = new SecurityManager(conf) + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) + // Do not call ui.bind() to avoid creating a new server for each application + } + val appListener = new ApplicationEventListener() + replayBus.addListener(appListener) + val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } - - val appListener = new ApplicationEventListener() - replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, info.logPath)), replayBus) - - ui.setAppName(s"${appInfo.name} ($appId)") - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(appInfo.sparkUser, - appListener.viewAcls.getOrElse("")) - ui } } catch { case e: FileNotFoundException => None @@ -163,19 +178,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * applications that haven't been updated since last time the logs were checked. */ private[history] def checkForLogs(): Unit = { - lastLogCheckTimeMs = getMonotonicTimeMs() - logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) - try { - var newLastModifiedTime = lastModifiedTime val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - val logInfos = statusList + var newLastModifiedTime = lastModifiedTime + val logInfos: Seq[FileStatus] = statusList .filter { entry => try { - val modTime = getModificationTime(entry) - newLastModifiedTime = math.max(newLastModifiedTime, modTime) - modTime >= lastModifiedTime + getModificationTime(entry).map { time => + newLastModifiedTime = math.max(newLastModifiedTime, time) + time >= lastModifiedTime + }.getOrElse(false) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -184,58 +197,254 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis false } } - .flatMap { entry => + .flatMap { entry => Some(entry) } + .sortWith { case (entry1, entry2) => + val mod1 = getModificationTime(entry1).getOrElse(-1L) + val mod2 = getModificationTime(entry2).getOrElse(-1L) + mod1 >= mod2 + } + + logInfos.grouped(20) + .map { batch => + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(batch) + }) + } + .foreach { task => try { - Some(replay(entry, new ReplayListenerBus())) + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() } catch { + case e: InterruptedException => + throw e case e: Exception => - logError(s"Failed to load application log data from $entry.", e) - None + logError("Exception while merging application listings", e) } } - .sortBy { info => (-info.endTime, -info.startTime) } lastModifiedTime = newLastModifiedTime + } catch { + case e: Exception => logError("Exception in checking for event log updates", e) + } + } + + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + + /** + * This method compresses the files passed in, and writes the compressed data out into the + * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being + * the name of the file being compressed. + */ + def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { + val fs = FileSystem.get(hadoopConf) + val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer + try { + outputStream.putNextEntry(new ZipEntry(entryName)) + ByteStreams.copy(inputStream, outputStream) + outputStream.closeEntry() + } finally { + inputStream.close() + } + } - // When there are new logs, merge the new list with the existing one, maintaining - // the expected ordering (descending end time). Maintaining the order is important - // to avoid having to sort the list every time there is a request for the log list. - if (!logInfos.isEmpty) { - val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() - def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id) || - newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && - !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { - newApps += (info.id -> info) + applications.get(appId) match { + case Some(appInfo) => + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + appInfo.attempts.filter { attempt => + attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get + }.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + // If this is a legacy directory, then add the directory to the zipStream and add + // each file to that directory. + if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { + val files = fs.listStatus(logPath) + zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) + zipStream.closeEntry() + files.foreach { file => + val path = file.getPath + zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) + } + } else { + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) + } } + } finally { + zipStream.close() } + case None => throw new SparkException(s"Logs for $appId not found.") + } + } - val newIterator = logInfos.iterator.buffered - val oldIterator = applications.values.iterator.buffered - while (newIterator.hasNext && oldIterator.hasNext) { - if (newIterator.head.endTime > oldIterator.head.endTime) { - addIfAbsent(newIterator.next) - } else { - addIfAbsent(oldIterator.next) - } + + /** + * Replay the log files in the list and merge the list of old applications with new ones + */ + private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { + val newAttempts = logs.flatMap { fileStatus => + try { + val bus = new ReplayListenerBus() + val res = replay(fileStatus, bus) + res match { + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + } + res + } catch { + case e: Exception => + logError( + s"Exception encountered when attempting to load application log ${fileStatus.getPath}", + e) + None + } + } + + if (newAttempts.isEmpty) { + return + } + + // Build a map containing all apps that contain new attempts. The app information in this map + // contains both the new app attempt, and those that were already loaded in the existing apps + // map. If an attempt has been updated, it replaces the old attempt in the list. + val newAppMap = new mutable.HashMap[String, FsApplicationHistoryInfo]() + newAttempts.foreach { attempt => + val appInfo = newAppMap.get(attempt.appId) + .orElse(applications.get(attempt.appId)) + .map { app => + val attempts = + app.attempts.filter(_.attemptId != attempt.attemptId).toList ++ List(attempt) + new FsApplicationHistoryInfo(attempt.appId, attempt.name, + attempts.sortWith(compareAttemptInfo)) + } + .getOrElse(new FsApplicationHistoryInfo(attempt.appId, attempt.name, List(attempt))) + newAppMap(attempt.appId) = appInfo + } + + // Merge the new app list with the existing one, maintaining the expected ordering (descending + // end time). Maintaining the order is important to avoid having to sort the list every time + // there is a request for the log list. + val newApps = newAppMap.values.toSeq.sortWith(compareAppInfo) + val mergedApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + def addIfAbsent(info: FsApplicationHistoryInfo): Unit = { + if (!mergedApps.contains(info.id)) { + mergedApps += (info.id -> info) + } + } + + val newIterator = newApps.iterator.buffered + val oldIterator = applications.values.iterator.buffered + while (newIterator.hasNext && oldIterator.hasNext) { + if (newAppMap.contains(oldIterator.head.id)) { + oldIterator.next() + } else if (compareAppInfo(newIterator.head, oldIterator.head)) { + addIfAbsent(newIterator.next()) + } else { + addIfAbsent(oldIterator.next()) + } + } + newIterator.foreach(addIfAbsent) + oldIterator.foreach(addIfAbsent) + + applications = mergedApps + } + + /** + * Delete event logs from the log directory according to the clean policy defined by the user. + */ + private[history] def cleanLogs(): Unit = { + try { + val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 + + val now = clock.getTimeMillis() + val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + + def shouldClean(attempt: FsApplicationAttemptInfo): Boolean = { + now - attempt.lastUpdated > maxAge && attempt.completed + } + + // Scan all logs from the log directory. + // Only completed applications older than the specified max age will be deleted. + applications.values.foreach { app => + val (toClean, toRetain) = app.attempts.partition(shouldClean) + attemptsToClean ++= toClean + + if (toClean.isEmpty) { + appsToRetain += (app.id -> app) + } else if (toRetain.nonEmpty) { + appsToRetain += (app.id -> + new FsApplicationHistoryInfo(app.id, app.name, toRetain.toList)) } - newIterator.foreach(addIfAbsent) - oldIterator.foreach(addIfAbsent) + } + + applications = appsToRetain - applications = newApps + val leftToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + attemptsToClean.foreach { attempt => + try { + val path = new Path(logDir, attempt.logPath) + if (fs.exists(path)) { + fs.delete(path, true) + } + } catch { + case e: AccessControlException => + logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") + case t: IOException => + logError(s"IOException in cleaning ${attempt.logPath}", t) + leftToClean += attempt + } } + + attemptsToClean = leftToClean } catch { - case e: Exception => logError("Exception in checking for event log updates", e) + case t: Exception => logError("Exception in cleaning logs", t) } } + /** + * Comparison function that defines the sort order for the application listing. + * + * @return Whether `i1` should precede `i2`. + */ + private def compareAppInfo( + i1: FsApplicationHistoryInfo, + i2: FsApplicationHistoryInfo): Boolean = { + val a1 = i1.attempts.head + val a2 = i2.attempts.head + if (a1.endTime != a2.endTime) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime + } + + /** + * Comparison function that defines the sort order for application attempts within the same + * application. Order is: attempts are sorted by descending start time. + * Most recent attempt state matches with current state of the app. + * + * Normally applications should have a single running attempt; but failure to call sc.stop() + * may cause multiple running attempts to show up. + * + * @return Whether `a1` should precede `a2`. + */ + private def compareAttemptInfo( + a1: FsApplicationAttemptInfo, + a2: FsApplicationAttemptInfo): Boolean = { + a1.startTime >= a2.startTime + } + /** * Replays the events in the specified log file and returns information about the associated - * application. + * application. Return `None` if the application ID cannot be located. */ - private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = { + private def replay( + eventLog: FileStatus, + bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() - val (logInput, sparkVersion) = + logInfo(s"Replaying log path: $logPath") + val logInput = if (isLegacyLogDirectory(eventLog)) { openLegacyEventLog(logPath) } else { @@ -243,17 +452,27 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, sparkVersion) - new FsApplicationHistoryInfo( - logPath.getName(), - appListener.appId.getOrElse(logPath.getName()), - appListener.appName.getOrElse(NOT_STARTED), - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog), - appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + bus.replay(logInput, logPath.toString, !appCompleted) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted)) + } else { + None + } } finally { logInput.close() } @@ -264,30 +483,24 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * log file (along with other metadata files), which is the case for directories generated by * the code in previous releases. * - * @return 2-tuple of (input stream of the events, version of Spark which wrote the log) + * @return input stream that holds one JSON record per line. */ - private[history] def openLegacyEventLog(dir: Path): (InputStream, String) = { + private[history] def openLegacyEventLog(dir: Path): InputStream = { val children = fs.listStatus(dir) var eventLogPath: Path = null var codecName: Option[String] = None - var sparkVersion: String = null children.foreach { child => child.getPath().getName() match { case name if name.startsWith(LOG_PREFIX) => eventLogPath = child.getPath() - case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) => codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length())) - - case version if version.startsWith(SPARK_VERSION_PREFIX) => - sparkVersion = version.substring(SPARK_VERSION_PREFIX.length()) - case _ => } } - if (eventLogPath == null || sparkVersion == null) { + if (eventLogPath == null) { throw new IllegalArgumentException(s"$dir is not a Spark application log directory.") } @@ -299,7 +512,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } val in = new BufferedInputStream(fs.open(eventLogPath)) - (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion) + codec.map(_.compressedInputStream(in)).getOrElse(in) } /** @@ -310,17 +523,19 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() - private def getModificationTime(fsEntry: FileStatus): Long = { - if (fsEntry.isDir) { - fs.listStatus(fsEntry.getPath).map(_.getModificationTime()).max + /** + * Returns the modification time of the given event log. If the status points at an empty + * directory, `None` is returned, indicating that there isn't an event log at that location. + */ + private def getModificationTime(fsEntry: FileStatus): Option[Long] = { + if (isLegacyLogDirectory(fsEntry)) { + val statusList = fs.listStatus(fsEntry.getPath) + if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None } else { - fsEntry.getModificationTime() + Some(fsEntry.getModificationTime()) } } - /** Returns the system's mononotically increasing time. */ - private def getMonotonicTimeMs(): Long = System.nanoTime() / (1000 * 1000) - /** * Return true when the application has completed. */ @@ -332,19 +547,51 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } -private class FsApplicationHistoryInfo( +private class FsApplicationAttemptInfo( val logPath: String, - id: String, - name: String, + val name: String, + val appId: String, + attemptId: Option[String], startTime: Long, endTime: Long, lastUpdated: Long, sparkUser: String, completed: Boolean = true) - extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser, completed) + extends ApplicationAttemptInfo( + attemptId, startTime, endTime, lastUpdated, sparkUser, completed) + +private class FsApplicationHistoryInfo( + id: String, + override val name: String, + override val attempts: List[FsApplicationAttemptInfo]) + extends ApplicationHistoryInfo(id, name, attempts) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index e4e7bc221601..0830cc1ba124 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} -private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { +private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { private val pageSize = 20 private val plusOrMinus = 2 @@ -34,18 +34,28 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { val requestedIncomplete = Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean - val allApps = parent.getApplicationList().filter(_.completed != requestedIncomplete) - val actualFirst = if (requestedFirst < allApps.size) requestedFirst else 0 - val apps = allApps.slice(actualFirst, Math.min(actualFirst + pageSize, allApps.size)) + val allApps = parent.getApplicationList() + .filter(_.attempts.head.completed != requestedIncomplete) + val allAppsSize = allApps.size + + val actualFirst = if (requestedFirst < allAppsSize) requestedFirst else 0 + val appsToShow = allApps.slice(actualFirst, actualFirst + pageSize) val actualPage = (actualFirst / pageSize) + 1 - val last = Math.min(actualFirst + pageSize, allApps.size) - 1 - val pageCount = allApps.size / pageSize + (if (allApps.size % pageSize > 0) 1 else 0) + val last = Math.min(actualFirst + pageSize, allAppsSize) - 1 + val pageCount = allAppsSize / pageSize + (if (allAppsSize % pageSize > 0) 1 else 0) val secondPageFromLeft = 2 val secondPageFromRight = pageCount - 1 - val appTable = UIUtils.listingTable(appHeader, appRow, apps) + val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) + val appTable = + if (hasMultipleAttempts) { + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + } else { + UIUtils.listingTable(appHeader, appRow, appsToShow) + } + val providerConfig = parent.getProviderConfig() val content =
@@ -59,14 +69,15 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { // to the first and last page. If the current page +/- `plusOrMinus` is greater // than the 2nd page from the first page or less than the 2nd page from the last // page, `...` will be displayed. - if (allApps.size > 0) { + if (allAppsSize > 0) { val leftSideIndices = - rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _) + rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _, requestedIncomplete) val rightSideIndices = - rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount) + rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount, + requestedIncomplete)

- Showing {actualFirst + 1}-{last + 1} of {allApps.size} + Showing {actualFirst + 1}-{last + 1} of {allAppsSize} {if (requestedIncomplete) "(Incomplete applications)"} { @@ -89,6 +100,8 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {

++ appTable + } else if (requestedIncomplete) { +

No incomplete applications found!

} else {

No completed applications found!

++

Did you specify the correct logging directory? @@ -122,28 +135,85 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Spark User", "Last Updated") - private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => {nextPage} ) + private val appWithAttemptHeader = Seq( + "App ID", + "App Name", + "Attempt ID", + "Started", + "Completed", + "Duration", + "Spark User", + "Last Updated") + + private def rangeIndices( + range: Seq[Int], + condition: Int => Boolean, + showIncomplete: Boolean): Seq[Node] = { + range.filter(condition).map(nextPage => + {nextPage} ) } - private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}" - val startTime = UIUtils.formatDate(info.startTime) - val endTime = if (info.endTime > 0) UIUtils.formatDate(info.endTime) else "-" + private def attemptRow( + renderAttemptIdColumn: Boolean, + info: ApplicationHistoryInfo, + attempt: ApplicationAttemptInfo, + isFirst: Boolean): Seq[Node] = { + val uiAddress = HistoryServer.getAttemptURI(info.id, attempt.attemptId) + val startTime = UIUtils.formatDate(attempt.startTime) + val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-" val duration = - if (info.endTime > 0) UIUtils.formatDuration(info.endTime - info.startTime) else "-" - val lastUpdated = UIUtils.formatDate(info.lastUpdated) + if (attempt.endTime > 0) { + UIUtils.formatDuration(attempt.endTime - attempt.startTime) + } else { + "-" + } + val lastUpdated = UIUtils.formatDate(attempt.lastUpdated) - {info.id} - {info.name} - {startTime} - {endTime} - {duration} - {info.sparkUser} - {lastUpdated} + { + if (isFirst) { + if (info.attempts.size > 1 || renderAttemptIdColumn) { + + {info.id} + + {info.name} + } else { + {info.id} + {info.name} + } + } else { + Nil + } + } + { + if (renderAttemptIdColumn) { + if (info.attempts.size > 1 && attempt.attemptId.isDefined) { + + {attempt.attemptId.get} + } else { +   + } + } else { + Nil + } + } + {startTime} + {endTime} + + {duration} + {attempt.sparkUser} + {lastUpdated} } + private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { + attemptRow(false, info, info.attempts.head, true) + } + + private def appWithAttemptRow(info: ApplicationHistoryInfo): Seq[Node] = { + attemptRow(true, info, info.attempts.head, true) ++ + info.attempts.drop(1).flatMap(attemptRow(true, info, _, false)) + } + private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = { "/?" + Array( "page=" + linkPage, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index fa9bfe5426b6..d4f327cc588f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ @@ -25,9 +26,11 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, + UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -45,14 +48,18 @@ class HistoryServer( provider: ApplicationHistoryProvider, securityManager: SecurityManager, port: Int) - extends WebUI(securityManager, port, conf) with Logging { + extends WebUI(securityManager, port, conf) with Logging with UIRoot { // How many applications to retain private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) private val appLoader = new CacheLoader[String, SparkUI] { override def load(key: String): SparkUI = { - val ui = provider.getAppUI(key).getOrElse(throw new NoSuchElementException()) + val parts = key.split("/") + require(parts.length == 1 || parts.length == 2, s"Invalid app key $key") + val ui = provider + .getAppUI(parts(0), if (parts.length > 1) Some(parts(1)) else None) + .getOrElse(throw new NoSuchElementException(s"no app with key $key")) attachSparkUI(ui) ui } @@ -61,7 +68,7 @@ class HistoryServer( private val appCache = CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(new RemovalListener[String, SparkUI] { - override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = { detachSparkUI(rm.getValue()) } }) @@ -69,6 +76,8 @@ class HistoryServer( private val loaderServlet = new HttpServlet { protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + // Parse the URI created by getAttemptURI(). It contains an app ID and an optional + // attempt ID (separated by a slash). val parts = Option(req.getPathInfo()).getOrElse("").split("/") if (parts.length < 2) { res.sendError(HttpServletResponse.SC_BAD_REQUEST, @@ -77,27 +86,36 @@ class HistoryServer( } val appId = parts(1) + val attemptId = if (parts.length >= 3) Some(parts(2)) else None + + // Since we may have applications with multiple attempts mixed with applications with a + // single attempt, we need to try both. Try the single-attempt route first, and if an + // error is raised, then try the multiple attempt route. + if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { + val msg =

Application {appId} not found.
+ res.setStatus(HttpServletResponse.SC_NOT_FOUND) + UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + res.getWriter().write(n.toString) + } + return + } // Note we don't use the UI retrieved from the cache; the cache loader above will register // the app's UI, and all we need to do is redirect the user to the same URI that was // requested, and the proper data should be served at that point. - try { - appCache.get(appId) - res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) - } catch { - case e: Exception => e.getCause() match { - case nsee: NoSuchElementException => - val msg =
Application {appId} not found.
- res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach( - n => res.getWriter().write(n.toString)) - - case cause: Exception => throw cause - } - } + res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) + } + + // SPARK-5983 ensure TRACE is not supported + protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) } } + def getSparkUI(appKey: String): Option[SparkUI] = { + Option(appCache.get(appKey)) + } + initialize() /** @@ -108,6 +126,9 @@ class HistoryServer( */ def initialize() { attachPage(new HistoryPage(this)) + + attachHandler(ApiRootResource.getServletHandler(this)) + attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) val contextHandler = new ServletContextHandler @@ -145,14 +166,41 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList() = provider.getListing() + def getApplicationList(): Iterable[ApplicationHistoryInfo] = { + provider.getListing() + } + + def getApplicationInfoList: Iterator[ApplicationInfo] = { + getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + } + + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + provider.writeEventLogs(appId, attemptId, zipStream) + } /** * Returns the provider configuration to show in the listing page. * * @return A map with the provider's configuration. */ - def getProviderConfig() = provider.getConfig() + def getProviderConfig(): Map[String, String] = provider.getConfig() + + private def loadAppUi(appId: String, attemptId: Option[String]): Boolean = { + try { + appCache.get(appId + attemptId.map { id => s"/$id" }.getOrElse("")) + true + } catch { + case e: Exception => e.getCause() match { + case nsee: NoSuchElementException => + false + + case cause: Exception => throw cause + } + } + } } @@ -174,13 +222,13 @@ object HistoryServer extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) - initSecurity() new HistoryServerArguments(conf, argStrings) + initSecurity() val securityManager = new SecurityManager(conf) val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Class.forName(providerName) + val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] @@ -190,11 +238,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run() = { - server.stop() - } - }) + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } @@ -213,4 +257,9 @@ object HistoryServer extends Logging { } } + private[history] def getAttemptURI(appId: String, attemptId: Option[String]): String = { + val attemptSuffix = attemptId.map { id => s"/$id" }.getOrElse("") + s"${HistoryServer.UI_PATH_PREFIX}/${appId}${attemptSuffix}" + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index b1270ade9f75..18265df9faa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,8 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) + extends Logging { private var propertiesFile: String = null parse(args.toList) @@ -55,6 +56,7 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -83,6 +85,7 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ede0a9dbefb8..b40d20f9f786 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,8 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +31,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { @@ -44,6 +42,11 @@ private[spark] class ApplicationInfo( @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ + // A cap on the number of executors this application can have at any given time. + // By default, this is infinite. Only after the first allocation request is issued by the + // application will this be set to a finite value. This is used for dynamic allocation. + @transient private[master] var executorLimit: Int = _ + @transient private var nextExecutorId: Int = _ init() @@ -61,6 +64,7 @@ private[spark] class ApplicationInfo( appSource = new ApplicationSource(this) nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] + executorLimit = Integer.MAX_VALUE } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -75,14 +79,17 @@ private[spark] class ApplicationInfo( } } - def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = { - val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + private[master] def addExecutor( + worker: WorkerInfo, + cores: Int, + useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerExecutorMB) executors(exec.id) = exec coresGranted += cores exec } - def removeExecutor(exec: ExecutorDesc) { + private[master] def removeExecutor(exec: ExecutorDesc) { if (executors.contains(exec.id)) { removedExecutors += executors(exec.id) executors -= exec.id @@ -90,26 +97,36 @@ private[spark] class ApplicationInfo( } } - private val myMaxCores = desc.maxCores.getOrElse(defaultCores) + private val requestedCores = desc.maxCores.getOrElse(defaultCores) - def coresLeft: Int = myMaxCores - coresGranted + private[master] def coresLeft: Int = requestedCores - coresGranted private var _retryCount = 0 - def retryCount = _retryCount + private[master] def retryCount = _retryCount - def incrementRetryCount() = { + private[master] def incrementRetryCount() = { _retryCount += 1 _retryCount } - def resetRetryCount() = _retryCount = 0 + private[master] def resetRetryCount() = _retryCount = 0 - def markFinished(endState: ApplicationState.Value) { + private[master] def markFinished(endState: ApplicationState.Value) { state = endState endTime = System.currentTimeMillis() } + private[master] def isFinished: Boolean = { + state != ApplicationState.WAITING && state != ApplicationState.RUNNING + } + + /** + * Return the limit on the number of executors this application can have. + * For testing only. + */ + private[deploy] def getExecutorLimit: Int = executorLimit + def duration: Long = { if (endTime != -1) { endTime - startTime diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index 38db02cd2421..017e8b55cbe7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -21,7 +21,7 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source -class ApplicationSource(val application: ApplicationInfo) extends Source { +private[master] class ApplicationSource(val application: ApplicationInfo) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "%s.%s.%s".format("application", application.desc.name, System.currentTimeMillis()) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index 67e6c5d66af0..37bfcdfdf477 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -17,11 +17,11 @@ package org.apache.spark.deploy.master -private[spark] object ApplicationState extends Enumeration { +private[master] object ApplicationState extends Enumeration { type ApplicationState = Value - val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value + val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value val MAX_NUM_RETRY = 10 } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 9d3d7938c6cc..b197dbcbfe29 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.DriverDescription import org.apache.spark.util.Utils -private[spark] class DriverInfo( +private[deploy] class DriverInfo( val startTime: Long, val id: String, val desc: DriverDescription, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala index 26a68bade3c6..35ff33a61653 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object DriverState extends Enumeration { +private[deploy] object DriverState extends Enumeration { type DriverState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala index 5d620dfcabad..fc62b094def6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} -private[spark] class ExecutorDesc( +private[master] class ExecutorDesc( val id: Int, val application: ApplicationInfo, val worker: WorkerInfo, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 36a2e2c6a634..aa379d4cd61e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,9 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} +import org.apache.spark.util.Utils /** @@ -31,11 +31,11 @@ import org.apache.spark.Logging * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ -private[spark] class FileSystemPersistenceEngine( +private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -48,7 +48,7 @@ private[spark] class FileSystemPersistenceEngine( new File(dir + File.separator + name).delete() } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) files.map(deserializeFromFile[T]) } @@ -56,27 +56,31 @@ private[spark] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) - try { - out.write(serialized) - } finally { - out.close() + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null + Utils.tryWithSafeFinally { + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) + } { + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index cf77c86d760c..70f21fbe0de8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait LeaderElectionAgent { - val masterActor: LeaderElectable + val masterInstance: LeaderElectable def stop() {} // to avoid noops in implementations. } @@ -37,7 +37,7 @@ trait LeaderElectable { } /** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ -private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable) +private[spark] class MonarchyLeaderAgent(val masterInstance: LeaderElectable) extends LeaderElectionAgent { - masterActor.electedLeader() + masterInstance.electedLeader() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 5eeb9fe52624..26904d39a9be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,15 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -43,92 +38,115 @@ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -private[spark] class Master( - host: String, - port: Int, +private[deploy] class Master( + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, - val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + val securityMgr: SecurityManager, + val conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - val conf = new SparkConf - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 - val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) - val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) - val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) - val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val workers = new HashSet[WorkerInfo] - val idToWorker = new HashMap[String, WorkerInfo] - val addressToWorker = new HashMap[Address, WorkerInfo] + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 + private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) + private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) + private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) + private val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") - val apps = new HashSet[ApplicationInfo] + val workers = new HashSet[WorkerInfo] val idToApp = new HashMap[String, ApplicationInfo] - val actorToApp = new HashMap[ActorRef, ApplicationInfo] - val addressToApp = new HashMap[Address, ApplicationInfo] val waitingApps = new ArrayBuffer[ApplicationInfo] - val completedApps = new ArrayBuffer[ApplicationInfo] - var nextAppNumber = 0 - val appIdToUI = new HashMap[String, SparkUI] + val apps = new HashSet[ApplicationInfo] - val drivers = new HashSet[DriverInfo] - val completedDrivers = new ArrayBuffer[DriverInfo] - val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling - var nextDriverNumber = 0 + private val idToWorker = new HashMap[String, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - Utils.checkHost(host, "Expected hostname") + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] + private val completedApps = new ArrayBuffer[ApplicationInfo] + private var nextAppNumber = 0 + private val appIdToUI = new HashMap[String, SparkUI] - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, + private val drivers = new HashSet[DriverInfo] + private val completedDrivers = new ArrayBuffer[DriverInfo] + // Drivers currently spooled for scheduling + private val waitingDrivers = new ArrayBuffer[DriverInfo] + private var nextDriverNumber = 0 + + Utils.checkHost(address.host, "Expected hostname") + + private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) + private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) - val masterSource = new MasterSource(this) + private val masterSource = new MasterSource(this) - val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null - val masterPublicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + private val masterPublicAddress = { + val envVar = conf.getenv("SPARK_PUBLIC_DNS") + if (envVar != null) envVar else address.host } - val masterUrl = "spark://" + host + ":" + port - var masterWebUiUrl: String = _ + private val masterUrl = address.toSparkURL + private var masterWebUiUrl: String = _ + + private var state = RecoveryState.STANDBY - var state = RecoveryState.STANDBY + private var persistenceEngine: PersistenceEngine = _ - var persistenceEngine: PersistenceEngine = _ + private var leaderElectionAgent: LeaderElectionAgent = _ - var leaderElectionAgent: LeaderElectionAgent = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ - private var recoveryCompletionTask: Cancellable = _ + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. - val spreadOutApps = conf.getBoolean("spark.deploy.spreadOut", true) + private val spreadOutApps = conf.getBoolean("spark.deploy.spreadOut", true) // Default maxCores for applications that don't specify it (i.e. pass Int.MaxValue) - val defaultCores = conf.getInt("spark.deploy.defaultCores", Int.MaxValue) + private val defaultCores = conf.getInt("spark.deploy.defaultCores", Int.MaxValue) if (defaultCores < 1) { throw new SparkException("spark.deploy.defaultCores must be positive") } - override def preStart() { + // Alternative application submission gateway that is stable across Spark versions + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private var restServer: Option[StandaloneRestServer] = None + private var restServerBoundPort: Option[Int] = None + + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) + + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) + } + restServerBoundPort = restServer.map(_.start()) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -138,20 +156,21 @@ private[spark] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => - val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) - .newInstance(conf, SerializationExtension(context.system)) + val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -161,19 +180,19 @@ private[spark] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) + } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) } + forwardMessageThread.shutdownNow() webUi.stop() + restServer.foreach(_.stop()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -181,16 +200,16 @@ private[spark] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { @@ -199,8 +218,11 @@ private[spark] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -211,103 +233,42 @@ private[spark] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) - } - } - } - - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"Can only kill drivers in ALIVE state. Current state: $state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RequestDriverStatus(driverId) => { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) - } - } - - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -319,11 +280,15 @@ private[spark] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") - appInfo.removeExecutor(exec) + // If an application has already finished, preserve its + // state to display its information properly on the UI + if (!appInfo.isFinished) { + appInfo.removeExecutor(exec) + } exec.worker.removeExecutor(exec) val normalExit = exitStatus == Some(0) @@ -356,7 +321,7 @@ private[spark] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -364,7 +329,7 @@ private[spark] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -412,40 +377,126 @@ private[spark] class Master( if (canCompleteRecovery) { completeRecovery() } } - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case UnregisterApplication(applicationId) => + logInfo(s"Received unregister request from application $applicationId") + idToApp.get(applicationId).foreach(finishApplication) + + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } } - case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } + + case BoundPortsRequest => { + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) + } + + case RequestExecutors(appId, requestedTotal) => + context.reply(handleRequestExecutors(appId, requestedTotal)) + + case KillExecutors(appId, executorIds) => + val formattedExecutorIds = formatExecutorIds(executorIds) + context.reply(handleKillExecutors(appId, formattedExecutorIds)) } - def canCompleteRecovery = + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 - def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo], + private def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo], storedWorkers: Seq[WorkerInfo]) { for (app <- storedApps) { logInfo("Trying to recover app: " + app.id) try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -462,19 +513,17 @@ private[spark] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } } } - def completeRecovery() { + private def completeRecovery() { // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) @@ -498,99 +547,166 @@ private[spark] class Master( } /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). + * Schedule executors to be launched on the workers. + * Returns an array containing number of cores assigned to each worker. + * + * There are two modes of launching executors. The first attempts to spread out an application's + * executors on as many workers as possible, while the second does the opposite (i.e. launch them + * on as few workers as possible). The former is usually better for data locality purposes and is + * the default. + * + * The number of cores assigned to each executor is configurable. When this is explicitly set, + * multiple executors from the same application may be launched on the same worker if the worker + * has enough cores and memory. Otherwise, each executor grabs all the cores available on the + * worker by default, in which case only one executor may be launched on each worker. + * + * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core + * at a time). Consider the following example: cluster has 4 workers with 16 cores each. + * User requests 3 executors (spark.cores.max = 48, spark.executor.cores = 16). If 1 core is + * allocated at a time, 12 cores from each worker would be assigned to each executor. + * Since 12 < 16, no executors would launch [SPARK-8881]. */ - def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) + private def scheduleExecutorsOnWorkers( + app: ApplicationInfo, + usableWorkers: Array[WorkerInfo], + spreadOutApps: Boolean): Array[Int] = { + val coresPerExecutor = app.desc.coresPerExecutor + val minCoresPerExecutor = coresPerExecutor.getOrElse(1) + val oneExecutorPerWorker = coresPerExecutor.isEmpty + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val numUsable = usableWorkers.length + val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker + val assignedExecutors = new Array[Int](numUsable) // Number of new executors on each worker + var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) + + /** Return whether the specified worker can launch an executor for this app. */ + def canLaunchExecutor(pos: Int): Boolean = { + val keepScheduling = coresToAssign >= minCoresPerExecutor + val enoughCores = usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor + + // If we allow multiple executors per worker, then we can always launch new executors. + // Otherwise, if there is already an executor on this worker, just give it more cores. + val launchingNewExecutor = !oneExecutorPerWorker || assignedExecutors(pos) == 0 + if (launchingNewExecutor) { + val assignedMemory = assignedExecutors(pos) * memoryPerExecutor + val enoughMemory = usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor + val underLimit = assignedExecutors.sum + app.executors.size < app.executorLimit + keepScheduling && enoughCores && enoughMemory && underLimit + } else { + // We're adding cores to an existing executor, so no need + // to check memory and executor limits + keepScheduling && enoughCores + } + } + + // Keep launching executors until no more workers can accommodate any + // more executors, or if we have reached this application's limits + var freeWorkers = (0 until numUsable).filter(canLaunchExecutor) + while (freeWorkers.nonEmpty) { + freeWorkers.foreach { pos => + var keepScheduling = true + while (keepScheduling && canLaunchExecutor(pos)) { + coresToAssign -= minCoresPerExecutor + assignedCores(pos) += minCoresPerExecutor + + // If we are launching one executor per worker, then every iteration assigns 1 core + // to the executor. Otherwise, every iteration assigns cores to a new executor. + if (oneExecutorPerWorker) { + assignedExecutors(pos) = 1 + } else { + assignedExecutors(pos) += 1 + } + + // Spreading out an application means spreading out its executors across as + // many workers as possible. If we are not spreading out, then we should keep + // scheduling executors on this worker until we use all of its resources. + // Otherwise, just move on to the next worker. + if (spreadOutApps) { + keepScheduling = false + } + } + } + freeWorkers = freeWorkers.filter(canLaunchExecutor) + } + assignedCores + } + + /** + * Schedule and launch executors on workers + */ + private def startExecutorsOnWorkers(): Unit = { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + for (app <- waitingApps if app.coresLeft > 0) { + val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) + } + } + } + + /** + * Allocate a worker's resources to one or more executors. + * @param app the info of the application which the executors belong to + * @param assignedCores number of cores on this worker for this application + * @param coresPerExecutor number of cores per executor + * @param worker the worker info + */ + private def allocateWorkerResourceToExecutors( + app: ApplicationInfo, + assignedCores: Int, + coresPerExecutor: Option[Int], + worker: WorkerInfo): Unit = { + // If the number of cores per executor is specified, we divide the cores assigned + // to this worker evenly among the executors with no remainder. + // Otherwise, we launch a single executor that grabs all the assignedCores on this worker. + val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1) + val coresToAssign = coresPerExecutor.getOrElse(assignedCores) + for (i <- 1 to numExecutors) { + val exec = app.addExecutor(worker, coresToAssign) + launchExecutor(worker, exec) + app.state = ApplicationState.RUNNING + } } /** * Schedule the currently available resources among waiting apps. This method will be called * every time a new app joins or resource availability changes. */ - private def schedule() { + private def schedule(): Unit = { if (state != RecoveryState.ALIVE) { return } - - // First schedule drivers, they take strict precedence over applications - // Randomization helps balance drivers - val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) - val numWorkersAlive = shuffledAliveWorkers.size - var curPos = 0 - - for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers - // We assign workers to each waiting driver in a round-robin fashion. For each driver, we - // start from the last worker that was assigned a driver, and continue onwards until we have - // explored all alive workers. - var launched = false - var numWorkersVisited = 0 - while (numWorkersVisited < numWorkersAlive && !launched) { - val worker = shuffledAliveWorkers(curPos) - numWorkersVisited += 1 + // Drivers take strict precedence over executors + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { launchDriver(worker, driver) waitingDrivers -= driver - launched = true - } - curPos = (curPos + 1) % numWorkersAlive - } - } - - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 - } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) - app.state = ApplicationState.RUNNING - } - } - } - } else { - // Pack each app into as few nodes as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) - app.state = ApplicationState.RUNNING - } - } } } } + startExecutorsOnWorkers() } - def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { + private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } - def registerWorker(worker: WorkerInfo): Boolean = { + private def registerWorker(worker: WorkerInfo): Boolean = { // There may be one or more refs to dead workers on this same node (w/ different ID's), // remove them. workers.filter { w => @@ -599,7 +715,7 @@ private[spark] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -618,15 +734,15 @@ private[spark] class Master( true } - def removeWorker(worker: WorkerInfo) { + private def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -641,22 +757,23 @@ private[spark] class Master( persistenceEngine.removeWorker(worker) } - def relaunchDriver(driver: DriverInfo) { + private def relaunchDriver(driver: DriverInfo) { driver.worker = None driver.state = DriverState.RELAUNCHING waitingDrivers += driver schedule() } - def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } - def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address - if (addressToWorker.contains(appAddress)) { + private def registerApplication(app: ApplicationInfo): Unit = { + val appAddress = app.driver.address + if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return } @@ -664,12 +781,12 @@ private[spark] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } - def finishApplication(app: ApplicationInfo) { + private def finishApplication(app: ApplicationInfo) { removeApplication(app, ApplicationState.FINISHED) } @@ -678,8 +795,8 @@ private[spark] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -695,58 +812,142 @@ private[spark] class Master( rebuildSparkUI(app) for (exec <- app.executors.values) { - exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) - exec.state = ExecutorState.KILLED + killExecutor(exec) } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) } } } + /** + * Handle a request to set the target number of executors for this application. + * + * If the executor limit is adjusted upwards, new executors will be launched provided + * that there are workers with sufficient resources. If it is adjusted downwards, however, + * we do not kill existing executors until we explicitly receive a kill request. + * + * @return whether the application has previously registered with this Master. + */ + private def handleRequestExecutors(appId: String, requestedTotal: Int): Boolean = { + idToApp.get(appId) match { + case Some(appInfo) => + logInfo(s"Application $appId requested to set total executors to $requestedTotal.") + appInfo.executorLimit = requestedTotal + schedule() + true + case None => + logWarning(s"Unknown application $appId requested $requestedTotal total executors.") + false + } + } + + /** + * Handle a kill request from the given application. + * + * This method assumes the executor limit has already been adjusted downwards through + * a separate [[RequestExecutors]] message, such that we do not launch new executors + * immediately after the old ones are removed. + * + * @return whether the application has previously registered with this Master. + */ + private def handleKillExecutors(appId: String, executorIds: Seq[Int]): Boolean = { + idToApp.get(appId) match { + case Some(appInfo) => + logInfo(s"Application $appId requests to kill executors: " + executorIds.mkString(", ")) + val (known, unknown) = executorIds.partition(appInfo.executors.contains) + known.foreach { executorId => + val desc = appInfo.executors(executorId) + appInfo.removeExecutor(desc) + killExecutor(desc) + } + if (unknown.nonEmpty) { + logWarning(s"Application $appId attempted to kill non-existent executors: " + + unknown.mkString(", ")) + } + schedule() + true + case None => + logWarning(s"Unregistered application $appId requested us to kill executors!") + false + } + } + + /** + * Cast the given executor IDs to integers and filter out the ones that fail. + * + * All executors IDs should be integers since we launched these executors. However, + * the kill interface on the driver side accepts arbitrary strings, so we need to + * handle non-integer executor IDs just to be safe. + */ + private def formatExecutorIds(executorIds: Seq[String]): Seq[Int] = { + executorIds.flatMap { executorId => + try { + Some(executorId.toInt) + } catch { + case e: NumberFormatException => + logError(s"Encountered executor with a non-integer ID: $executorId. Ignoring") + None + } + } + } + + /** + * Ask the worker on which the specified executor is launched to kill the executor. + */ + private def killExecutor(exec: ExecutorDesc): Unit = { + exec.worker.removeExecutor(exec) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) + exec.state = ExecutorState.KILLED + } + /** * Rebuild a new SparkUI from the given application's event logs. - * Return whether this is successful. + * Return the UI if successful, else None */ - def rebuildSparkUI(app: ApplicationInfo): Boolean = { + private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { val appName = app.desc.name val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" try { - val eventLogFile = app.desc.eventLogDir - .map { dir => EventLoggingListener.getLogPath(dir, app.id) } + val eventLogDir = app.desc.eventLogDir .getOrElse { // Event logging is not enabled for this application app.desc.appUiUrl = notFoundBasePath - return false + return None } - - val fs = Utils.getHadoopFileSystem(eventLogFile, hadoopConf) - if (fs.exists(new Path(eventLogFile + EventLoggingListener.IN_PROGRESS))) { + val eventLogFilePrefix = EventLoggingListener.getLogPath( + eventLogDir, app.id, app.desc.eventLogCodec) + val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) + val inProgressExists = fs.exists(new Path(eventLogFilePrefix + + EventLoggingListener.IN_PROGRESS)) + + if (inProgressExists) { // Event logging is enabled for this application, but the application is still in progress - val title = s"Application history not found (${app.id})" - var msg = s"Application $appName is still in progress." - logWarning(msg) - msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" - return false + logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") } - val (logInput, sparkVersion) = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) + val (eventLogFile, status) = if (inProgressExists) { + (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)") + } else { + (eventLogFilePrefix, " (completed)") + } + + val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") + appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { - replayBus.replay(logInput, sparkVersion) + replayBus.replay(logInput, eventLogFile, maybeTruncated) } finally { logInput.close() } @@ -754,17 +955,17 @@ private[spark] class Master( webUi.attachSparkUI(ui) // Application UI is successfully rebuilt, so link the Master UI to it app.desc.appUiUrl = ui.basePath - true + Some(ui) } catch { case fnf: FileNotFoundException => // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" - var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir}." + var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir.get}." logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" - false + None case e: Exception => // Relay exception message to application UI page val title = s"Application history load error (${app.id})" @@ -773,56 +974,59 @@ private[spark] class Master( logError(msg, e) msg = URLEncoder.encode(msg, "UTF-8") app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title" - false + None } } /** Generate a new app ID given a app's submission date */ - def newApplicationId(submitDate: Date): String = { + private def newApplicationId(submitDate: Date): String = { val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 appId } /** Check for, and remove, any timed-out workers */ - def timeOutDeadWorkers() { + private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } } } - def newDriverId(submitDate: Date): String = { + private def newDriverId(submitDate: Date): String = { val appId = "driver-%s-%04d".format(createDateFormat.format(submitDate), nextDriverNumber) nextDriverNumber += 1 appId } - def createDriver(desc: DriverDescription): DriverInfo = { + private def createDriver(desc: DriverDescription): DriverInfo = { val now = System.currentTimeMillis() val date = new Date(now) new DriverInfo(now, newDriverId(date), desc, date) } - def launchDriver(worker: WorkerInfo, driver: DriverInfo) { + private def launchDriver(worker: WorkerInfo, driver: DriverInfo) { logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } - def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) { + private def removeDriver( + driverId: String, + finalState: DriverState, + exception: Option[Exception]) { drivers.find(d => d.id == driverId) match { case Some(driver) => logInfo(s"Removing driver: $driverId") @@ -843,51 +1047,34 @@ private[spark] class Master( } } -private[spark] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" +private[deploy] object Master extends Logging { + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) - } - - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) - val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] - (actorSystem, boundPort, resp.webUIBoundPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index e34bee785429..44cefbc77f08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.{IntParam, Utils} /** * Command-line parser for the master. */ -private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { +private[master] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 @@ -49,7 +49,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { webUiPort = conf.get("spark.master.ui.port").toInt } - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) host = value @@ -84,7 +84,8 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { /** * Print usage and exit JVM with the given exit code. */ - def printUsageAndExit(exitCode: Int) { + private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index db72d8ae9bda..a952cee36eb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut @@ -36,7 +36,7 @@ private[master] object MasterMessages { case object CompleteRecovery - case object RequestWebUIPort + case object BoundPortsRequest - case class WebUIPortResponse(webUIBoundPort: Int) + case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 9c3f79f1244b..66a9ff38678c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -30,6 +30,11 @@ private[spark] class MasterSource(val master: Master) extends Source { override def getValue: Int = master.workers.size }) + // Gauge for alive worker numbers in cluster + metricRegistry.register(MetricRegistry.name("aliveWorkers"), new Gauge[Int]{ + override def getValue: Int = master.workers.filter(_.state == WorkerState.ALIVE).size + }) + // Gauge for application numbers in cluster metricRegistry.register(MetricRegistry.name("apps"), new Gauge[Int] { override def getValue: Int = master.apps.size diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index 2e0e1e7036ac..58a00bceee6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -33,7 +34,7 @@ import scala.reflect.ClassTag * The implementation of this trait defines how name-object pairs are stored or retrieved. */ @DeveloperApi -trait PersistenceEngine { +abstract class PersistenceEngine { /** * Defines how the object is serialized and persisted. Implementation will @@ -80,14 +81,17 @@ trait PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} } -private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { +private[master] class BlackHolePersistenceEngine extends PersistenceEngine { override def persist(name: String, obj: Object): Unit = {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 1096eb036835..c4c3283fb73f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,22 +48,29 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") - def createPersistenceEngine() = { + def createPersistenceEngine(): PersistenceEngine = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } - def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } } -private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) - def createLeaderElectionAgent(master: LeaderElectable) = + def createPersistenceEngine(): PersistenceEngine = { + new ZooKeeperPersistenceEngine(conf, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { new ZooKeeperLeaderElectionAgent(master, conf) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala index 256a5a7c28e4..aa0f02fa625c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object RecoveryState extends Enumeration { +private[deploy] object RecoveryState extends Enumeration { type MasterState = Value val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e94aae93e449..f75196660520 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { @@ -104,7 +102,9 @@ private[spark] class WorkerInfo( "http://" + this.publicAddress + ":" + this.webUiPort } - def setState(state: WorkerState.Value) = { + def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala index 0b36ef60051f..b60baaadfb4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object WorkerState extends Enumeration { +private[master] object WorkerState extends Enumeration { type WorkerState = Value val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 8eaa0ad94851..d317206a614f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,14 +17,12 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.spark.deploy.SparkCuratorUtil -private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, +private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -35,7 +33,7 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectab start() - def start() { + private def start() { logInfo("Starting ZooKeeper LeaderElection agent") zk = SparkCuratorUtil.newClient(conf) leaderLatch = new LeaderLatch(zk, WORKING_DIR) @@ -72,13 +70,13 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectab } } - def updateLeadershipStatus(isLeader: Boolean) { + private def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER - masterActor.electedLeader() + masterInstance.electedLeader() } else if (!isLeader && status == LeadershipStatus.LEADER) { status = LeadershipStatus.NOT_LEADER - masterActor.revokedLeadership() + masterInstance.revokedLeadership() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index e11ac031fb9c..540e802420ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,23 +17,25 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine - with Logging -{ - val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" - val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) + with Logging { + + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" + private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) SparkCuratorUtil.mkdir(zk, WORKING_DIR) @@ -46,9 +48,9 @@ private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializati zk.delete().forPath(WORKING_DIR + "/" + name) } - override def read[T: ClassTag](prefix: String) = { - val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) - file.map(deserializeFromFile[T]).flatten + override def read[T: ClassTag](prefix: String): Seq[T] = { + zk.getChildren.forPath(WORKING_DIR).asScala + .filter(_.startsWith(prefix)).map(deserializeFromFile[T]).flatten } override def close() { @@ -56,17 +58,16 @@ private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializati } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } - def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { + private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 3aae2b95d739..e28e7e379ac9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,42 +19,29 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask -import org.json4s.JValue - -import org.apache.spark.deploy.{ExecutorState, JsonProtocol} +import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils -private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { +private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout - - /** Executor details for a particular application */ - override def renderJson(request: HttpServletRequest): JValue = { - val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) - val app = state.activeApps.find(_.id == appId).getOrElse({ - state.completedApps.find(_.id == appId).getOrElse(null) - }) - JsonProtocol.writeApplicationInfo(app) - } + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) + if (app == null) { + val msg =
No running application with ID {appId}
+ return UIUtils.basicSparkPage(msg, "Not Found") + } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq @@ -85,7 +72,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
  • Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala index d8daff3e7fb9..e021f1eef794 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.ui.{UIUtils, WebUIPage} -private[spark] class HistoryNotFoundPage(parent: MasterWebUI) +private[ui] class HistoryNotFoundPage(parent: MasterWebUI) extends WebUIPage("history/not-found") { /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a2872..c3e20ebf8d6e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,46 +19,72 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout +private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { + private val master = parent.masterEndpointRef + + def getMasterState: MasterStateResponse = { + master.askWithRetry[MasterStateResponse](RequestMasterState) + } override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) - JsonProtocol.writeMasterState(state) + JsonProtocol.writeMasterState(getMasterState) + } + + def handleAppKillRequest(request: HttpServletRequest): Unit = { + handleKillRequest(request, id => { + parent.master.idToApp.get(id).foreach { app => + parent.master.removeApplication(app, ApplicationState.KILLED) + } + }) + } + + def handleDriverKillRequest(request: HttpServletRequest): Unit = { + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) + } + + private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { + if (parent.killEnabled && + parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { + val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean + val id = Option(request.getParameter("id")) + if (id.isDefined && killFlag) { + action(id.get) + } + + Thread.sleep(100) + } } /** Index view listing applications and executors */ def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = getMasterState - val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") + val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", - "State", "Duration") + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) val completedApps = state.completedApps.sortBy(_.endTime).reverse val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) - val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", - "Main Class") + val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", + "Memory", "Main Class") val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers) val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse @@ -66,19 +92,27 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + def hasDrivers: Boolean = activeDrivers.length > 0 || completedDrivers.length > 0 val content =
    • URL: {state.uri}
    • -
    • Workers: {state.workers.size}
    • -
    • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
    • -
    • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
    • + { + state.restUri.map { uri => +
    • + REST URL: {uri} + (cluster mode) +
    • + }.getOrElse { Seq.empty } + } +
    • Alive Workers: {aliveWorkers.size}
    • +
    • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
    • +
    • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
    • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
    • @@ -155,9 +189,21 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } private def appRow(app: ApplicationInfo): Seq[Node] = { + val killLink = if (parent.killEnabled && + (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { + val confirm = + s"if (window.confirm('Are you sure you want to kill application ${app.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
      + + + (kill) +
      + } {app.id} + {killLink} {app.desc.name} @@ -165,8 +211,8 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.coresGranted} - - {Utils.megabytesToString(app.desc.memoryPerSlave)} + + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} {UIUtils.formatDate(app.submitDate)} {app.desc.user} @@ -176,8 +222,21 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } private def driverRow(driver: DriverInfo): Seq[Node] = { + val killLink = if (parent.killEnabled && + (driver.state == DriverState.RUNNING || + driver.state == DriverState.SUBMITTED || + driver.state == DriverState.RELAUNCHING)) { + val confirm = + s"if (window.confirm('Are you sure you want to kill driver ${driver.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
      + + + (kill) +
      + } - {driver.id} + {driver.id} {killLink} {driver.submitDate} {driver.worker.map(w => {w.id.toString}).getOrElse("None")} @@ -188,7 +247,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(driver.desc.mem.toLong)} - {driver.desc.command.arguments(1)} + {driver.desc.command.arguments(2)} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 73400c5affb5..6174fc11f83d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,28 +19,38 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.Logging import org.apache.spark.deploy.master.Master +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo, + UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils /** * Web UI server for the standalone master. */ -private[spark] +private[master] class MasterWebUI(val master: Master, requestedPort: Int) - extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { + extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging + with UIRoot { - val masterActorRef = master.self - val timeout = AkkaUtils.askTimeout(master.conf) + val masterEndpointRef = master.self + val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) + + val masterPage = new MasterPage(this) initialize() /** Initialize all components of the server. */ def initialize() { + val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(new HistoryNotFoundPage(this)) - attachPage(new MasterPage(this)) + attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + attachHandler(ApiRootResource.getServletHandler(this)) + attachHandler(createRedirectHandler( + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) + attachHandler(createRedirectHandler( + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST"))) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ @@ -54,8 +64,25 @@ class MasterWebUI(val master: Master, requestedPort: Int) assert(serverInfo.isDefined, "Master UI must be bound to a server before detaching SparkUIs") ui.getHandlers.foreach(detachHandler) } + + def getApplicationInfoList: Iterator[ApplicationInfo] = { + val state = masterPage.getMasterState + val activeApps = state.activeApps.sortBy(_.startTime).reverse + val completedApps = state.completedApps.sortBy(_.endTime).reverse + activeApps.iterator.map { ApplicationsListResource.convertApplicationInfo(_, false) } ++ + completedApps.iterator.map { ApplicationsListResource.convertApplicationInfo(_, true) } + } + + def getSparkUI(appId: String): Option[SparkUI] = { + val state = masterPage.getMasterState + val activeApps = state.activeApps.sortBy(_.startTime).reverse + val completedApps = state.completedApps.sortBy(_.endTime).reverse + (activeApps ++ completedApps).find { _.id == appId }.flatMap { + master.rebuildSparkUI + } + } } -private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +private[master] object MasterWebUI { + private val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala new file mode 100644 index 000000000000..5d4e5b899dfd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.deploy.mesos.ui.MesosClusterUI +import org.apache.spark.deploy.rest.mesos.MesosRestServer +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.util.SignalLogger +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +/* + * A dispatcher that is responsible for managing and launching drivers, and is intended to be + * used for Mesos cluster mode. The dispatcher is a long-running process started by the user in + * the cluster independently of Spark applications. + * It contains a [[MesosRestServer]] that listens for requests to submit drivers and a + * [[MesosClusterScheduler]] that processes these requests by negotiating with the Mesos master + * for resources. + * + * A typical new driver lifecycle is the following: + * - Driver submitted via spark-submit talking to the [[MesosRestServer]] + * - [[MesosRestServer]] queues the driver request to [[MesosClusterScheduler]] + * - [[MesosClusterScheduler]] gets resource offers and launches the drivers that are in queue + * + * This dispatcher supports both Mesos fine-grain or coarse-grain mode as the mode is configurable + * per driver launched. + * This class is needed since Mesos doesn't manage frameworks, so the dispatcher acts as + * a daemon to launch drivers as Mesos frameworks upon request. The dispatcher is also started and + * stopped by sbin/start-mesos-dispatcher and sbin/stop-mesos-dispatcher respectively. + */ +private[mesos] class MesosClusterDispatcher( + args: MesosClusterDispatcherArguments, + conf: SparkConf) + extends Logging { + + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) + private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) + + private val engineFactory = recoveryMode match { + case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory + case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf) + case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode) + } + + private val scheduler = new MesosClusterScheduler(engineFactory, conf) + + private val server = new MesosRestServer(args.host, args.port, conf, scheduler) + private val webUi = new MesosClusterUI( + new SecurityManager(conf), + args.webUiPort, + conf, + publicAddress, + scheduler) + + private val shutdownLatch = new CountDownLatch(1) + + def start(): Unit = { + webUi.bind() + scheduler.frameworkUrl = webUi.activeWebUiUrl + scheduler.start() + server.start() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + def stop(): Unit = { + webUi.stop() + server.stop() + scheduler.stop() + shutdownLatch.countDown() + } +} + +private[mesos] object MesosClusterDispatcher extends Logging { + def main(args: Array[String]) { + SignalLogger.register(log) + val conf = new SparkConf + val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + conf.setMaster(dispatcherArgs.masterUrl) + conf.setAppName(dispatcherArgs.name) + dispatcherArgs.zookeeperUrl.foreach { z => + conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.mesos.deploy.zookeeper.url", z) + } + val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) + dispatcher.start() + val shutdownHook = new Thread() { + override def run() { + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + dispatcher.awaitShutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala new file mode 100644 index 000000000000..5accaf78d0a5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, Utils} + + +private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { + var host = Utils.localHostName() + var port = 7077 + var name = "Spark Cluster" + var webUiPort = 8081 + var masterUrl: String = _ + var zookeeperUrl: Option[String] = None + var propertiesFile: String = _ + + parse(args.toList) + + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + private def parse(args: List[String]): Unit = args match { + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--webui-port" | "-p") :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--zk" | "-z") :: value :: tail => + zookeeperUrl = Some(value) + parse(tail) + + case ("--master" | "-m") :: value :: tail => + if (!value.startsWith("mesos://")) { + // scalastyle:off println + System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println + System.exit(1) + } + masterUrl = value.stripPrefix("mesos://") + parse(tail) + + case ("--name") :: value :: tail => + name = value + parse(tail) + + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => + printUsageAndExit(0) + + case Nil => { + if (masterUrl == null) { + // scalastyle:off println + System.err.println("--master is required") + // scalastyle:on println + printUsageAndExit(1) + } + } + + case _ => + printUsageAndExit(1) + } + + private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println + System.err.println( + "Usage: MesosClusterDispatcher [options]\n" + + "\n" + + "Options:\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + + " --name NAME Framework name to show in Mesos UI\n" + + " -m --master MASTER URI for connecting to Mesos master\n" + + " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + + " Zookeeper for persistence\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") + // scalastyle:on println + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala new file mode 100644 index 000000000000..1948226800af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.Date + +import org.apache.spark.deploy.Command +import org.apache.spark.scheduler.cluster.mesos.MesosClusterRetryState + +/** + * Describes a Spark driver that is submitted from the + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]], to be launched by + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * @param jarUrl URL to the application jar + * @param mem Amount of memory for the driver + * @param cores Number of cores for the driver + * @param supervise Supervise the driver for long running app + * @param command The command to launch the driver. + * @param schedulerProperties Extra properties to pass the Mesos scheduler + */ +private[spark] class MesosDriverDescription( + val name: String, + val jarUrl: String, + val mem: Int, + val cores: Double, + val supervise: Boolean, + val command: Command, + val schedulerProperties: Map[String, String], + val submissionId: String, + val submissionDate: Date, + val retryState: Option[MesosClusterRetryState] = None) + extends Serializable { + + def copy( + name: String = name, + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Double = cores, + supervise: Boolean = supervise, + command: Command = command, + schedulerProperties: Map[String, String] = schedulerProperties, + submissionId: String = submissionId, + submissionDate: Date = submissionDate, + retryState: Option[MesosClusterRetryState] = retryState): MesosDriverDescription = { + new MesosDriverDescription(name, jarUrl, mem, cores, supervise, command, schedulerProperties, + submissionId, submissionDate, retryState) + } + + override def toString: String = s"MesosDriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala new file mode 100644 index 000000000000..12337a940a41 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.net.SocketAddress + +import scala.collection.mutable + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver +import org.apache.spark.network.util.TransportConf + +/** + * An RPC endpoint that receives registration requests from Spark drivers running on Mesos. + * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. + */ +private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { + + // Stores a map of driver socket addresses to app ids + private val connectedApps = new mutable.HashMap[SocketAddress, String] + + protected override def handleMessage( + message: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): Unit = { + message match { + case RegisterDriverParam(appId) => + val address = client.getSocketAddress + logDebug(s"Received registration request from app $appId (remote address $address).") + if (connectedApps.contains(address)) { + val existingAppId = connectedApps(address) + if (!existingAppId.equals(appId)) { + logError(s"A new app '$appId' has connected to existing address $address, " + + s"removing previously registered app '$existingAppId'.") + applicationRemoved(existingAppId, true) + } + } + connectedApps(address) = appId + callback.onSuccess(new Array[Byte](0)) + case _ => super.handleMessage(message, client, callback) + } + } + + /** + * On connection termination, clean up shuffle files written by the associated application. + */ + override def connectionTerminated(client: TransportClient): Unit = { + val address = client.getSocketAddress + if (connectedApps.contains(address)) { + val appId = connectedApps(address) + logInfo(s"Application $appId disconnected (address was $address).") + applicationRemoved(appId, true /* cleanupLocalDirs */) + connectedApps.remove(address) + } else { + logWarning(s"Unknown $address disconnected.") + } + } + + /** An extractor object for matching [[RegisterDriver]] message. */ + private object RegisterDriverParam { + def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId) + } +} + +/** + * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers + * to associate with. This allows the shuffle service to detect when a driver is terminated + * and can clean up the associated shuffle files. + */ +private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager) + extends ExternalShuffleService(conf, securityManager) { + + protected override def newShuffleBlockHandler( + conf: TransportConf): ExternalShuffleBlockHandler = { + new MesosExternalShuffleBlockHandler(conf) + } +} + +private[spark] object MesosExternalShuffleService extends Logging { + + def main(args: Array[String]): Unit = { + ExternalShuffleService.main(args, + (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm)) + } +} + + diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala new file mode 100644 index 000000000000..e8ef60bd5428 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.{MesosClusterSubmissionState, MesosClusterRetryState} +import org.apache.spark.ui.{UIUtils, WebUIPage} + + +private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") { + + override def render(request: HttpServletRequest): Seq[Node] = { + val driverId = request.getParameter("id") + require(driverId != null && driverId.nonEmpty, "Missing id parameter") + + val state = parent.scheduler.getDriverState(driverId) + if (state.isEmpty) { + val content = +
      +

      Cannot find driver {driverId}

      +
      + return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + } + val driverState = state.get + val driverHeaders = Seq("Driver property", "Value") + val schedulerHeaders = Seq("Scheduler property", "Value") + val commandEnvHeaders = Seq("Command environment variable", "Value") + val launchedHeaders = Seq("Launched property", "Value") + val commandHeaders = Seq("Comamnd property", "Value") + val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") + val driverDescription = Iterable.apply(driverState.description) + val submissionState = Iterable.apply(driverState.submissionState) + val command = Iterable.apply(driverState.description.command) + val schedulerProperties = Iterable.apply(driverState.description.schedulerProperties) + val commandEnv = Iterable.apply(driverState.description.command.environment) + val driverTable = + UIUtils.listingTable(driverHeaders, driverRow, driverDescription) + val commandTable = + UIUtils.listingTable(commandHeaders, commandRow, command) + val commandEnvTable = + UIUtils.listingTable(commandEnvHeaders, propertiesRow, commandEnv) + val schedulerTable = + UIUtils.listingTable(schedulerHeaders, propertiesRow, schedulerProperties) + val launchedTable = + UIUtils.listingTable(launchedHeaders, launchedRow, submissionState) + val retryTable = + UIUtils.listingTable( + retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) + val content = +

      Driver state information for driver id {driverId}

      + Back to Drivers +
      +
      +

      Driver state: {driverState.state}

      +

      Driver properties

      + {driverTable} +

      Driver command

      + {commandTable} +

      Driver command environment

      + {commandEnvTable} +

      Scheduler properties

      + {schedulerTable} +

      Launched state

      + {launchedTable} +

      Retry state

      + {retryTable} +
      +
      ; + + UIUtils.basicSparkPage(content, s"Details for Job $driverId") + } + + private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { + submissionState.map { state => + + Mesos Slave ID + {state.slaveId.getValue} + + + Mesos Task ID + {state.taskId.getValue} + + + Launch Time + {state.startDate} + + + Finish Time + {state.finishDate.map(_.toString).getOrElse("")} + + + Last Task Status + {state.mesosTaskStatus.map(_.toString).getOrElse("")} + + }.getOrElse(Seq[Node]()) + } + + private def propertiesRow(properties: collection.Map[String, String]): Seq[Node] = { + properties.map { case (k, v) => + + {k}{v} + + }.toSeq + } + + private def commandRow(command: Command): Seq[Node] = { + + Main class{command.mainClass} + + + Arguments{command.arguments.mkString(" ")} + + + Class path entries{command.classPathEntries.mkString(" ")} + + + Java options{command.javaOpts.mkString((" "))} + + + Library path entries{command.libraryPathEntries.mkString((" "))} + + } + + private def driverRow(driver: MesosDriverDescription): Seq[Node] = { + + Name{driver.name} + + + Id{driver.submissionId} + + + Cores{driver.cores} + + + Memory{driver.mem} + + + Submitted{driver.submissionDate} + + + Supervise{driver.supervise} + + } + + private def retryRow(retryState: Option[MesosClusterRetryState]): Seq[Node] = { + retryState.map { state => + + + {state.lastFailureStatus} + + + {state.nextRetry} + + + {state.retries} + + + }.getOrElse(Seq[Node]()) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala new file mode 100644 index 000000000000..7419fa969964 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + val state = parent.scheduler.getSchedulerState() + val queuedHeaders = Seq("Driver ID", "Submit Date", "Main Class", "Driver Resources") + val driverHeaders = queuedHeaders ++ + Seq("Start Date", "Mesos Slave ID", "State") + val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ + Seq("Last Failed Status", "Next Retry Time", "Attempt Count") + val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) + val launchedTable = UIUtils.listingTable(driverHeaders, driverRow, state.launchedDrivers) + val finishedTable = UIUtils.listingTable(driverHeaders, driverRow, state.finishedDrivers) + val retryTable = UIUtils.listingTable(retryHeaders, retryRow, state.pendingRetryDrivers) + val content = +

      Mesos Framework ID: {state.frameworkId}

      +
      +
      +

      Queued Drivers:

      + {queuedTable} +

      Launched Drivers:

      + {launchedTable} +

      Finished Drivers:

      + {finishedTable} +

      Supervise drivers waiting for retry:

      + {retryTable} +
      +
      ; + UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + } + + private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { + val id = submission.submissionId + + {id} + {submission.submissionDate} + {submission.command.mainClass} + cpus: {submission.cores}, mem: {submission.mem} + + } + + private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + val id = state.driverDescription.submissionId + + {id} + {state.driverDescription.submissionDate} + {state.driverDescription.command.mainClass} + cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} + {state.startDate} + {state.slaveId.getValue} + {stateString(state.mesosTaskStatus)} + + } + + private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + val id = submission.submissionId + + {id} + {submission.submissionDate} + {submission.command.mainClass} + {submission.retryState.get.lastFailureStatus} + {submission.retryState.get.nextRetry} + {submission.retryState.get.retries} + + } + + private def stateString(status: Option[TaskStatus]): String = { + if (status.isEmpty) { + return "" + } + val sb = new StringBuilder + val s = status.get + sb.append(s"State: ${s.getState}") + if (status.get.hasMessage) { + sb.append(s", Message: ${s.getMessage}") + } + if (status.get.hasHealthy) { + sb.append(s", Healthy: ${s.getHealthy}") + } + if (status.get.hasSource) { + sb.append(s", Source: ${s.getSource}") + } + if (status.get.hasReason) { + sb.append(s", Reason: ${s.getReason}") + } + if (status.get.hasTimestamp) { + sb.append(s", Time: ${s.getTimestamp}") + } + sb.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala new file mode 100644 index 000000000000..3f693545a034 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.{SparkUI, WebUI} + +/** + * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] + */ +private[spark] class MesosClusterUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + dispatcherPublicAddress: String, + val scheduler: MesosClusterScheduler) + extends WebUI(securityManager, port, conf) { + + initialize() + + def activeWebUiUrl: String = "http://" + dispatcherPublicAddress + ":" + boundPort + + override def initialize() { + attachPage(new MesosClusterPage(this)) + attachPage(new DriverPage(this)) + attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + } +} + +private object MesosClusterUI { + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala new file mode 100644 index 000000000000..1fe956320a1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.{DataOutputStream, FileNotFoundException} +import java.net.{ConnectException, HttpURLConnection, SocketException, URL} +import javax.servlet.http.HttpServletResponse + +import scala.collection.mutable +import scala.io.Source + +import com.fasterxml.jackson.core.JsonProcessingException +import com.google.common.base.Charsets + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A client that submits applications to a [[RestSubmissionServer]]. + * + * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], + * where [action] can be one of create, kill, or status. Each type of request is represented in + * an HTTP message sent to the following prefixes: + * (1) submit - POST to /submissions/create + * (2) kill - POST /submissions/kill/[submissionId] + * (3) status - GET /submissions/status/[submissionId] + * + * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields. + * Otherwise, the URL fully specifies the intended action of the client. + * + * Since the protocol is expected to be stable across Spark versions, existing fields cannot be + * added or removed, though new optional fields can be added. In the rare event that forward or + * backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2). + * + * The client and the server must communicate using the same version of the protocol. If there + * is a mismatch, the server will respond with the highest protocol version it supports. A future + * implementation of this client can use that information to retry using the version specified + * by the server. + */ +private[spark] class RestSubmissionClient(master: String) extends Logging { + import RestSubmissionClient._ + + private val supportedMasterPrefixes = Seq("spark://", "mesos://") + + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } + + // Set of masters that lost contact with us, used to keep track of + // whether there are masters still alive for us to communicate with + private val lostMasters = new mutable.HashSet[String] + + /** + * Submit an application specified by the parameters in the provided request. + * + * If the submission was successful, poll the status of the submission and report + * it to the user. Otherwise, report the error message provided by the server. + */ + def createSubmission(request: CreateSubmissionRequest): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to launch an application in $master.") + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getSubmitUrl(m) + try { + response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => + if (s.success) { + reportSubmissionStatus(s) + handleRestResponse(s) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } + } + response + } + + /** Request that the server kill the specified submission. */ + def killSubmission(submissionId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill submission $submissionId in $master.") + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getKillUrl(m, submissionId) + try { + response = post(url) + response match { + case k: KillSubmissionResponse => + if (!Utils.responseFromBackup(k.message)) { + handleRestResponse(k) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } + } + response + } + + /** Request the status of a submission from the server. */ + def requestSubmissionStatus( + submissionId: String, + quiet: Boolean = false): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request for the status of submission $submissionId in $master.") + + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getStatusUrl(m, submissionId) + try { + response = get(url) + response match { + case s: SubmissionStatusResponse if s.success => + if (!quiet) { + handleRestResponse(s) + } + handled = true + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } + } + response + } + + /** Construct a message that captures the specified parameters for submitting an application. */ + def constructSubmitRequest( + appResource: String, + mainClass: String, + appArgs: Array[String], + sparkProperties: Map[String, String], + environmentVariables: Map[String, String]): CreateSubmissionRequest = { + val message = new CreateSubmissionRequest + message.clientSparkVersion = sparkVersion + message.appResource = appResource + message.mainClass = mainClass + message.appArgs = appArgs + message.sparkProperties = sparkProperties + message.environmentVariables = environmentVariables + message.validate() + message + } + + /** Send a GET request to the specified URL. */ + private def get(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending GET request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + readResponse(conn) + } + + /** Send a POST request to the specified URL. */ + private def post(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + readResponse(conn) + } + + /** Send a POST request with the given JSON as the body to the specified URL. */ + private def postJson(url: URL, json: String): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url:\n$json") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + try { + val out = new DataOutputStream(conn.getOutputStream) + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } + } catch { + case e: ConnectException => + throw new SubmitRestConnectionException("Connect Exception when connect to server", e) + } + readResponse(conn) + } + + /** + * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]]. + * If the response represents an error, report the embedded message to the user. + * Exposed for testing. + */ + private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + try { + val dataStream = + if (connection.getResponseCode == HttpServletResponse.SC_OK) { + connection.getInputStream + } else { + connection.getErrorStream + } + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") + } + } catch { + case unreachable @ (_: FileNotFoundException | _: SocketException) => + throw new SubmitRestConnectionException("Unable to connect to server", unreachable) + case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + throw new SubmitRestProtocolException("Malformed response received from server", malformed) + } + } + + /** Return the REST URL for creating a new submission. */ + private def getSubmitUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/create") + } + + /** Return the REST URL for killing an existing submission. */ + private def getKillUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/kill/$submissionId") + } + + /** Return the REST URL for requesting the status of an existing submission. */ + private def getStatusUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/status/$submissionId") + } + + /** Return the base URL for communicating with the server, including the protocol version. */ + private def getBaseUrl(master: String): String = { + var masterUrl = master + supportedMasterPrefixes.foreach { prefix => + if (master.startsWith(prefix)) { + masterUrl = master.stripPrefix(prefix) + } + } + masterUrl = masterUrl.stripSuffix("/") + s"http://$masterUrl/$PROTOCOL_VERSION/submissions" + } + + /** Throw an exception if this is not standalone mode. */ + private def validateMaster(master: String): Unit = { + val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) } + if (!valid) { + throw new IllegalArgumentException( + "This REST client only supports master URLs that start with " + + "one of the following: " + supportedMasterPrefixes.mkString(",")) + } + } + + /** Report the status of a newly created submission. */ + private def reportSubmissionStatus( + submitResponse: CreateSubmissionResponse): Unit = { + if (submitResponse.success) { + val submissionId = submitResponse.submissionId + if (submissionId != null) { + logInfo(s"Submission successfully created as $submissionId. Polling submission state...") + pollSubmissionStatus(submissionId) + } else { + // should never happen + logError("Application successfully submitted, but submission ID was not provided!") + } + } else { + val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") + logError(s"Application submission failed$failMessage") + } + } + + /** + * Poll the status of the specified submission and log it. + * This retries up to a fixed number of times before giving up. + */ + private def pollSubmissionStatus(submissionId: String): Unit = { + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val response = requestSubmissionStatus(submissionId, quiet = true) + val statusResponse = response match { + case s: SubmissionStatusResponse => s + case _ => return // unexpected type, let upstream caller handle it + } + if (statusResponse.success) { + val driverState = Option(statusResponse.driverState) + val workerId = Option(statusResponse.workerId) + val workerHostPort = Option(statusResponse.workerHostPort) + val exception = Option(statusResponse.message) + // Log driver state, if present + driverState match { + case Some(state) => logInfo(s"State of driver $submissionId is now $state.") + case _ => logError(s"State of driver $submissionId was not found!") + } + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) + } + logError(s"Error: Master did not recognize driver $submissionId.") + } + + /** Log the response sent by the server in the REST application submission protocol. */ + private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = { + logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}") + } + + /** Log an appropriate error if the response sent by the server is not of the expected type. */ + private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { + logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") + } + + /** + * When a connection exception is caught, return true if all masters are lost. + * Note that the heuristic used here does not take into account that masters + * can recover during the lifetime of this client. This assumption should be + * harmless because this client currently does not support retrying submission + * on failure yet (SPARK-6443). + */ + private def handleConnectionException(masterUrl: String): Boolean = { + if (!lostMasters.contains(masterUrl)) { + logWarning(s"Unable to connect to server ${masterUrl}.") + lostMasters += masterUrl + } + lostMasters.size >= masters.size + } +} + +private[spark] object RestSubmissionClient { + private val REPORT_DRIVER_STATUS_INTERVAL = 1000 + private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 + val PROTOCOL_VERSION = "v1" + + /** + * Submit an application, assuming Spark parameters are specified through the given config. + * This is abstracted to its own method for testing purposes. + */ + def run( + appResource: String, + mainClass: String, + appArgs: Array[String], + conf: SparkConf, + env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + val master = conf.getOption("spark.master").getOrElse { + throw new IllegalArgumentException("'spark.master' must be set.") + } + val sparkProperties = conf.getAll.toMap + val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } + val client = new RestSubmissionClient(master) + val submitRequest = client.constructSubmitRequest( + appResource, mainClass, appArgs, sparkProperties, environmentVariables) + client.createSubmission(submitRequest) + } + + def main(args: Array[String]): Unit = { + if (args.size < 2) { + sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") + sys.exit(1) + } + val appResource = args(0) + val mainClass = args(1) + val appArgs = args.slice(2, args.size) + val conf = new SparkConf + run(appResource, mainClass, appArgs, conf) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala new file mode 100644 index 000000000000..2e78d03e5c0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + */ +private[spark] abstract class RestSubmissionServer( + val host: String, + val requestedPort: Int, + val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet + protected val killRequestServlet: KillRequestServlet + protected val statusRequestServlet: StatusRequestServlet + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" + protected lazy val contextToServlet = Map[String, RestServlet]( + s"$baseContext/create/*" -> submitRequestServlet, + s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/status/*" -> statusRequestServlet, + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object RestSubmissionServer { + val PROTOCOL_VERSION = RestSubmissionClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class RestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse +} + +/** + * A servlet for handling status requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class StatusRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse +} + +/** + * A servlet for handling submit requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class SubmitRequestServlet extends RestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + protected def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends RestServlet { + private val serverVersion = RestSubmissionServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala new file mode 100644 index 000000000000..d5b9bcab1423 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.File +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * This is intended to be embedded in the standalone Master and used in cluster mode only. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + * + * @param host the address this server should bind to + * @param requestedPort the port this server will attempt to bind to + * @param masterConf the conf used by the Master + * @param masterEndpoint reference to the Master endpoint to which requests can be sent + * @param masterUrl the URL of the Master new drivers will attempt to connect to + */ +private[deploy] class StandaloneRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + masterEndpoint: RpcEndpointRef, + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) + protected override val killRequestServlet = + new StandaloneKillRequestServlet(masterEndpoint, masterConf) + protected override val statusRequestServlet = + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) +} + +/** + * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) + extends KillRequestServlet { + + protected def handleKill(submissionId: String): KillSubmissionResponse = { + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) + val k = new KillSubmissionResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.submissionId = submissionId + k.success = response.success + k + } +} + +/** + * A servlet for handling status requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) + extends StatusRequestServlet { + + protected def handleStatus(submissionId: String): SubmissionStatusResponse = { + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) + val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } + val d = new SubmissionStatusResponse + d.serverSparkVersion = sparkVersion + d.submissionId = submissionId + d.success = response.found + d.driverState = response.state.map(_.toString).orNull + d.workerId = response.workerId.orNull + d.workerHostPort = response.workerHostPort.orNull + d.message = message.orNull + d + } +} + +/** + * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StandaloneSubmitRequestServlet( + masterEndpoint: RpcEndpointRef, + masterUrl: String, + conf: SparkConf) + extends SubmitRequestServlet { + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that takes into account memory, java options, + * classpath and other settings to launch the driver. This does not currently consider + * fields used by python applications since python is not supported in standalone + * cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + + // Construct driver description + val conf = new SparkConf(false) + .setAll(sparkProperties) + .set("spark.master", masterUrl) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper + environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) + } + + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala new file mode 100644 index 000000000000..b97921ec934a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * An exception thrown in the REST application submission protocol. + */ +private[rest] class SubmitRestProtocolException(message: String, cause: Throwable = null) + extends Exception(message, cause) + +/** + * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]]. + */ +private[rest] class SubmitRestMissingFieldException(message: String) + extends SubmitRestProtocolException(message) + +/** + * An exception thrown if the REST client cannot reach the REST server. + */ +private[deploy] class SubmitRestConnectionException(message: String, cause: Throwable) + extends SubmitRestProtocolException(message, cause) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala new file mode 100644 index 000000000000..ef5a7e35ad56 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import com.fasterxml.jackson.annotation._ +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.util.Utils + +/** + * An abstract message exchanged in the REST application submission protocol. + * + * This message is intended to be serialized to and deserialized from JSON in the exchange. + * Each message can either be a request or a response and consists of three common fields: + * (1) the action, which fully specifies the type of the message + * (2) the Spark version of the client / server + * (3) an optional message + */ +@JsonInclude(Include.NON_NULL) +@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) +@JsonPropertyOrder(alphabetic = true) +private[rest] abstract class SubmitRestProtocolMessage { + @JsonIgnore + val messageType = Utils.getFormattedClassName(this) + + val action: String = messageType + var message: String = null + + // For JSON deserialization + private def setAction(a: String): Unit = { } + + /** + * Serialize the message to JSON. + * This also ensures that the message is valid and its fields are in the expected format. + */ + def toJson: String = { + validate() + SubmitRestProtocolMessage.mapper.writeValueAsString(this) + } + + /** + * Assert the validity of the message. + * If the validation fails, throw a [[SubmitRestProtocolException]]. + */ + final def validate(): Unit = { + try { + doValidate() + } catch { + case e: Exception => + throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e) + } + } + + /** Assert the validity of the message */ + protected def doValidate(): Unit = { + if (action == null) { + throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType") + } + } + + /** Assert that the specified field is set in this message. */ + protected def assertFieldIsSet[T](value: T, name: String): Unit = { + if (value == null) { + throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.") + } + } + + /** + * Assert a condition when validating this message. + * If the assertion fails, throw a [[SubmitRestProtocolException]]. + */ + protected def assert(condition: Boolean, failMessage: String): Unit = { + if (!condition) { throw new SubmitRestProtocolException(failMessage) } + } +} + +/** + * Helper methods to process serialized [[SubmitRestProtocolMessage]]s. + */ +private[spark] object SubmitRestProtocolMessage { + private val packagePrefix = this.getClass.getPackage.getName + private val mapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .enable(SerializationFeature.INDENT_OUTPUT) + .registerModule(DefaultScalaModule) + + /** + * Parse the value of the action field from the given JSON. + * If the action field is not found, throw a [[SubmitRestMissingFieldException]]. + */ + def parseAction(json: String): String = { + val value: Option[String] = parse(json) match { + case JObject(fields) => + fields.collectFirst { case ("action", v) => v }.collect { case JString(s) => s } + case _ => None + } + value.getOrElse { + throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") + } + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method first parses the action from the JSON and uses it to infer the message type. + * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in + * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. + */ + def fromJson(json: String): SubmitRestProtocolMessage = { + val className = parseAction(json) + val clazz = Utils.classForName(packagePrefix + "." + className) + .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) + fromJson(json, clazz) + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method determines the type of the message from the class provided instead of + * inferring it from the action field. This is useful for deserializing JSON that + * represents custom user-defined messages. + */ + def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { + mapper.readValue(json, clazz) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala new file mode 100644 index 000000000000..0d50a768942e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.util.Try + +import org.apache.spark.util.Utils + +/** + * An abstract request sent from the client in the REST application submission protocol. + */ +private[rest] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + var clientSparkVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(clientSparkVersion, "clientSparkVersion") + } +} + +/** + * A request to launch a new application in the REST application submission protocol. + */ +private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { + var appResource: String = null + var mainClass: String = null + var appArgs: Array[String] = null + var sparkProperties: Map[String, String] = null + var environmentVariables: Map[String, String] = null + + protected override def doValidate(): Unit = { + super.doValidate() + assert(sparkProperties != null, "No Spark properties set!") + assertFieldIsSet(appResource, "appResource") + assertPropertyIsSet("spark.app.name") + assertPropertyIsBoolean("spark.driver.supervise") + assertPropertyIsNumeric("spark.driver.cores") + assertPropertyIsNumeric("spark.cores.max") + assertPropertyIsMemory("spark.driver.memory") + assertPropertyIsMemory("spark.executor.memory") + } + + private def assertPropertyIsSet(key: String): Unit = + assertFieldIsSet(sparkProperties.getOrElse(key, null), key) + + private def assertPropertyIsBoolean(key: String): Unit = + assertProperty[Boolean](key, "boolean", _.toBoolean) + + private def assertPropertyIsNumeric(key: String): Unit = + assertProperty[Double](key, "numeric", _.toDouble) + + private def assertPropertyIsMemory(key: String): Unit = + assertProperty[Int](key, "memory", Utils.memoryStringToMb) + + /** Assert that a Spark property can be converted to a certain type. */ + private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = { + sparkProperties.get(key).foreach { value => + Try(convert(value)).getOrElse { + throw new SubmitRestProtocolException( + s"Property '$key' expected $valueType value: actual was '$value'.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala new file mode 100644 index 000000000000..0e226ee294ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean + +/** + * An abstract response sent from the server in the REST application submission protocol. + */ +private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + var serverSparkVersion: String = null + var success: Boolean = null + var unknownFields: Array[String] = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(serverSparkVersion, "serverSparkVersion") + } +} + +/** + * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. + */ +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a kill request in the REST application submission protocol. + */ +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a status request in the REST application submission protocol. + */ +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + var driverState: String = null + var workerId: String = null + var workerHostPort: String = null + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * An error response message used in the REST application submission protocol. + */ +private[rest] class ErrorResponse extends SubmitRestProtocolResponse { + // The highest protocol version that the server knows about + // This is set when the client specifies an unknown version + var highestProtocolVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(message, "message") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala new file mode 100644 index 000000000000..868cc35d06ef --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest.mesos + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.atomic.AtomicLong +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * All requests are forwarded to + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * This is intended to be used in Mesos cluster mode only. + * For more details about the REST submission please refer to [[RestSubmissionServer]] javadocs. + */ +private[spark] class MesosRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + scheduler: MesosClusterScheduler) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new MesosSubmitRequestServlet(scheduler, masterConf) + protected override val killRequestServlet = + new MesosKillRequestServlet(scheduler, masterConf) + protected override val statusRequestServlet = + new MesosStatusRequestServlet(scheduler, masterConf) +} + +private[mesos] class MesosSubmitRequestServlet( + scheduler: MesosClusterScheduler, + conf: SparkConf) + extends SubmitRequestServlet { + + private val DEFAULT_SUPERVISE = false + private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb + private val DEFAULT_CORES = 1.0 + + private val nextDriverNumber = new AtomicLong(0) + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def newDriverId(submitDate: Date): String = { + "driver-%s-%04d".format( + createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that launches a mesos framework for the job. + * This does not currently consider fields used by python applications since python + * is not supported in mesos cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) + + // Construct driver description + val conf = new SparkConf(false).setAll(sparkProperties) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) + val submitDate = new Date() + val submissionId = newDriverId(submitDate) + + new MesosDriverDescription( + name, appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, + command, request.sparkProperties, submissionId, submitDate) + } + + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val s = scheduler.submitDriver(driverDescription) + s.serverSparkVersion = sparkVersion + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + s.unknownFields = unknownFields + } + s + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} + +private[mesos] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends KillRequestServlet { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = scheduler.killDriver(submissionId) + k.serverSparkVersion = sparkVersion + k + } +} + +private[mesos] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends StatusRequestServlet { + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val d = scheduler.getDriverStatus(submissionId) + d.serverSparkVersion = sparkVersion + d + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 28e9662db5da..ce02ee203a4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -18,18 +18,20 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} -import java.lang.System._ +import scala.collection.JavaConverters._ import scala.collection.Map import org.apache.spark.Logging +import org.apache.spark.SecurityManager import org.apache.spark.deploy.Command +import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils /** ** Utilities for running commands with the spark classpath. */ -private[spark] +private[deploy] object CommandUtils extends Logging { /** @@ -38,12 +40,14 @@ object CommandUtils extends Logging { */ def buildProcessBuilder( command: Command, + securityMgr: SecurityManager, memory: Int, sparkHome: String, substituteArguments: String => String, classPaths: Seq[String] = Seq[String](), env: Map[String, String] = sys.env): ProcessBuilder = { - val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) + val localCommand = buildLocalCommand( + command, securityMgr, substituteArguments, classPaths, env) val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) val builder = new ProcessBuilder(commandSeq: _*) val environment = builder.environment() @@ -54,12 +58,10 @@ object CommandUtils extends Logging { } private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { - val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") - // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows - Seq(runner) ++ buildJavaOpts(command, memory, sparkHome) ++ Seq(command.mainClass) ++ - command.arguments + val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() + cmd.asScala ++ Seq(command.mainClass) ++ command.arguments } /** @@ -69,6 +71,7 @@ object CommandUtils extends Logging { */ private def buildLocalCommand( command: Command, + securityMgr: SecurityManager, substituteArguments: String => String, classPath: Seq[String] = Seq[String](), env: Map[String, String]): Command = { @@ -76,48 +79,26 @@ object CommandUtils extends Logging { val libraryPathEntries = command.libraryPathEntries val cmdLibraryPath = command.environment.get(libraryPathName) - val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + var newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) } else { command.environment } + // set auth secret to env variable if needed + if (securityMgr.isAuthenticationEnabled) { + newEnvironment += (SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey) + } + Command( command.mainClass, command.arguments.map(substituteArguments), newEnvironment, command.classPathEntries ++ classPath, Seq[String](), // library path already captured in environment variable - command.javaOpts) - } - - /** - * Attention: this must always be aligned with the environment variables in the run scripts and - * the way the JAVA_OPTS are assembled there. - */ - private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { - val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") - - // Exists for backwards compatibility with older Spark versions - val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString) - .getOrElse(Nil) - if (workerLocalOpts.length > 0) { - logWarning("SPARK_JAVA_OPTS was set on the worker. It is deprecated in Spark 1.0.") - logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.") - } - - // Figure out our classpath with the external compute-classpath script - val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" - val classPath = Utils.executeAndGetOutput( - Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment = command.environment) - val userClassPath = command.classPathEntries ++ Seq(classPath) - - val javaVersion = System.getProperty("java.version") - val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None - Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts + // filter out auth secret from java options + command.javaOpts.filterNot(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 28cab36c7b9e..89159ff5e2b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,65 +19,74 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ -import scala.collection.Map +import scala.collection.JavaConverters._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.{Command, DriverDescription, SparkHadoopUtil} +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.{Utils, Clock, SystemClock} /** * Manages the execution of one driver, including automatically restarting the driver on failure. * This is currently only used in standalone cluster deploy mode. */ -private[spark] class DriverRunner( - val conf: SparkConf, +private[deploy] class DriverRunner( + conf: SparkConf, val driverId: String, val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, - val workerUrl: String) + val worker: RpcEndpointRef, + val workerUrl: String, + val securityManager: SecurityManager) extends Logging { - @volatile var process: Option[Process] = None - @volatile var killed = false + @volatile private var process: Option[Process] = None + @volatile private var killed = false // Populated once finished - var finalState: Option[DriverState] = None - var finalException: Option[Exception] = None - var finalExitCode: Option[Int] = None + private[worker] var finalState: Option[DriverState] = None + private[worker] var finalException: Option[Exception] = None + private var finalExitCode: Option[Int] = None // Decoupled for testing - private[deploy] def setClock(_clock: Clock) = clock = _clock - private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper - private var clock = new Clock { - def currentTimeMillis(): Long = System.currentTimeMillis() + def setClock(_clock: Clock): Unit = { + clock = _clock } + + def setSleeper(_sleeper: Sleeper): Unit = { + sleeper = _sleeper + } + + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) } /** Starts a thread to run and manage the driver. */ - def start() = { + private[worker] def start() = { new Thread("DriverRunner for " + driverId) { override def run() { try { val driverDir = createWorkingDirectory() val localJarFilename = downloadUserJar(driverDir) - // Make sure user application jar is on the classpath + def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl + case "{{USER_JAR}}" => localJarFilename + case other => other + } + // TODO: If we add ability to submit multiple jars they should also be added here - val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename)) + val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager, + driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { @@ -98,25 +107,19 @@ private[spark] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } /** Terminate this driver (or prevent it from ever starting if not yet started) */ - def kill() { + private[worker] def kill() { synchronized { process.foreach(p => p.destroy()) killed = true } } - /** Replace variables in a command argument passed to us */ - private def substituteVariables(argument: String): String = argument match { - case "{{WORKER_URL}}" => workerUrl - case other => other - } - /** * Creates the working directory for this driver. * Will throw an exception if there are errors preparing the directory. @@ -134,12 +137,9 @@ private[spark] class DriverRunner( * Will throw an exception if there are errors downloading the jar. */ private def downloadUserJar(driverDir: File): String = { - val jarPath = new Path(driverDesc.jarUrl) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val jarFileSystem = jarPath.getFileSystem(hadoopConf) - val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) val jarFileName = jarPath.getName val localJarFile = new File(driverDir, jarFileName) @@ -147,7 +147,14 @@ private[spark] class DriverRunner( if (!localJarFile.exists()) { // May already exist if running multiple workers on one node logInfo(s"Copying user jar $jarPath to $destPath") - FileUtil.copy(jarFileSystem, jarPath, destPath, false, hadoopConf) + Utils.fetchFile( + driverDesc.jarUrl, + driverDir, + conf, + securityManager, + hadoopConf, + System.currentTimeMillis(), + useCache = false) } if (!localJarFile.exists()) { // Verify copy succeeded @@ -159,22 +166,22 @@ private[spark] class DriverRunner( private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { builder.directory(baseDir) - def initialize(process: Process) = { + def initialize(process: Process): Unit = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(baseDir, "stderr") - val header = "Launch Command: %s\n%s\n\n".format( - builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"") + val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40) Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) } - private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, - supervise: Boolean) { + def runCommandWithRetry( + command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = { // Time to wait between submission retries. var waitSeconds = 1 // A run of this many seconds resets the exponential back-off. @@ -191,9 +198,9 @@ private[spark] class DriverRunner( initialize(process.get) } - val processStart = clock.currentTimeMillis() + val processStart = clock.getTimeMillis() val exitCode = process.get.waitFor() - if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { waitSeconds = 1 } @@ -209,10 +216,6 @@ private[spark] class DriverRunner( } } -private[deploy] trait Clock { - def currentTimeMillis(): Long -} - private[deploy] trait Sleeper { def sleep(seconds: Int) } @@ -224,8 +227,8 @@ private[deploy] trait ProcessBuilderLike { } private[deploy] object ProcessBuilderLike { - def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { - def start() = processBuilder.start() - def command = processBuilder.command() + def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { + override def start(): Process = processBuilder.start() + override def command: Seq[String] = processBuilder.command().asScala } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 05e242e6df70..6799f78ec0c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,32 +17,52 @@ package org.apache.spark.deploy.worker -import akka.actor._ +import java.io.File import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. + * This is used in standalone cluster mode only. */ object DriverWrapper { def main(args: Array[String]) { args.toList match { - case workerUrl :: mainClass :: extraArgs => + /* + * IMPORTANT: Spark 1.3 provides a stable application submission gateway that is both + * backward and forward compatible across future Spark versions. Because this gateway + * uses this class to launch the driver, the ordering and semantics of the arguments + * here must also remain consistent across versions. + */ + case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) + + val currentLoader = Thread.currentThread.getContextClassLoader + val userJarUrl = new File(userJar).toURI().toURL() + val loader = + if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader) + } else { + new MutableURLClassLoader(Array(userJarUrl), currentLoader) + } + Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(args(1)) + val clazz = Utils.classForName(mainClass) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.shutdown() case _ => - System.err.println("Usage: DriverWrapper [options]") + // scalastyle:off println + System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index bc9f78b9e5c7..3aef0515cbf6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,60 +19,59 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. * This is currently only used in standalone mode. */ -private[spark] class ExecutorRunner( +private[deploy] class ExecutorRunner( val appId: String, val execId: Int, val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, + val webUiPort: Int, + val publicAddress: String, val sparkHome: File, val executorDir: File, val workerUrl: String, - val conf: SparkConf, + conf: SparkConf, val appLocalDirs: Seq[String], - var state: ExecutorState.Value) + @volatile var state: ExecutorState.Value) extends Logging { - val fullId = appId + "/" + execId - var workerThread: Thread = null - var process: Process = null - var stdoutAppender: FileAppender = null - var stderrAppender: FileAppender = null + private val fullId = appId + "/" + execId + private var workerThread: Thread = null + private var process: Process = null + private var stdoutAppender: FileAppender = null + private var stderrAppender: FileAppender = null // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. - var shutdownHook: Thread = null + private var shutdownHook: AnyRef = null - def start() { + private[worker] def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { override def run() { fetchAndRunExecutor() } } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { - override def run() { - killProcess(Some("Worker shutting down")) - } - } - Runtime.getRuntime.addShutdownHook(shutdownHook) + shutdownHook = ShutdownHookManager.addShutdownHook { () => + killProcess(Some("Worker shutting down")) } } /** @@ -84,32 +83,35 @@ private[spark] class ExecutorRunner( var exitCode: Option[Int] = None if (process != null) { logInfo("Killing process!") - process.destroy() - process.waitFor() if (stdoutAppender != null) { stdoutAppender.stop() } if (stderrAppender != null) { stderrAppender.stop() } + process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ - def kill() { + private[worker] def kill() { if (workerThread != null) { // the workerThread will kill the child process when interrupted workerThread.interrupt() workerThread = null state = ExecutorState.KILLED - Runtime.getRuntime.removeShutdownHook(shutdownHook) + try { + ShutdownHookManager.removeShutdownHook(shutdownHook) + } catch { + case e: IllegalStateException => None + } } } /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ - def substituteVariables(argument: String): String = argument match { + private[worker] def substituteVariables(argument: String): String = argument match { case "{{WORKER_URL}}" => workerUrl case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host @@ -121,22 +123,30 @@ private[spark] class ExecutorRunner( /** * Download and run the executor described in our ApplicationDescription */ - def fetchAndRunExecutor() { + private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) + val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") + logInfo(s"Launch command: $formattedCommand") builder.directory(executorDir) - builder.environment.put("SPARK_LOCAL_DIRS", appLocalDirs.mkString(",")) + builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0") + + // Add webUI log urls + val baseUrl = + s"http://$publicAddress:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType=" + builder.environment.put("SPARK_LOG_URL_STDERR", s"${baseUrl}stderr") + builder.environment.put("SPARK_LOG_URL_STDOUT", s"${baseUrl}stdout") + process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) + formattedCommand, "=" * 40) // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") @@ -151,7 +161,7 @@ private[spark] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala deleted file mode 100644 index b9798963bab0..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.worker - -import org.apache.spark.{Logging, SparkConf, SecurityManager} -import org.apache.spark.network.TransportContext -import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.SaslRpcHandler -import org.apache.spark.network.server.TransportServer -import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler - -/** - * Provides a server from which Executors can read shuffle files (rather than reading directly from - * each other), to provide uninterrupted access to the files in the face of executors being turned - * off or killed. - * - * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. - */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) - extends Logging { - - private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) - private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) - private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) - private val blockHandler = new ExternalShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = { - val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler - new TransportContext(transportConf, handler) - } - - private var server: TransportServer = _ - - /** Starts the external shuffle service if the user has configured us to. */ - def startIfEnabled() { - if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) - } - } - - def stop() { - if (enabled && server != null) { - server.close() - server = null - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index b20f5c0c8289..770927c80f7a 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,124 +21,145 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} - -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ -private[spark] class Worker( - host: String, - port: Int, +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} + +private[deploy] class Worker( + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) - def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + + // For worker and executor IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 + private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 // Model retries to connect to the master, after Hadoop's model. // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) // Afterwards, the next 10 attempts are between 30 and 90 seconds. // A bit of randomness is introduced so that not all of the workers attempt to reconnect at // the same time. - val INITIAL_REGISTRATION_RETRIES = 6 - val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 - val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 - val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { + private val INITIAL_REGISTRATION_RETRIES = 6 + private val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 + private val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 + private val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) - val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) + private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 + private val CLEANUP_INTERVAL_MILLIS = + conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) - - val testing: Boolean = sys.props.contains("spark.testing") - var master: ActorSelection = null - var masterAddress: Address = null - var activeMasterUrl: String = "" - var activeMasterWebUiUrl : String = "" - val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile var registered = false - @volatile var connected = false - val workerId = generateWorkerId() - val sparkHome = + private val APP_DATA_RETENTION_SECONDS = + conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + + private val testing: Boolean = sys.props.contains("spark.testing") + private var master: Option[RpcEndpointRef] = None + private var activeMasterUrl: String = "" + private[worker] var activeMasterWebUiUrl : String = "" + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false + private val workerId = generateWorkerId() + private val sparkHome = if (testing) { assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") new File(sys.props("spark.test.home")) } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } + var workDir: File = null - val executors = new HashMap[String, ExecutorRunner] - val finishedExecutors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new LinkedHashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] - val finishedDrivers = new HashMap[String, DriverRunner] + val executors = new HashMap[String, ExecutorRunner] + val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors", + WorkerWebUI.DEFAULT_RETAINED_EXECUTORS) + val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers", + WorkerWebUI.DEFAULT_RETAINED_DRIVERS) + // The shuffle service is not actually started unless configured. - val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) - val publicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") + private val publicAddress = { + val envVar = conf.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host } - var webUi: WorkerWebUI = null + private var webUi: WorkerWebUI = null - var coresUsed = 0 - var memoryUsed = 0 - var connectionAttemptCount = 0 + private var connectionAttemptCount = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) - val workerSource = new WorkerSource(this) + private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) + private val workerSource = new WorkerSource(this) - var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) + + var coresUsed = 0 + var memoryUsed = 0 def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed - def createWorkDir() { + private def createWorkDir() { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() @@ -156,14 +177,13 @@ private[spark] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -175,38 +195,45 @@ private[spark] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } /** * Re-register with the master because a network failure or a master failure has occurred. * If the re-registration attempt threshold is exceeded, the worker exits with error. - * Note that for thread-safety this should only be called from the actor. + * Note that for thread-safety this should only be called from the rpcEndpoint. */ private def reregisterWithMaster(): Unit = { Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -229,21 +256,48 @@ private[spark] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -252,41 +306,68 @@ private[spark] class Worker( } } - def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + + private def registerWithMaster() { + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => - // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker + // rpcEndpoint. + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -296,30 +377,27 @@ private[spark] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) - - case Heartbeat => - logInfo(s"Received heartbeat from driver ${sender.path}") + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -345,11 +423,13 @@ private[spark] class Worker( } // Create local dirs for the executor. These are passed to the executor via the - // SPARK_LOCAL_DIRS environment variable, and deleted by the Worker when the + // SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir).getAbsolutePath() + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + appDir.getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs @@ -362,16 +442,18 @@ private[spark] class Worker( self, workerId, host, + webUi.boundPort, + publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -379,32 +461,14 @@ private[spark] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - executors.get(fullId) match { - case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - executors -= fullId - finishedExecutors(fullId) = executor - coresUsed -= executor.cores - memoryUsed -= executor.memory - case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - } - maybeCleanupApplication(appId) - } + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + handleExecutorStateChanged(executorStateChanged) case KillExecutor(masterUrl, appId, execId) => if (masterUrl != activeMasterUrl) { @@ -429,7 +493,8 @@ private[spark] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl) + workerUri, + securityMgr) drivers(driverId) = driver driver.start() @@ -447,36 +512,10 @@ private[spark] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { - state match { - case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") - case DriverState.FAILED => - logWarning(s"Driver $driverId exited with failure") - case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") - case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") - case _ => - logDebug(s"Driver $driverId changed state to $state") - } - master ! DriverStateChanged(driverId, state, exception) - val driver = drivers.remove(driverId).get - finishedDrivers(driverId) = driver - memoryUsed -= driver.driverDesc.mem - coresUsed -= driver.driverDesc.cores + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + handleDriverStateChanged(driverStateChanged) } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -485,6 +524,21 @@ private[spark] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -501,35 +555,127 @@ private[spark] class Worker( Utils.deleteRecursively(new File(dir)) } } + shuffleService.applicationRemoved(id) + } + } + + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") } } - def generateWorkerId(): String = { + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() webUi.stop() metricsSystem.stop() } + + private def trimFinishedExecutorsIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedExecutors.size > retainedExecutors) { + finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach { + case (executorId, _) => finishedExecutors.remove(executorId) + } + } + } + + private def trimFinishedDriversIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedDrivers.size > retainedDrivers) { + finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach { + case (driverId, _) => finishedDrivers.remove(driverId) + } + } + } + + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { + val driverId = driverStateChanged.driverId + val exception = driverStateChanged.exception + val state = driverStateChanged.state + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FAILED => + logWarning(s"Driver $driverId exited with failure") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + case _ => + logDebug(s"Driver $driverId changed state to $state") + } + sendToMaster(driverStateChanged) + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + trimFinishedDriversIfNecessary() + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + + private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): + Unit = { + sendToMaster(executorStateChanged) + val state = executorStateChanged.state + if (ExecutorState.isFinished(state)) { + val appId = executorStateChanged.appId + val fullId = appId + "/" + executorStateChanged.execId + val message = executorStateChanged.message + val exitStatus = executorStateChanged.exitStatus + executors.get(fullId) match { + case Some(executor) => + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + executors -= fullId + finishedExecutors(fullId) = executor + trimFinishedExecutorsIfNecessary() + coresUsed -= executor.cores + memoryUsed -= executor.memory + case None => + logInfo("Unknown Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + } + maybeCleanupApplication(appId) + } + } } -private[spark] object Worker extends Logging { +private[deploy] object Worker extends Logging { + val SYSTEM_NAME = "sparkWorker" + val ENDPOINT_NAME = "Worker" + def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -537,22 +683,20 @@ private[spark] object Worker extends Logging { memory: Int, masterUrls: Array[String], workDir: String, - workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workerNumber: Option[Int] = None, + conf: SparkConf = new SparkConf): RpcEnv = { - // The LocalSparkCluster runs multiple local sparkWorkerX actor systems - val conf = new SparkConf - val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") - val actorName = "Worker" + // The LocalSparkCluster runs multiple local sparkWorkerX RPC Environments + val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("") val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, + masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr)) + rpcEnv } - private[spark] def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { + def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r val result = cmd.javaOpts.collectFirst { case pattern(_result) => _result.toBoolean @@ -560,7 +704,7 @@ private[spark] object Worker extends Logging { result.getOrElse(false) } - private[spark] def maybeUpdateSSLSettings(cmd: Command, conf: SparkConf): Command = { + def maybeUpdateSSLSettings(cmd: Command, conf: SparkConf): Command = { val prefix = "spark.ssl." val useNLC = "spark.ssl.useNodeLocalConf" if (isUseLocalNodeSSLConfig(cmd)) { @@ -573,5 +717,4 @@ private[spark] object Worker extends Logging { cmd } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 019cd70f2a22..5181142c5f80 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkConf /** * Command-line parser for the worker. */ -private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { +private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 0 var webUiPort = 8081 @@ -63,7 +63,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { checkWorkerMemory() - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) host = value @@ -105,7 +105,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { if (masters != null) { // Two positional arguments were given printUsageAndExit(1) } - masters = value.stripPrefix("spark://").split(",").map("spark://" + _) + masters = Utils.parseStandaloneMasterUrls(value) parse(tail) case Nil => @@ -121,6 +121,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -147,6 +149,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { val ibmVendor = System.getProperty("java.vendor").contains("IBM") var totalMb = 0 try { + // scalastyle:off classforname val bean = ManagementFactory.getOperatingSystemMXBean() if (ibmVendor) { val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") @@ -157,14 +160,17 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt } + // scalastyle:on classforname } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) + math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) } def checkWorkerMemory(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index df1e01b23b93..b36023bc40c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -21,7 +21,7 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source -private[spark] class WorkerSource(val worker: Worker) extends Source { +private[worker] class WorkerSource(val worker: Worker) extends Source { override val sourceName = "worker" override val metricRegistry = new MetricRegistry() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b61..735c4f092715 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,58 +17,62 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends RpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + } } // Used to avoid shutting down JVM during tests + // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit + // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to + // true rather than calling `System.exit`. The user can check `isShutDown` to know if + // `exitNonZero` is called. private[deploy] var isShutDown = false private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false - // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + // Lets filter events only from the worker's rpc system + private val expectedAddress = RpcAddress.fromURIString(workerUrl) + private def isWorker(address: RpcAddress) = expectedAddress == address - def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) + private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) + + override def receive: PartialFunction[Any, Unit] = { + case e => logWarning(s"Received unexpected message: $e") + } - override def receiveWithLogging = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { logInfo(s"Successfully connected to $workerUrl") + } + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // This log message will never be seen + logError(s"Lost connection to worker rpc endpoint $workerUrl. Exiting.") + exitNonZero() + } + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") logError(s"Error was: $cause") exitNonZero() - - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - // This log message will never be seen - logError(s"Lost connection to worker actor $workerUrl. Exiting.") - exitNonZero() - - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index ecb358c39981..5a1d06eb87db 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.worker.ui +import java.io.File +import java.net.URI import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -26,9 +28,10 @@ import org.apache.spark.util.Utils import org.apache.spark.Logging import org.apache.spark.util.logging.RollingFileAppender -private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { +private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +132,18 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + + // Verify that the normalized path of the log directory is in the working directory + val normalizedUri = new URI(logDirectory).normalize() + val normalizedLogDir = new File(normalizedUri.getPath) + if (!Utils.isInDirectory(workDir, normalizedLogDir)) { + return ("Error: invalid log directory " + logDirectory, 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") @@ -144,7 +159,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w offset } } - val endIndex = math.min(startIndex + totalLength, totalLength) + val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b90503280..fd905feb97e9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -31,20 +29,16 @@ import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - val workerActor = parent.worker.self - val worker = parent.worker - val timeout = parent.timeout +private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors @@ -134,7 +128,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { def driverRow(driver: DriverRunner): Seq[Node] = { {driver.driverId} - {driver.driverDesc.command.arguments(1)} + {driver.driverDesc.command.arguments(2)} {driver.finalState.getOrElse(DriverState.RUNNING)} {driver.driverDesc.cores.toString} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 7ac81a2d87ef..709a27233598 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -25,12 +25,12 @@ import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone worker. */ -private[spark] +private[worker] class WorkerWebUI( val worker: Worker, val workDir: File, @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - val timeout = AkkaUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() @@ -53,6 +53,8 @@ class WorkerWebUI( } } -private[spark] object WorkerWebUI { +private[worker] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR + val DEFAULT_RETAINED_DRIVERS = 1000 + val DEFAULT_RETAINED_EXECUTORS = 1000 } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index bc72c8970319..fcd76ec52742 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,47 +17,73 @@ package org.apache.spark.executor +import java.net.URL import java.nio.ByteBuffer -import scala.concurrent.Await +import org.apache.hadoop.conf.Configuration -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} +import scala.collection.mutable +import scala.util.{Failure, Success} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, + userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None - override def preStart() { + // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need + // to be changed so that we don't share the serializer instance across threads + private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() + + override def onStart() { logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + // This is a very fast action so we can use "ThreadUtils.sameThread" + driver = Some(ref) + ref.ask[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + }(ThreadUtils.sameThread).onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } + }(ThreadUtils.sameThread) } - override def receiveWithLogging = { + def extractLogUrls: Map[String, String] = { + val prefix = "SPARK_LOG_URL_" + sys.env.filterKeys(_.startsWith(prefix)) + .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + } + + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) - executor = new Executor(executorId, hostname, env, isLocal = false) + executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -68,7 +94,6 @@ private[spark] class CoarseGrainedExecutorBackend( logError("Received LaunchTask command but executor was null") System.exit(1) } else { - val ser = env.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, @@ -83,23 +108,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - if (x.remoteAddress == driver.anchorPath.address) { - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - } else { - logWarning(s"Received irrelevant DisassociatedEvent $x") - } - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -111,7 +141,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname: String, cores: Int, appId: String, - workerUrl: Option[String]) { + workerUrl: Option[String], + userClassPath: Seq[URL]) { SignalLogger.register(log) @@ -122,16 +153,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -145,6 +174,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } + if (driverConf.contains("spark.yarn.credentials.file")) { + logInfo("Will periodically update credentials from: " + + driverConf.get("spark.yarn.credentials.file")) + SparkHadoopUtil.get.startExecutorDelegationTokenRenewer(driverConf) + } + val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) @@ -152,34 +187,86 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() + SparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() } } def main(args: Array[String]) { - args.length match { - case x if x < 5 => - System.err.println( + var driverUrl: String = null + var executorId: String = null + var hostname: String = null + var cores: Int = 0 + var appId: String = null + var workerUrl: Option[String] = None + val userClassPath = new mutable.ListBuffer[URL]() + + var argv = args.toList + while (!argv.isEmpty) { + argv match { + case ("--driver-url") :: value :: tail => + driverUrl = value + argv = tail + case ("--executor-id") :: value :: tail => + executorId = value + argv = tail + case ("--hostname") :: value :: tail => + hostname = value + argv = tail + case ("--cores") :: value :: tail => + cores = value.toInt + argv = tail + case ("--app-id") :: value :: tail => + appId = value + argv = tail + case ("--worker-url") :: value :: tail => // Worker url is used in spark standalone mode to enforce fate-sharing with worker - "Usage: CoarseGrainedExecutorBackend " + - " [] ") - System.exit(1) + workerUrl = Some(value) + argv = tail + case ("--user-class-path") :: value :: tail => + userClassPath += new URL(value) + argv = tail + case Nil => + case tail => + // scalastyle:off println + System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println + printUsageAndExit() + } + } - // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) - // and CoarseMesosSchedulerBackend (for mesos mode). - case 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), None) - case x if x > 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5))) + if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || + appId == null) { + printUsageAndExit() } + + run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) } + + private def printUsageAndExit() = { + // scalastyle:off println + System.err.println( + """ + |"Usage: CoarseGrainedExecutorBackend [options] + | + | Options are: + | --driver-url + | --executor-id + | --hostname + | --cores + | --app-id + | --worker-url + | --user-class-path + |""".stripMargin) + // scalastyle:on println + System.exit(1) + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala new file mode 100644 index 000000000000..f47d7ef511da --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{TaskCommitDenied, TaskEndReason} + +/** + * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. + */ +private[spark] class CommitDeniedException( + msg: String, + jobID: Int, + splitID: Int, + attemptID: Int) + extends Exception(msg) { + + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) +} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 312bb3a1daaa..c3491bb8b1cf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,35 +17,38 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory +import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent._ +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} +import org.apache.spark.unsafe.memory.TaskMemoryManager +import org.apache.spark.util._ /** - * Spark executor used with Mesos, YARN, and the standalone scheduler. - * In coarse-grained mode, an existing actor system is provided. + * Spark executor, backed by a threadpool to run tasks. + * + * This can be used with Mesos, YARN, and the standalone scheduler. + * An internal RPC interface (at the moment Akka) is used for communication with the driver, + * except in the case of Mesos fine-grained mode. */ private[spark] class Executor( executorId: String, executorHostname: String, env: SparkEnv, + userClassPath: Seq[URL] = Nil, isLocal: Boolean = false) - extends Logging -{ + extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -58,8 +61,6 @@ private[spark] class Executor( private val conf = env.conf - @volatile private var isStopped = false - // No ip or host:port - just hostname Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -75,16 +76,21 @@ private[spark] class Executor( Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) } - val executorSource = new ExecutorSource(this, executorId) + // Start worker thread pool + private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val executorSource = new ExecutorSource(threadPool, executorId) if (!isLocal) { env.metricsSystem.registerSource(executorSource) env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) + + // Whether to load classes in user jars before those in Spark jars + private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -92,7 +98,7 @@ private[spark] class Executor( private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) // Set the classloader for serializer - env.serializer.setDefaultClassLoader(urlClassLoader) + env.serializer.setDefaultClassLoader(replClassLoader) // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. @@ -101,12 +107,12 @@ private[spark] class Executor( // Limit of bytes for total size of results (default is 1GB) private val maxResultSize = Utils.getMaxResultSize(conf) - // Start worker thread pool - val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") - // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + // Executor for the heartbeat task. + private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + startDriverHeartbeater() def launchTask( @@ -114,31 +120,35 @@ private[spark] class Executor( taskId: Long, attemptNumber: Int, taskName: String, - serializedTask: ByteBuffer) { + serializedTask: ByteBuffer): Unit = { val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean) { + def killTask(taskId: Long, interruptThread: Boolean): Unit = { val tr = runningTasks.get(taskId) if (tr != null) { tr.kill(interruptThread) } } - def stop() { + def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) - isStopped = true + env.rpcEnv.stop(executorEndpoint) + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() if (!isLocal) { env.stop() } } - private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + /** Returns the total amount of time this JVM process has spent in garbage collection. */ + private def computeTotalGcTime(): Long = { + ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum + } class TaskRunner( execBackend: ExecutorBackend, @@ -148,12 +158,19 @@ private[spark] class Executor( serializedTask: ByteBuffer) extends Runnable { + /** Whether this task has been killed. */ @volatile private var killed = false - @volatile var task: Task[Any] = _ - @volatile var attemptedTask: Option[Task[Any]] = None + + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ - def kill(interruptThread: Boolean) { + /** + * The task to run. This will be set in run() by deserializing the task binary coming + * from the driver. Once it is set, it will never be changed. + */ + @volatile var task: Task[Any] = _ + + def kill(interruptThread: Boolean): Unit = { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") killed = true if (task != null) { @@ -161,19 +178,21 @@ private[spark] class Executor( } } - override def run() { + override def run(): Unit = { + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var taskStart: Long = 0 - startGCTime = gcTime + startGCTime = computeTotalGcTime() try { val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. @@ -185,13 +204,30 @@ private[spark] class Executor( throw new TaskKilledException } - attemptedTask = Some(task) logDebug("Task " + taskId + "'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + var threwException = true + val (value, accumUpdates) = try { + val res = task.run( + taskAttemptId = taskId, + attemptNumber = attemptNumber, + metricsSystem = env.metricsSystem) + threwException = false + res + } finally { + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. @@ -205,20 +241,23 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.setExecutorDeserializeTime(taskStart - deserializeStartTime) - m.setExecutorRunTime(taskFinish - taskStart) - m.setJvmGCTime(gcTime - startGCTime) + // Deserialization happens in two parts: first, we deserialize a Task object, which + // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. + m.setExecutorDeserializeTime( + (taskStart - deserializeStartTime) + task.executorDeserializeTime) + // We need to subtract Task.run()'s deserialization time to avoid double-counting + m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) + m.updateAccumulators() } - val accumUpdates = Accumulators.values - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val serializedResult = { + val serializedResult: ByteBuffer = { if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + @@ -240,44 +279,50 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => { + case ffe: FetchFailedException => val reason = ffe.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - case _: TaskKilledException | _: InterruptedException if task.killed => { + case _: TaskKilledException | _: InterruptedException if task.killed => logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - } - case t: Throwable => { + case cDE: CommitDeniedException => + val reason = cDE.toTaskEndReason + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + + case t: Throwable => // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.setExecutorRunTime(serviceTime) - m.setJvmGCTime(gcTime - startGCTime) + val metrics: Option[TaskMetrics] = Option(task).flatMap { task => + task.metrics.map { m => + m.setExecutorRunTime(System.currentTimeMillis() - taskStart) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) + m.updateAccumulators() + m + } } - val reason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { SparkUncaughtExceptionHandler.uncaughtException(t) } - } + } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() - // Release memory used by this thread for accumulators - Accumulators.clear() runningTasks.remove(taskId) } } @@ -288,17 +333,23 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): MutableURLClassLoader = { + // Bootstrap the list of jars with the user class path. + val now = System.currentTimeMillis() + userClassPath.foreach { url => + currentJars(url.getPath().split("/").last) = now + } + val currentLoader = Utils.getContextOrSparkClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => + val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL - }.toArray - val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) - userClassPathFirst match { - case true => new ChildExecutorURLClassLoader(urls, currentLoader) - case false => new ExecutorURLClassLoader(urls, currentLoader) + } + if (userClassPathFirst) { + new ChildFirstURLClassLoader(urls, currentLoader) + } else { + new MutableURLClassLoader(urls, currentLoader) } } @@ -310,14 +361,13 @@ private[spark] class Executor( val classUri = conf.get("spark.repl.class.uri", null) if (classUri != null) { logInfo("Using REPL class URI: " + classUri) - val userClassPathFirst: java.lang.Boolean = - conf.getBoolean("spark.files.userClassPathFirst", false) try { - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + val _userClassPathFirst: java.lang.Boolean = userClassPathFirst + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, userClassPathFirst) + constructor.newInstance(conf, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -340,82 +390,87 @@ private[spark] class Executor( for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, - env.securityManager, hadoopConf, timestamp, useCache = !isLocal) - currentJars(name) = timestamp - // Add it to our class loader + for ((name, timestamp) <- newJars) { val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + val currentTimeStamp = currentJars.get(name) + .orElse(currentJars.get(localName)) + .getOrElse(-1L) + if (currentTimeStamp < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) + currentJars(name) = timestamp + // Add it to our class loader + val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL + if (!urlClassLoader.getURLs().contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } } - def startDriverHeartbeater() { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) - - val t = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) - - while (!isStopped) { - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() - val curGCTime = gcTime - - for (taskRunner <- runningTasks.values()) { - if (taskRunner.attemptedTask.nonEmpty) { - Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics() - metrics.updateInputMetrics() - metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - - if (isLocal) { - // JobProgressListener will hold an reference of it during - // onExecutorMetricsUpdate(), then JobProgressListener can not see - // the changes of metrics any more, so make a deep copy of it - val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) - tasksMetrics += ((taskRunner.taskId, copiedMetrics)) - } else { - // It will be copied by serialization - tasksMetrics += ((taskRunner.taskId, metrics)) - } - } - } - } - - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) - try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) - if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") - env.blockManager.reregister() - } - } catch { - case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t) + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + + /** Reports heartbeat and metrics for active tasks to the driver. */ + private def reportHeartBeat(): Unit = { + // list of (task id, metrics) to send back to the driver + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + val curGCTime = computeTotalGcTime() + + for (taskRunner <- runningTasks.values().asScala) { + if (taskRunner.task != null) { + taskRunner.task.metrics.foreach { metrics => + metrics.updateShuffleReadMetrics() + metrics.updateInputMetrics() + metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + metrics.updateAccumulators() + + if (isLocal) { + // JobProgressListener will hold an reference of it during + // onExecutorMetricsUpdate(), then JobProgressListener can not see + // the changes of metrics any more, so make a deep copy of it + val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) + tasksMetrics += ((taskRunner.taskId, copiedMetrics)) + } else { + // It will be copied by serialization + tasksMetrics += ((taskRunner.taskId, metrics)) } - - Thread.sleep(interval) } } } - t.setDaemon(true) - t.setName("Driver Heartbeater") - t.start() + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + try { + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) + if (response.reregisterBlockManager) { + logInfo("Told to re-register on heartbeat") + env.blockManager.reregister() + } + } catch { + case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + } + } + + /** + * Schedules a task to report heartbeat and partial metrics for active tasks to driver. + */ + private def startDriverHeartbeater(): Unit = { + val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") + + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) + } + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala deleted file mode 100644 index 41925f7e97e8..000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} - -/** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * Actor that runs inside of executors to enable driver -> executor RPC. - */ -private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { - - override def receiveWithLogging = { - case TriggerThreadDump => - sender ! Utils.getThreadDump() - } - -} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala new file mode 100644 index 000000000000..cf362f846473 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) + } + +} + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index 52862ae0ca5e..ea36fb60bd54 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -33,11 +33,11 @@ object ExecutorExitCode { /** DiskStore failed to create a local temporary directory after many attempts. */ val DISK_STORE_FAILED_TO_CREATE_DIR = 53 - /** TachyonStore failed to initialize after many attempts. */ - val TACHYON_STORE_FAILED_TO_INITIALIZE = 54 + /** ExternalBlockStore failed to initialize after many attempts. */ + val EXTERNAL_BLOCK_STORE_FAILED_TO_INITIALIZE = 54 - /** TachyonStore failed to create a local temporary directory after many attempts. */ - val TACHYON_STORE_FAILED_TO_CREATE_DIR = 55 + /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ + val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 def explainExitCode(exitCode: Int): String = { exitCode match { @@ -46,9 +46,11 @@ object ExecutorExitCode { case OOM => "OutOfMemoryError" case DISK_STORE_FAILED_TO_CREATE_DIR => "Failed to create local directory (bad spark.local.dir?)" - case TACHYON_STORE_FAILED_TO_INITIALIZE => "TachyonStore failed to initialize." - case TACHYON_STORE_FAILED_TO_CREATE_DIR => - "TachyonStore failed to create a local temporary directory." + // TODO: replace external block store with concrete implementation name + case EXTERNAL_BLOCK_STORE_FAILED_TO_INITIALIZE => "ExternalBlockStore failed to initialize." + // TODO: replace external block store with concrete implementation name + case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => + "ExternalBlockStore failed to create a local temporary directory." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index c4d73622c472..d16f4a1fc4e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -17,16 +17,20 @@ package org.apache.spark.executor -import scala.collection.JavaConversions._ +import java.util.concurrent.ThreadPoolExecutor + +import scala.collection.JavaConverters._ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem import org.apache.spark.metrics.source.Source -private[spark] class ExecutorSource(val executor: Executor, executorId: String) extends Source { +private[spark] +class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { + private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption + FileSystem.getAllStatistics.asScala.find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { @@ -41,23 +45,23 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getActiveCount() + override def getValue: Int = threadPool.getActiveCount() }) // Gauge for executor thread pool's approximate total number of tasks that have been completed metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), new Gauge[Long] { - override def getValue: Long = executor.threadPool.getCompletedTaskCount() + override def getValue: Long = threadPool.getCompletedTaskCount() }) // Gauge for executor thread pool's current number of threads metricRegistry.register(MetricRegistry.name("threadpool", "currentPool_size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getPoolSize() + override def getValue: Int = threadPool.getPoolSize() }) // Gauge got executor thread pool's largest number of threads that have ever simultaneously // been in th pool metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getMaximumPoolSize() + override def getValue: Int = threadPool.getMaximumPoolSize() }) // Gauge for file system stats of this executor diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala deleted file mode 100644 index 218ed7b5d2d3..000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.{URLClassLoader, URL} - -import org.apache.spark.util.ParentClassLoader - -/** - * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. - * We also make changes so user classes can come before the default classes. - */ - -private[spark] trait MutableURLClassLoader extends ClassLoader { - def addURL(url: URL) - def getURLs: Array[URL] -} - -private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends MutableURLClassLoader { - - private object userClassLoader extends URLClassLoader(urls, null){ - override def addURL(url: URL) { - super.addURL(url) - } - override def findClass(name: String): Class[_] = { - super.findClass(name) - } - } - - private val parentClassLoader = new ParentClassLoader(parent) - - override def findClass(name: String): Class[_] = { - try { - userClassLoader.findClass(name) - } catch { - case e: ClassNotFoundException => { - parentClassLoader.loadClass(name) - } - } - } - - def addURL(url: URL) { - userClassLoader.addURL(url) - } - - def getURLs() = { - userClassLoader.getURLs() - } -} - -private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) with MutableURLClassLoader { - - override def addURL(url: URL) { - super.addURL(url) - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index cfd672e1d8a9..0474fd2ccc12 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} @@ -28,7 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} +import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -55,7 +55,7 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. - val cpusPerTask = executorInfo.getResourcesList + val cpusPerTask = executorInfo.getResourcesList.asScala .find(_.getName == "cpus") .map(_.getScalar.getValue.toInt) .getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 97912c68c598..42207a955359 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,14 +17,15 @@ package org.apache.spark.executor -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark.executor.DataReadMethod.DataReadMethod +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -44,29 +45,29 @@ class TaskMetrics extends Serializable { * Host's name the task runs on */ private var _hostname: String = _ - def hostname = _hostname + def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ - def executorDeserializeTime = _executorDeserializeTime + def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ - def executorRunTime = _executorRunTime + def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ private var _resultSize: Long = _ - def resultSize = _resultSize + def resultSize: Long = _resultSize private[spark] def setResultSize(value: Long) = _resultSize = value @@ -74,31 +75,31 @@ class TaskMetrics extends Serializable { * Amount of time the JVM spent in garbage collection while executing this task */ private var _jvmGCTime: Long = _ - def jvmGCTime = _jvmGCTime + def jvmGCTime: Long = _jvmGCTime private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value /** * Amount of time spent serializing the task result */ private var _resultSerializationTime: Long = _ - def resultSerializationTime = _resultSerializationTime + def resultSerializationTime: Long = _resultSerializationTime private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value /** * The number of in-memory bytes spilled by this task */ private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value + def memoryBytesSpilled: Long = _memoryBytesSpilled + private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value + private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value /** * The number of on-disk bytes spilled by this task */ private var _diskBytesSpilled: Long = _ - def diskBytesSpilled = _diskBytesSpilled - def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value + def diskBytesSpilled: Long = _diskBytesSpilled + private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -106,7 +107,7 @@ class TaskMetrics extends Serializable { */ private var _inputMetrics: Option[InputMetrics] = None - def inputMetrics = _inputMetrics + def inputMetrics: Option[InputMetrics] = _inputMetrics /** * This should only be used when recreating TaskMetrics, not when updating input metrics in @@ -128,7 +129,7 @@ class TaskMetrics extends Serializable { */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None - def shuffleReadMetrics = _shuffleReadMetrics + def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** * This should only be used when recreating TaskMetrics, not when updating read metrics in @@ -177,41 +178,78 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): - InputMetrics =synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { + synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } } } /** * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. */ - private[spark] def updateShuffleReadMetrics() = synchronized { - val merged = new ShuffleReadMetrics() - for (depMetrics <- depsShuffleReadMetrics) { - merged.incFetchWaitTime(depMetrics.fetchWaitTime) - merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) - merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) - merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + private[spark] def updateShuffleReadMetrics(): Unit = synchronized { + if (!depsShuffleReadMetrics.isEmpty) { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.incFetchWaitTime(depMetrics.fetchWaitTime) + merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) + merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) + merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + merged.incLocalBytesRead(depMetrics.localBytesRead) + merged.incRecordsRead(depMetrics.recordsRead) + } + _shuffleReadMetrics = Some(merged) } - _shuffleReadMetrics = Some(merged) } - private[spark] def updateInputMetrics() = synchronized { + private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + // Get the hostname from cached data, since hostname is the order of number of nodes in + // cluster, so using cached hostname will decrease the object number and alleviate the GC + // overhead. + _hostname = TaskMetrics.getCachedHostName(_hostname) + } + + private var _accumulatorUpdates: Map[Long, Any] = Map.empty + @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + + private[spark] def updateAccumulators(): Unit = synchronized { + _accumulatorUpdates = _accumulatorsUpdater() + } + + /** + * Return the latest updates of accumulators in this task. + */ + def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates + + private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { + _accumulatorsUpdater = accumulatorsUpdater + } } private[spark] object TaskMetrics { + private val hostNameCache = new ConcurrentHashMap[String, String]() + def empty: TaskMetrics = new TaskMetrics + + def getCachedHostName(host: String): String = { + val canonicalHost = hostNameCache.putIfAbsent(host, host) + if (canonicalHost != null) canonicalHost else host + } } /** @@ -242,27 +280,31 @@ object DataWriteMethod extends Enumeration with Serializable { @DeveloperApi case class InputMetrics(readMethod: DataReadMethod.Value) { - private val _bytesRead: AtomicLong = new AtomicLong() + /** + * This is volatile so that it is visible to the updater thread. + */ + @volatile @transient var bytesReadCallback: Option[() => Long] = None /** * Total bytes read. */ - def bytesRead: Long = _bytesRead.get() - @volatile @transient var bytesReadCallback: Option[() => Long] = None + private var _bytesRead: Long = _ + def bytesRead: Long = _bytesRead + def incBytesRead(bytes: Long): Unit = _bytesRead += bytes /** - * Adds additional bytes read for this read method. + * Total records read. */ - def addBytesRead(bytes: Long) = { - _bytesRead.addAndGet(bytes) - } + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. */ def updateBytesRead() { bytesReadCallback.foreach { c => - _bytesRead.set(c()) + _bytesRead = c() } } @@ -285,8 +327,15 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { * Total bytes written */ private var _bytesWritten: Long = _ - def bytesWritten = _bytesWritten - private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + def bytesWritten: Long = _bytesWritten + private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value + + /** + * Total records written + */ + private var _recordsWritten: Long = 0L + def recordsWritten: Long = _recordsWritten + private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value } /** @@ -299,18 +348,17 @@ class ShuffleReadMetrics extends Serializable { * Number of remote blocks fetched in this shuffle by this task */ private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched = _remoteBlocksFetched + def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def defRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + /** * Number of local blocks fetched in this shuffle by this task */ private var _localBlocksFetched: Int = _ - def localBlocksFetched = _localBlocksFetched + def localBlocksFetched: Int = _localBlocksFetched private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def defLocalBlocksFetched(value: Int) = _localBlocksFetched -= value - + private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value /** * Time the task spent waiting for remote shuffle blocks. This only includes the time @@ -318,22 +366,42 @@ class ShuffleReadMetrics extends Serializable { * still not finished processing block A, it is not considered to be blocking on block B. */ private var _fetchWaitTime: Long = _ - def fetchWaitTime = _fetchWaitTime + def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ private var _remoteBytesRead: Long = _ - def remoteBytesRead = _remoteBytesRead + def remoteBytesRead: Long = _remoteBytesRead private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + private var _localBytesRead: Long = _ + def localBytesRead: Long = _localBytesRead + private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead: Long = _remoteBytesRead + _localBytesRead + /** * Number of blocks fetched in this shuffle by this task (remote or local) */ - def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched + + /** + * Total number of records read from the shuffle by this task + */ + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + private[spark] def incRecordsRead(value: Long) = _recordsRead += value + private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } /** @@ -346,17 +414,24 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for the shuffle by this task */ @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten = _shuffleBytesWritten + def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime= _shuffleWriteTime + def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** + * Total number of records written to the shuffle by this task + */ + @volatile private var _shuffleRecordsWritten: Long = _ + def shuffleRecordsWritten: Long = _shuffleRecordsWritten + private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value + private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value + private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa..532850dd5771 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 593a62b3e3b3..a5ad47293f1c 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.input import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -44,12 +44,9 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum - - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } @@ -73,16 +70,16 @@ private[spark] abstract class StreamBasedRecordReader[T]( private var key = "" private var value: T = null.asInstanceOf[T] - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} - override def close() = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: T = value - override def nextKeyValue = { + override def nextKeyValue: Boolean = { if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) @@ -119,7 +116,8 @@ private[spark] class StreamRecordReader( * The format for the PortableDataStream files */ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { - override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) + : CombineFileRecordReader[String, PortableDataStream] = { new CombineFileRecordReader[String, PortableDataStream]( split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) } @@ -204,7 +202,7 @@ class PortableDataStream( /** * Close the file (if it is currently open) */ - def close() = { + def close(): Unit = { if (isOpen) { try { fileIn.close() diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index aaef7c74eea3..1ba34a11414a 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -17,7 +17,7 @@ package org.apache.spark.input -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit @@ -52,10 +52,8 @@ private[spark] class WholeTextFileInputFormat * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index f856890d279f..607d5a321efc 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,7 +17,7 @@ package org.apache.spark.io -import java.io.{InputStream, OutputStream} +import java.io.{IOException, InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} @@ -26,7 +26,6 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils -import org.apache.spark.Logging /** * :: DeveloperApi :: @@ -53,15 +52,18 @@ private[spark] object CompressionCodec { "lzf" -> classOf[LZFCompressionCodec].getName, "snappy" -> classOf[SnappyCompressionCodec].getName) + def getCodecName(conf: SparkConf): String = { + conf.get(configKey, DEFAULT_COMPRESSION_CODEC) + } + def createCodec(conf: SparkConf): CompressionCodec = { - createCodec(conf, conf.get(configKey, DEFAULT_COMPRESSION_CODEC)) + createCodec(conf, getCodecName(conf)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) val codec = try { - val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) - .getConstructor(classOf[SparkConf]) + val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { case e: ClassNotFoundException => None @@ -71,6 +73,20 @@ private[spark] object CompressionCodec { s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) } + /** + * Return the short version of the given codec name. + * If it is already a short name, just return it. + */ + def getShortName(codecName: String): String = { + if (shortCompressionCodecNames.contains(codecName)) { + codecName + } else { + shortCompressionCodecNames + .collectFirst { case (k, v) if v == codecName => k } + .getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") } + } + } + val FALLBACK_COMPRESSION_CODEC = "lzf" val DEFAULT_COMPRESSION_CODEC = "snappy" val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq @@ -80,7 +96,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -90,7 +106,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -120,7 +136,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -136,9 +152,54 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) - new SnappyOutputStream(s, blockSize) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt + new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize)) } override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } + +/** + * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version + * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. + */ +private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends OutputStream { + + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte]): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b, off, len) + } + + override def flush(): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.flush() + } + + override def close(): Unit = { + if (!closed) { + closed = true + os.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/launcher/SparkSubmitArgumentsParser.scala b/core/src/main/scala/org/apache/spark/launcher/SparkSubmitArgumentsParser.scala new file mode 100644 index 000000000000..a83501253105 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/SparkSubmitArgumentsParser.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +/** + * This class makes SparkSubmitOptionParser visible for Spark code outside of the `launcher` + * package, since Java doesn't have a feature similar to `private[spark]`, and we don't want + * that class to be public. + */ +private[spark] abstract class SparkSubmitArgumentsParser extends SparkSubmitOptionParser diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala new file mode 100644 index 000000000000..0c096656f923 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.io.File +import java.util.{HashMap => JHashMap, List => JList, Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.spark.deploy.Command + +/** + * This class is used by CommandUtils. It uses some package-private APIs in SparkLauncher, and since + * Java doesn't have a feature similar to `private[spark]`, and we don't want that class to be + * public, needs to live in the same package as the rest of the library. + */ +private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, command: Command) + extends AbstractCommandBuilder { + + childEnv.putAll(command.environment.asJava) + childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sparkHome) + + override def buildCommand(env: JMap[String, String]): JList[String] = { + val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) + cmd.add(s"-Xms${memoryMb}M") + cmd.add(s"-Xmx${memoryMb}M") + command.javaOpts.foreach(cmd.add) + addPermGenSizeOpt(cmd) + addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) + cmd + } + + def buildCommand(): JList[String] = buildCommand(new JHashMap[String, String]()) + +} diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 21b782edd2a9..f405b732e472 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -17,9 +17,17 @@ package org.apache.spark.mapred +import java.io.IOException import java.lang.reflect.Modifier -import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.util.{Utils => SparkUtils} private[spark] trait SparkHadoopMapRedUtil { @@ -52,16 +60,99 @@ trait SparkHadoopMapRedUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + SparkUtils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + SparkUtils.classForName(second) + } + } +} + +object SparkHadoopMapRedUtil extends Logging { + /** + * Commits a task output. Before committing the task output, we need to know whether some other + * task attempt might be racing to commit the same output partition. Therefore, coordinate with + * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for + * details). + * + * Output commit coordinator is only contacted when the following two configurations are both set + * to `true`: + * + * - `spark.speculation` + * - `spark.hadoop.outputCommitCoordination.enabled` + */ + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int, + attemptId: Int): Unit = { + + val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) + + // Called after we have decided to commit + def performCommit(): Unit = { + try { + committer.commitTask(mrTaskContext) + logInfo(s"$mrTaskAttemptID: Committed") + } catch { + case cause: IOException => + logError(s"Error committing the output of task: $mrTaskAttemptID", cause) + committer.abortTask(mrTaskContext) + throw cause + } } + + // First, check whether the task's output has already been committed by some other attempt + if (committer.needsTaskCommit(mrTaskContext)) { + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + + if (canCommit) { + performCommit() + } else { + val message = + s"$mrTaskAttemptID: Not committed because the driver did not authorize commit" + logInfo(message) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + committer.abortTask(mrTaskContext) + throw new CommitDeniedException(message, jobId, splitId, attemptId) + } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() + } + } else { + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") + } + } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + sparkTaskContext: TaskContext): Unit = { + commitTask( + committer, + mrTaskContext, + sparkTaskContext.stageId(), + sparkTaskContext.partitionId(), + sparkTaskContext.attemptNumber()) } } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 3340673f9115..943ebcb7bd0a 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { @@ -45,8 +46,8 @@ trait SparkHadoopMapReduceUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") + attemptId: Int): TaskAttemptID = { + val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap // (not available in YARN) @@ -57,10 +58,10 @@ trait SparkHadoopMapReduceUtil { } catch { case exc: NoSuchMethodException => { // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if(isMap) "MAP" else "REDUCE") + taskTypeClass, if (isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + Utils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + Utils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 1b7a5d1f1980..dd2d325d8703 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -20,20 +20,21 @@ package org.apache.spark.metrics import java.io.{FileInputStream, InputStream} import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { +private[spark] class MetricsConfig(conf: SparkConf) extends Logging { - val DEFAULT_PREFIX = "*" - val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r - val METRICS_CONF = "metrics.properties" + private val DEFAULT_PREFIX = "*" + private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r + private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties" - val properties = new Properties() - var propertyCategories: mutable.HashMap[String, Properties] = null + private[metrics] val properties = new Properties() + private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null private def setDefaultProperties(prop: Properties) { prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet") @@ -46,44 +47,32 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi // Add default properties in case there's no properties file setDefaultProperties(properties) - // If spark.metrics.conf is not set, try to get file in class path - var is: InputStream = null - try { - is = configFile match { - case Some(f) => new FileInputStream(f) - case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF) - } + loadPropertiesFromFile(conf.getOption("spark.metrics.conf")) - if (is != null) { - properties.load(is) - } - } catch { - case e: Exception => logError("Error loading configure file", e) - } finally { - if (is != null) is.close() + // Also look for the properties in provided Spark configuration + val prefix = "spark.metrics.conf." + conf.getAll.foreach { + case (k, v) if k.startsWith(prefix) => + properties.setProperty(k.substring(prefix.length()), v) + case _ => } propertyCategories = subProperties(properties, INSTANCE_REGEX) if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) + val defaultProperty = propertyCategories(DEFAULT_PREFIX).asScala + for((inst, prop) <- propertyCategories if (inst != DEFAULT_PREFIX); + (k, v) <- defaultProperty if (prop.get(k) == null)) { + prop.put(k, v) } } } def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1).isDefined) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) + prop.asScala.foreach { kv => + if (regex.findPrefixOf(kv._1.toString).isDefined) { + val regex(prefix, suffix) = kv._1.toString + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2.toString) } } subProperties @@ -95,5 +84,31 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) } } -} + /** + * Loads configuration from a config file. If no config file is provided, try to get file + * in class path. + */ + private[this] def loadPropertiesFromFile(path: Option[String]): Unit = { + var is: InputStream = null + try { + is = path match { + case Some(f) => new FileInputStream(f) + case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME) + } + + if (is != null) { + properties.load(is) + } + } catch { + case e: Exception => + val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME) + logError(s"Error loading configuration file $file", e) + } finally { + if (is != null) { + is.close() + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 83e8eb71260e..4517f465ebd3 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,9 +20,12 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.util.Utils + import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} @@ -69,8 +72,7 @@ private[spark] class MetricsSystem private ( securityMgr: SecurityManager) extends Logging { - private[this] val confFile = conf.get("spark.metrics.conf", null) - private[this] val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val metricsConfig = new MetricsConfig(conf) private val sinks = new mutable.ArrayBuffer[Sink] private val sources = new mutable.ArrayBuffer[Source] @@ -84,7 +86,7 @@ private[spark] class MetricsSystem private ( /** * Get any UI handlers used by this metrics system; can only be called after start(). */ - def getServletHandlers = { + def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") metricsServlet.map(_.getHandlers).getOrElse(Array()) } @@ -140,6 +142,9 @@ private[spark] class MetricsSystem private ( } else { defaultName } } + def getSourcesByName(sourceName: String): Seq[Source] = + sources.filter(_.sourceName == sourceName) + def registerSource(source: Source) { sources += source try { @@ -166,7 +171,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Class.forName(classPath).newInstance() + val source = Utils.classForName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) @@ -182,7 +187,7 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Class.forName(classPath) + val sink = Utils.classForName(classPath) .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { @@ -191,7 +196,10 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) + case e: Exception => { + logError("Sink class " + classPath + " cannot be instantialized") + throw e + } } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 2f65bc8b4660..0c2e212a3307 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -30,8 +30,12 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils._ -private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class MetricsServlet( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -45,10 +49,12 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[ServletContextHandler]( - createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) - ) + def getHandlers: Array[ServletContextHandler] = { + Array[ServletContextHandler]( + createServletHandler(servletPath, + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + ) + } def getMetricsSnapshot(request: HttpServletRequest): String = { mapper.writeValueAsString(registry) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 0d83d8c425ca..9fad4e7deacb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { - def start: Unit - def stop: Unit + def start(): Unit + def stop(): Unit def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala new file mode 100644 index 000000000000..11dfcfe2f04e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import com.codahale.metrics.{Slf4jReporter, MetricRegistry} + +import org.apache.spark.SecurityManager +import org.apache.spark.metrics.MetricsSystem + +private[spark] class Slf4jSink( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SLF4J_DEFAULT_PERIOD = 10 + val SLF4J_DEFAULT_UNIT = "SECONDS" + + val SLF4J_KEY_PERIOD = "period" + val SLF4J_KEY_UNIT = "unit" + + val pollPeriod = Option(property.getProperty(SLF4J_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => SLF4J_DEFAULT_PERIOD + } + + val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) + } + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val reporter: Slf4jReporter = Slf4jReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .build() + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } + + override def report() { + reporter.report() + } +} + diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99e..670e68366332 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2..7c170a742fb6 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager @@ -55,7 +55,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) + val streamId = streamManager.registerStream(blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 3f0950dae1f2..ff8aae9ebe9f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,14 +17,14 @@ package org.apache.spark.network.netty -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} -import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock @@ -49,22 +49,32 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { - val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - if (!authEnabled) { - (nettyRpcHandler, None) - } else { - (new SaslRpcHandler(nettyRpcHandler, securityManager), - Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) - } + val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + var serverBootstrap: Option[TransportServerBootstrap] = None + var clientBootstrap: Option[TransportClientBootstrap] = None + if (authEnabled) { + serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager, + securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(bootstrap.toList) - server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0)) + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) + server = createServer(serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) } + /** Creates and binds the TransportServer, possibly trying multiple ports. */ + private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { + def startService(port: Int): (TransportServer, Int) = { + val server = transportContext.createServer(port, bootstraps.asJava) + (server, server.getPort) + } + + val portToTry = conf.getInt("spark.blockManager.port", 0) + Utils.startServiceOnPort(portToTry, startService, conf, getClass.getName)._1 + } + override def fetchBlocks( host: String, port: Int, diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index b573f1a8a5fc..79cb0640c867 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -57,16 +57,6 @@ private[nio] class BlockMessage() { } def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ typ = buffer.getInt() val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -110,7 +100,7 @@ private[nio] class BlockMessage() { def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data - def getLevel: StorageLevel = level + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -138,24 +128,12 @@ private[nio] class BlockMessage() { buffers += data } - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index a1a2c00ed154..f1c9ea8b64ca 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -32,27 +32,17 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int): BlockMessage = blockMessages(i) - def iterator = blockMessages.iterator + def iterator: Iterator[BlockMessage] = blockMessages.iterator - def length = blockMessages.length + def length: Int = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis val newBlockMessages = new ArrayBuffer[BlockMessage]() val buffer = bufferMessage.buffers(0) buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ while (buffer.remaining() > 0) { val size = buffer.getInt() logDebug("Creating block message of size " + size + " bytes") @@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } } -private[nio] object BlockMessageArray { +private[nio] object BlockMessageArray extends Logging { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -114,8 +92,8 @@ private[nio] object BlockMessageArray { val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear + val buffer = ByteBuffer.allocate(100) + buffer.clear() BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, StorageLevel.MEMORY_ONLY_SER)) } else { @@ -123,10 +101,10 @@ private[nio] object BlockMessageArray { } } val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") + logDebug("Block message array created") val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") + logDebug("Converted to buffer message") val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) @@ -138,10 +116,11 @@ private[nio] object BlockMessageArray { }) newBuffer.flip val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) + logDebug("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") + logDebug("Converted back to block message array") + // scalastyle:off println newBlockMessageArray.foreach(blockMessage => { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { @@ -154,6 +133,7 @@ private[nio] object BlockMessageArray { } } }) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index 3b245c5c7a4f..9a9e22b0c236 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -31,9 +31,9 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val initialSize = currentSize() var gotChunkForSendingOnce = false - def size = initialSize + def size: Int = initialSize - def currentSize() = { + def currentSize(): Int = { if (buffers == null || buffers.isEmpty) { 0 } else { @@ -100,11 +100,11 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: buffers.foreach(_.flip) } - def hasAckId() = (ackId != 0) + def hasAckId(): Boolean = ackId != 0 - def isCompletelyReceived() = !buffers(0).hasRemaining + def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - override def toString = { + override def toString: String = { if (hasAckId) { "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index c2d9578be7eb..8d9ebadaf79d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,7 +23,7 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -101,9 +101,11 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, socketRemoteConnectionManagerId } - def key() = channel.keyFor(selector) + def key(): SelectionKey = channel.keyFor(selector) - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + def getRemoteAddress(): InetSocketAddress = { + channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + } // Returns whether we have to register for further reads or not. def read(): Boolean = { @@ -143,7 +145,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, } def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks foreach { + onExceptionCallbacks.asScala.foreach { callback => try { callback(this, e) @@ -179,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { @@ -280,7 +282,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, /* channel.socket.setSendBufferSize(256 * 1024) */ - override def getRemoteAddress() = address + override def getRemoteAddress(): InetSocketAddress = address val DEFAULT_INTEREST = SelectionKey.OP_READ @@ -324,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index 764dc5e5503e..b3b281ff465f 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -18,7 +18,9 @@ package org.apache.spark.network.nio private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString: String = { + connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + } } private[nio] object ConnectionId { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index ee22c6656e69..914391879038 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -36,7 +36,7 @@ import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import scala.util.Try import scala.util.control.NonFatal @@ -79,17 +79,18 @@ private[nio] class ConnectionManager( private val selector = SelectorProvider.provider.openSelector() private val ackTimeoutMonitor = - new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) + new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = - conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120)) + conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", + conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. - // + // // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is + // + // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" // parameter is necessary. private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) @@ -101,7 +102,7 @@ private[nio] class ConnectionManager( handlerThreadCount, conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) { + ThreadUtils.namedThreadFactory("handle-message-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -116,7 +117,7 @@ private[nio] class ConnectionManager( ioThreadCount, conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) { + ThreadUtils.namedThreadFactory("handle-read-write-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -133,7 +134,7 @@ private[nio] class ConnectionManager( connectThreadCount, conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) { + ThreadUtils.namedThreadFactory("handle-connect-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -159,7 +160,7 @@ private[nio] class ConnectionManager( private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("Connection manager future execution context")) + ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) @volatile private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null @@ -187,8 +188,9 @@ private[nio] class ConnectionManager( private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + @volatile private var isActive = true private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() + override def run(): Unit = ConnectionManager.this.run() } selectorThread.setDaemon(true) // start this thread last, since it invokes run(), which accesses members above @@ -341,7 +343,7 @@ private[nio] class ConnectionManager( def run() { try { - while(!selectorThread.isInterrupted) { + while (isActive) { while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) @@ -397,7 +399,7 @@ private[nio] class ConnectionManager( } catch { // Explicitly only dealing with CancelledKeyException here since other exceptions // should be dealt with differently. - case e: CancelledKeyException => { + case e: CancelledKeyException => // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() @@ -419,8 +421,11 @@ private[nio] class ConnectionManager( } } } - } - 0 + 0 + + case e: ClosedSelectorException => + logDebug("Failed select() as selector is closed.", e) + return } if (selectedKeysCount == 0) { @@ -630,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } @@ -651,7 +655,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -795,7 +799,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() @@ -987,10 +991,11 @@ private[nio] class ConnectionManager( } def stop() { + isActive = false ackTimeoutMonitor.stop() + selector.close() selectorThread.interrupt() selectorThread.join() - selector.close() val connections = connectionsByKey.values connections.foreach(_.close()) if (connectionsByKey.size != 0) { @@ -1011,7 +1016,9 @@ private[spark] object ConnectionManager { val conf = new SparkConf val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // scalastyle:off println println("Received [" + msg + "] from [" + id + "]") + // scalastyle:on println None }) @@ -1028,6 +1035,7 @@ private[spark] object ConnectionManager { System.gc() } + // scalastyle:off println def testSequentialSending(manager: ConnectionManager) { println("--------------------------") println("Sequential Sending") @@ -1145,4 +1153,5 @@ private[spark] object ConnectionManager { println() } } + // scalastyle:on println } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index cbb37ec5ced1..1cd13d887c6f 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -26,7 +26,7 @@ private[nio] case class ConnectionManagerId(host: String, port: Int) { Utils.checkHost(host) assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) + def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index fb4a979b824c..85d2fe2bf9c2 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -42,7 +42,9 @@ private[nio] abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString: String = { + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + } } @@ -51,7 +53,7 @@ private[nio] object Message { var lastId = 1 - def getNewId() = synchronized { + def getNewId(): Int = synchronized { lastId += 1 if (lastId == 0) { lastId += 1 diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index 278c5ac356ef..a4568e849fa1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining + val size: Int = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { + lazy val buffers: ArrayBuffer[ByteBuffer] = { val ab = new ArrayBuffer[ByteBuffer]() ab += header.buffer if (buffer != null) { @@ -35,7 +35,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { ab } - override def toString = { + override def toString: String = { "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index 6e20f291c5ce..7b3da4bb9d5e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -50,8 +50,10 @@ private[nio] class MessageChunkHeader( flip.asInstanceOf[ByteBuffer] } - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + override def toString: String = { + "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 747a2088a725..232c552f9865 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -75,7 +75,7 @@ private[nio] class SecurityMessage extends Logging { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - connectionId = idBuilder.toString() + connectionId = idBuilder.toString() val tokenLength = buffer.getInt() token = new Array[Byte](tokenLength) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 5ad73c3d27f4..8ae76c5f72f2 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -27,8 +27,7 @@ package org.apache * contains operations available only on RDDs of Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can * be saved as SequenceFiles. These operations are automatically available on any RDD of the right - * type (e.g. RDD[(Int, Int)] through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * type (e.g. RDD[(Int, Int)] through implicit conversions. * * Java programmers should reference the [[org.apache.spark.api.java]] package * for Spark programming APIs in Java. @@ -44,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.3.0-SNAPSHOT" + val SPARK_VERSION = "1.5.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 3ef3cc219dec..5afce75680f9 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag @@ -32,12 +32,12 @@ import org.apache.spark.util.collection.OpenHashMap * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { + extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OpenHashMap[T,Long]() // Sum of counts for each key + var sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) @@ -48,9 +48,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf if (outputsMerged == totalOutputs) { val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => - result(key) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -64,9 +64,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf val stdev = math.sqrt(variance) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(key) = new BoundedDouble(mean, confidence, low, high) + result.put(key, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala index af26c3d59ac0..a16404068480 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub while (iter.hasNext) { val entry = iter.next() val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -72,9 +72,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub val confFactor = studentTCacher.get(counter.count) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala index 442fb86227d8..54a1beab3514 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl while (iter.hasNext) { val entry = iter.next() val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -80,9 +80,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl val confFactor = studentTCacher.get(counter.count) val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index cadd0c7ed19b..53c4b32c95ab 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -99,7 +99,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { case None => "(partial: " + initialValue + ")" } } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) + def getFinalValueInternal(): Option[T] = PartialResult.this.getFinalValueInternal().map(f) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac06..ca1eb1f4e4a9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,8 +19,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.util.ThreadUtils + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.ExecutionContext import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -33,7 +35,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Returns a future for counting the number of elements in the RDD. */ - def countAsync(): FutureAction[Long] = { + def countAsync(): FutureAction[Long] = self.withScope { val totalCount = new AtomicLong self.context.submitJob( self, @@ -45,7 +47,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -53,19 +55,21 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Returns a future for retrieving all elements of this RDD. */ - def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + def collectAsync(): FutureAction[Seq[T]] = self.withScope { + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } /** * Returns a future for retrieving the first num elements of the RDD. */ - def takeAsync(num: Int): FutureAction[Seq[T]] = { + def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { val f = new ComplexFutureAction[Seq[T]] f.run { + // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which + // is a cached thread pool. val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 @@ -81,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } @@ -101,7 +105,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi partsScanned += numPartsToTry } results.toSeq - } + }(AsyncRDDActions.futureExecutionContext) f } @@ -109,17 +113,22 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(f: T => Unit): FutureAction[Unit] = { + def foreachAsync(f: T => Unit): FutureAction[Unit] = self.withScope { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = self.withScope { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } + +private object AsyncRDDActions { + val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128)) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc..922030263756 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -31,12 +31,12 @@ private[spark] class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } @@ -54,7 +54,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPreferredLocations(split: Partition): Seq[String] = { assertValid() - locations_(split.asInstanceOf[BlockRDDPartition].blockId) + _locations(split.asInstanceOf[BlockRDDPartition].blockId) } /** @@ -79,14 +79,14 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds /** Check if this BlockRDD is valid. If not valid, exception is thrown. */ private[spark] def assertValid() { - if (!_isValid) { + if (!isValid) { throw new SparkException( "Attempted to use %s after its blocks have been removed!".format(toString)) } } protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = { - locations_ + _locations } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 1cbd684224b7..c1d697178757 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 1c13e2c37284..72fe215dae73 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -17,154 +17,31 @@ package org.apache.spark.rdd -import java.io.IOException - import scala.reflect.ClassTag -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path - -import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Partition, SparkContext, TaskContext} -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} +/** + * An RDD partition used to recover checkpointed data. + */ +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** - * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + * An RDD that recovers checkpointed data from storage. */ -private[spark] -class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) +private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) - - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) - - override def getPartitions: Array[Partition] = { - val cpath = new Path(checkpointPath) - val numPartitions = - // listStatus can throw exception if path does not exist. - if (fs.exists(cpath)) { - val dirContents = fs.listStatus(cpath).map(_.getPath) - val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size - if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } - numPart - } else 0 - - Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) - } - - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - - override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath, - CheckpointRDD.splitIdToFile(split.index))) - val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") - } - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, broadcastedConf, context) - } - - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } -} - -private[spark] object CheckpointRDD extends Logging { - def splitIdToFile(splitId: Int): String = { - "part-%05d".format(splitId) - } - - def writeToFile[T: ClassTag]( - path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], - blockSize: Int = -1 - )(ctx: TaskContext, iterator: Iterator[T]) { - val env = SparkEnv.get - val outputDir = new Path(path) - val fs = outputDir.getFileSystem(broadcastedConf.value.value) - - val finalOutputName = splitIdToFile(ctx.partitionId) - val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = - new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber) - - if (fs.exists(tempOutputPath)) { - throw new IOException("Checkpoint failed: temporary path " + - tempOutputPath + " already exists") - } - val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - - val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) - } else { - // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) - } - val serializer = env.serializer.newInstance() - val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() - - if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.exists(finalOutputPath)) { - logInfo("Deleting tempOutputPath " + tempOutputPath) - fs.delete(tempOutputPath, false) - throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptNumber + " and final output path does not exist") - } else { - // Some other copy of this task must've finished before us and renamed it - logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") - fs.delete(tempOutputPath, false) - } - } - } - - def readFromFile[T]( - path: Path, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], - context: TaskContext - ): Iterator[T] = { - val env = SparkEnv.get - val fs = path.getFileSystem(broadcastedConf.value.value) - val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) - val serializer = env.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) - - deserializeStream.asIterator.asInstanceOf[Iterator[T]] - } + // CheckpointRDD should not be checkpointed again + override def doCheckpoint(): Unit = { } + override def checkpoint(): Unit = { } + override def localCheckpoint(): this.type = this - // Test whether CheckpointRDD generate expected number of partitions despite - // each split file having multiple blocks. This needs to be run on a - // cluster (mesos or standalone) using HDFS. - def main(args: Array[String]) { - import org.apache.spark._ + // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the + // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods. + // scalastyle:off + protected override def getPartitions: Array[Partition] = ??? + override def compute(p: Partition, tc: TaskContext): Iterator[T] = ??? + // scalastyle:on - val Array(cluster, hdfsPath) = args - val env = SparkEnv.get - val sc = new SparkContext(cluster, "CheckpointRDD Test") - val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) - val path = new Path(hdfsPath, "temp") - val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) - val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) - sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) - val cpRDD = new CheckpointRDD[Int](sc, path.toString) - assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") - assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") - fs.delete(path, true) - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f..9c617fc719cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,21 +23,21 @@ import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} -import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleHandle - -private[spark] sealed trait CoGroupSplitDep extends Serializable +/** The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. */ private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, + @transient rdd: RDD[_], + @transient splitIndex: Int, var split: Partition - ) extends CoGroupSplitDep { + ) extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -47,9 +47,16 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep - -private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +/** + * Stores information about the narrow dependencies used by a CoGroupedRdd. + * + * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one + * dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing + * the partition for that dependency) at the corresponding index. The size of + * narrowDeps should always be equal to the number of parents. + */ +private[spark] class CoGroupPartition( + idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -86,28 +93,29 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => + rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner]( + rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + for ((dep, depNum) <- dependencies.zipWithIndex) dep match { + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => + val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent - val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(handle) => + case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager - .getReader(handle, split.index, split.index + 1, context) + .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } @@ -159,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index b073eba8a157..90d9735cb3f6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition( * the preferred location of each new partition overlaps with as many preferred locations of its * parent partitions * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD + * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive) * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ private[spark] class CoalescedRDD[T: ClassTag]( @@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag]( balanceSlack: Double = 0.10) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies + require(maxPartitions > 0 || maxPartitions == prev.partitions.length, + s"Number of partitions ($maxPartitions) must be positive.") + override def getPartitions: Array[Partition] = { val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) @@ -166,7 +169,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD @@ -186,7 +189,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: override val isEmpty = !it.hasNext // initializes/resets to start iterating from the beginning - def resetIterator() = { + def resetIterator(): Iterator[(String, Partition)] = { val iterators = (0 to 2).map( x => prev.partitions.iterator.flatMap(p => { if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None @@ -196,10 +199,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: } // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } + override def hasNext: Boolean = { !isEmpty } // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { + override def next(): (String, Partition) = { if (it.hasNext) { it.next() } else { @@ -237,7 +240,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: val rotIt = new LocationIterator(prev) // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { + if (!rotIt.hasNext) { (1 to targetLen).foreach(x => groupArr += PartitionGroup()) return } @@ -310,11 +313,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def throwBalls() { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { + for ((p, i) <- prev.partitions.zipWithIndex) { groupArr(i).arr += p } } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { + for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } @@ -343,7 +346,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: private case class PartitionGroup(prefLoc: Option[String] = None) { var arr = mutable.ArrayBuffer[Partition]() - def size = arr.size + def size: Int = arr.size } private object PartitionGroup { diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index e66f83bb34e3..926bce6f15a2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -30,47 +30,59 @@ import org.apache.spark.util.StatCounter */ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Add up the elements in this RDD. */ - def sum(): Double = { - self.reduce(_ + _) + def sum(): Double = self.withScope { + self.fold(0.0)(_ + _) } /** * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. */ - def stats(): StatCounter = { + def stats(): StatCounter = self.withScope { self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) } /** Compute the mean of this RDD's elements. */ - def mean(): Double = stats().mean + def mean(): Double = self.withScope { + stats().mean + } /** Compute the variance of this RDD's elements. */ - def variance(): Double = stats().variance + def variance(): Double = self.withScope { + stats().variance + } /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = stats().stdev + def stdev(): Double = self.withScope { + stats().stdev + } /** * Compute the sample standard deviation of this RDD's elements (which corrects for bias in * estimating the standard deviation by dividing by N-1 instead of N). */ - def sampleStdev(): Double = stats().sampleStdev + def sampleStdev(): Double = self.withScope { + stats().sampleStdev + } /** * Compute the sample variance of this RDD's elements (which corrects for bias in * estimating the variance by dividing by N-1 instead of N). */ - def sampleVariance(): Double = stats().sampleVariance + def sampleVariance(): Double = self.withScope { + stats().sampleVariance + } /** * :: Experimental :: * Approximate operation to return the mean within a timeout. */ @Experimental - def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + def meanApprox( + timeout: Long, + confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -79,9 +91,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * Approximate operation to return the sum within a timeout. */ @Experimental - def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + def sumApprox( + timeout: Long, + confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -93,7 +107,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope { // Scala's built-in range has issues. See #SI-8782 def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { val span = max - min @@ -140,7 +154,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * the maximum value of the last position and all NaN entries will be counted * in that bucket. */ - def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { + def histogram( + buckets: Array[Double], + evenBuckets: Boolean = false): Array[Long] = self.withScope { if (buckets.length < 2) { throw new IllegalArgumentException("buckets array must have at least two elements") } @@ -191,29 +207,34 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } - self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + if (self.partitions.length == 0) { + new Array[Long](buckets.length - 1) + } else { + // reduce() requires a non-empty RDD. This works because the mapPartitions will make + // non-empty partitions out of empty ones. But it doesn't handle the no-partitions case, + // which is below + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 89adddcf0ac3..e1f8719eead0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -99,8 +99,8 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - sc: SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + @transient sc: SparkContext, + broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -108,6 +108,10 @@ class HadoopRDD[K, V]( minPartitions: Int) extends RDD[(K, V)](sc, Nil) with Logging { + if (initLocalJobConfFuncOpt.isDefined) { + sc.clean(initLocalJobConfFuncOpt.get) + } + def this( sc: SparkContext, conf: JobConf, @@ -117,8 +121,8 @@ class HadoopRDD[K, V]( minPartitions: Int) = { this( sc, - sc.broadcast(new SerializableWritable(conf)) - .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + sc.broadcast(new SerializableConfiguration(conf)) + .asInstanceOf[Broadcast[SerializableConfiguration]], None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, @@ -215,8 +219,7 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -240,14 +243,16 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - override def getNext() = { + override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { case eof: EOFException => finished = true } - + if (!finished) { + inputMetrics.incRecordsRead(1) + } (key, value) } @@ -261,7 +266,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) @@ -269,7 +274,7 @@ class HadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } @@ -335,11 +340,11 @@ private[spark] object HadoopRDD extends Logging { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - def putCachedMetadata(key: String, value: Any) = + private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) /** Add Hadoop configuration specific to a single partition and attempt. */ @@ -369,7 +374,7 @@ private[spark] object HadoopRDD extends Logging { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[HadoopPartition] val inputSplit = partition.inputSplit.value f(inputSplit, firstParent[T].iterator(split, context)) @@ -378,11 +383,11 @@ private[spark] object HadoopRDD extends Logging { private[spark] class SplitInfoReflections { val inputSplitWithLocationInfo = - Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") val isInMemory = splitLocationInfo.getMethod("isInMemory") val getLocation = splitLocationInfo.getMethod("getLocation") } diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 642a12c1edf6..0c28f045e46e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.sql.{Connection, ResultSet} +import java.sql.{PreparedStatement, Connection, ResultSet} import scala.reflect.ClassTag @@ -28,8 +28,9 @@ import org.apache.spark.util.NextIterator import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx + override def index: Int = idx } + // TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** * An RDD that executes an SQL query on a JDBC connection and reads results. @@ -62,15 +63,16 @@ class JdbcRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { // bounds are inclusive, hence the + 1 here and - 1 on end - val length = 1 + upperBound - lowerBound + val length = BigInt(1) + upperBound - lowerBound (0 until numPartitions).map(i => { - val start = lowerBound + ((i * length) / numPartitions).toLong - val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 - new JdbcPartition(i, start, end) + val start = lowerBound + ((i * length) / numPartitions) + val end = lowerBound + (((i + 1) * length) / numPartitions) - 1 + new JdbcPartition(i, start.toLong, end.toLong) }).toArray } - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] + { context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() @@ -88,7 +90,7 @@ class JdbcRDD[T: ClassTag]( stmt.setLong(2, part.upper) val rs = stmt.executeQuery() - override def getNext: T = { + override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { @@ -99,21 +101,21 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) { + if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) { + if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! conn.isClosed()) { + if (null != conn) { conn.close() } logInfo("closed connection") diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala new file mode 100644 index 000000000000..daa5779d688c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext} +import org.apache.spark.storage.RDDBlockId + +/** + * A dummy CheckpointRDD that exists to provide informative error messages during failures. + * + * This is simply a placeholder because the original checkpointed RDD is expected to be + * fully cached. Only if an executor fails or if the user explicitly unpersists the original + * RDD will Spark ever attempt to compute this CheckpointRDD. When this happens, however, + * we must provide an informative error message. + * + * @param sc the active SparkContext + * @param rddId the ID of the checkpointed RDD + * @param numPartitions the number of partitions in the checkpointed RDD + */ +private[spark] class LocalCheckpointRDD[T: ClassTag]( + @transient sc: SparkContext, + rddId: Int, + numPartitions: Int) + extends CheckpointRDD[T](sc) { + + def this(rdd: RDD[T]) { + this(rdd.context, rdd.id, rdd.partitions.size) + } + + protected override def getPartitions: Array[Partition] = { + (0 until numPartitions).toArray.map { i => new CheckpointRDDPartition(i) } + } + + /** + * Throw an exception indicating that the relevant block is not found. + * + * This should only be called if the original RDD is explicitly unpersisted or if an + * executor is lost. Under normal circumstances, however, the original RDD (our child) + * is expected to be fully cached and so all partitions should already be computed and + * available in the block storage. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + throw new SparkException( + s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " + + s"that originally checkpointed this partition is no longer alive, or the original RDD is " + + s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " + + s"instead, which is slower than local checkpointing but more fault-tolerant.") + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala new file mode 100644 index 000000000000..d6fad896845f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * An implementation of checkpointing implemented on top of Spark's caching layer. + * + * Local checkpointing trades off fault tolerance for performance by skipping the expensive + * step of saving the RDD data to a reliable and fault-tolerant storage. Instead, the data + * is written to the local, ephemeral block storage that lives in each executor. This is useful + * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). + */ +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) + extends RDDCheckpointData[T](rdd) with Logging { + + /** + * Ensure the RDD is fully cached so the partitions can be recovered later. + */ + protected override def doCheckpoint(): CheckpointRDD[T] = { + val level = rdd.getStorageLevel + + // Assume storage level uses disk; otherwise memory eviction may cause data loss + assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing") + + // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we + // must cache any missing partitions. TODO: avoid running another job here (SPARK-8582). + val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator) + val missingPartitionIndices = rdd.partitions.map(_.index).filter { i => + !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i)) + } + if (missingPartitionIndices.nonEmpty) { + rdd.sparkContext.runJob(rdd, action, missingPartitionIndices) + } + + new LocalCheckpointRDD[T](rdd) + } + +} + +private[spark] object LocalRDDCheckpointData { + + val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK + + /** + * Transform the specified storage level to one that uses disk. + * + * This guarantees that the RDD can be recomputed multiple times correctly as long as + * executors do not fail. Otherwise, if the RDD is cached in memory only, for instance, + * the checkpoint data will be lost if the relevant block is evicted from memory. + * + * This method is idempotent. + */ + def transformStorageLevel(level: StorageLevel): StorageLevel = { + // If this RDD is to be cached off-heap, fail fast since we cannot provide any + // correctness guarantees about subsequent computations after the first one + if (level.useOffHeap) { + throw new SparkException("Local checkpointing is not compatible with off-heap caching.") + } + + StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4883fb828814..4312d3a41775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -21,6 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +/** + * An RDD that applies the provided function to every partition of the parent RDD. + */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) @@ -31,6 +34,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala new file mode 100644 index 000000000000..b475bd8d79f8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, Partitioner, TaskContext} + +/** + * An RDD that applies a user provided function to every partition of the parent RDD, and + * additionally allows the user to prepare each partition before computing the parent partition. + */ +private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag]( + prev: RDD[T], + preparePartition: () => M, + executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner: Option[Partitioner] = { + if (preservesPartitioning) firstParent[T].partitioner else None + } + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + /** + * Prepare a partition before computing it from its parent. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + val preparedArgument = preparePartition() + val parentIterator = firstParent[T].iterator(partition, context) + executePartition(context, partition.index, preparedArgument, parentIterator) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 44b9ffd2a53f..6a9c004d65cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -74,7 +74,7 @@ class NewHadoopRDD[K, V]( with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobTrackerId: String = { @@ -128,7 +128,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -141,6 +141,12 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -151,29 +157,36 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - + if (!finished) { + inputMetrics.incRecordsRead(1) + } (reader.getCurrentKey, reader.getCurrentValue) } private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + // Close reader and release it + reader.close() + reader = null + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } @@ -194,7 +207,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) @@ -236,7 +249,7 @@ private[spark] object NewHadoopRDD { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[NewHadoopPartition] val inputSplit = partition.serializableHadoopSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 144f679a5946..d71bb6300090 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,8 +56,8 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) - : RDD[(K, V)] = + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) + : RDD[(K, V)] = self.withScope { val part = new RangePartitioner(numPartitions, self, ascending) new ShuffledRDD[K, V, V](self, part) @@ -71,8 +71,31 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * This is more efficient than calling `repartition` and then sorting within each partition * because it can push the sorting down into the shuffle machinery. */ - def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = { + def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = self.withScope { new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) } + /** + * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`. + * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be + * performed efficiently by only scanning the partitions that might contain matching elements. + * Otherwise, a standard `filter` is applied to all partitions. + */ + def filterByRange(lower: K, upper: K): RDD[P] = self.withScope { + + def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) + + val rddToFilter: RDD[P] = self.partitioner match { + case Some(rp: RangePartitioner[K, V]) => { + val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { + case (l, u) => Math.min(l, u) to Math.max(l, u) + } + PartitionPruningRDD.create(self, partitionIndicies.contains) + } + case _ => + self + } + rddToFilter.filter { case (k, v) => inRange(k) } + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 49b88a90ab5a..4e5f2e8a5d46 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -22,19 +22,19 @@ import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} import scala.collection.{Map, mutable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.DynamicVariable import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter} + RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner @@ -44,7 +44,7 @@ import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -75,7 +75,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = { + serializer: Serializer = null): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -108,7 +108,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, - numPartitions: Int): RDD[(K, C)] = { + numPartitions: Int): RDD[(K, C)] = self.withScope { combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } @@ -122,7 +122,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, - combOp: (U, U) => U): RDD[(K, U)] = { + combOp: (U, U) => U): RDD[(K, U)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) val zeroArray = new Array[Byte](zeroBuffer.limit) @@ -131,7 +131,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) - combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) + // We will clean the combiner closure later in `combineByKey` + val cleanedSeqOp = self.context.clean(seqOp) + combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) } /** @@ -144,7 +146,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U, - combOp: (U, U) => U): RDD[(K, U)] = { + combOp: (U, U) => U): RDD[(K, U)] = self.withScope { aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp) } @@ -158,7 +160,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U, - combOp: (U, U) => U): RDD[(K, U)] = { + combOp: (U, U) => U): RDD[(K, U)] = self.withScope { aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp) } @@ -167,7 +169,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * may be added to the result an arbitrary number of times, and must not change the result * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ - def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { + def foldByKey( + zeroValue: V, + partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) val zeroArray = new Array[Byte](zeroBuffer.limit) @@ -177,7 +181,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) - combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) + val cleanedFunc = self.context.clean(func) + combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) } /** @@ -185,7 +190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * may be added to the result an arbitrary number of times, and must not change the result * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ - def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { + def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = self.withScope { foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) } @@ -194,7 +199,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * may be added to the result an arbitrary number of times, and must not change the result * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ - def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { + def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = self.withScope { foldByKey(zeroValue, defaultPartitioner(self))(func) } @@ -213,7 +218,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ def sampleByKey(withReplacement: Boolean, fractions: Map[K, Double], - seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = self.withScope { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") @@ -242,9 +247,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @return RDD containing the sampled subset */ @Experimental - def sampleByKeyExact(withReplacement: Boolean, + def sampleByKeyExact( + withReplacement: Boolean, fractions: Map[K, Double], - seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = self.withScope { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") @@ -261,7 +267,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * the merging locally on each mapper before sending results to a reducer, similarly to a * "combiner" in MapReduce. */ - def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { + def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKey[V]((v: V) => v, func, func, partitioner) } @@ -270,7 +276,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * the merging locally on each mapper before sending results to a reducer, similarly to a * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = self.withScope { reduceByKey(new HashPartitioner(numPartitions), func) } @@ -280,7 +286,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ * parallelism level. */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope { reduceByKey(defaultPartitioner(self), func) } @@ -289,7 +295,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * immediately to the master as a Map. This will also perform the merging locally on each mapper * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ - def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { + val cleanedF = self.sparkContext.clean(func) if (keyClass.isArray) { throw new SparkException("reduceByKeyLocally() does not support array keys") @@ -299,27 +306,29 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val map = new JHashMap[K, V] iter.foreach { pair => val old = map.get(pair._1) - map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + map.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } Iterator(map) } : Iterator[JHashMap[K, V]] val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { - m2.foreach { pair => + m2.asScala.foreach { pair => val old = m1.get(pair._1) - m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] - self.mapPartitions(reducePartition).reduce(mergeMaps) + self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } /** Alias for reduceByKeyLocally */ @deprecated("Use reduceByKeyLocally", "1.0.0") - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = self.withScope { + reduceByKeyLocally(func) + } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as @@ -327,7 +336,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ - def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap + def countByKey(): Map[K, Long] = self.withScope { + self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap + } /** * :: Experimental :: @@ -336,7 +347,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[Map[K, BoundedDouble]] = { + : PartialResult[Map[K, BoundedDouble]] = self.withScope { self.map(_._1).countByValueApprox(timeout, confidence) } @@ -360,7 +371,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param partitioner Partitioner to use for the resulting RDD. */ @Experimental - def countApproxDistinctByKey(p: Int, sp: Int, partitioner: Partitioner): RDD[(K, Long)] = { + def countApproxDistinctByKey( + p: Int, + sp: Int, + partitioner: Partitioner): RDD[(K, Long)] = self.withScope { require(p >= 4, s"p ($p) must be >= 4") require(sp <= 32, s"sp ($sp) must be <= 32") require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)") @@ -392,7 +406,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * It must be greater than 0.000017. * @param partitioner partitioner of the resulting RDD */ - def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = { + def countApproxDistinctByKey( + relativeSD: Double, + partitioner: Partitioner): RDD[(K, Long)] = self.withScope { require(relativeSD > 0.000017, s"accuracy ($relativeSD) must be greater than 0.000017") val p = math.ceil(2.0 * math.log(1.054 / relativeSD) / math.log(2)).toInt assert(p <= 32) @@ -410,7 +426,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * It must be greater than 0.000017. * @param numPartitions number of partitions of the resulting RDD */ - def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): RDD[(K, Long)] = { + def countApproxDistinctByKey( + relativeSD: Double, + numPartitions: Int): RDD[(K, Long)] = self.withScope { countApproxDistinctByKey(relativeSD, new HashPartitioner(numPartitions)) } @@ -424,7 +442,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. */ - def countApproxDistinctByKey(relativeSD: Double = 0.05): RDD[(K, Long)] = { + def countApproxDistinctByKey(relativeSD: Double = 0.05): RDD[(K, Long)] = self.withScope { countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) } @@ -441,7 +459,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ - def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = { + def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. @@ -449,7 +467,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 val bufs = combineByKey[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -465,14 +483,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ - def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = { + def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { groupByKey(new HashPartitioner(numPartitions)) } /** * Return a copy of the RDD partitioned using the specified partitioner. */ - def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { + def partitionBy(partitioner: Partitioner): RDD[(K, V)] = self.withScope { if (keyClass.isArray && partitioner.isInstanceOf[HashPartitioner]) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -488,7 +506,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ - def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { + def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues( pair => for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w) ) @@ -500,7 +518,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to * partition the output RDD. */ - def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { + def leftOuterJoin[W]( + other: RDD[(K, W)], + partitioner: Partitioner): RDD[(K, (V, Option[W]))] = self.withScope { this.cogroup(other, partitioner).flatMapValues { pair => if (pair._2.isEmpty) { pair._1.iterator.map(v => (v, None)) @@ -517,7 +537,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * partition the output RDD. */ def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], W))] = { + : RDD[(K, (Option[V], W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues { pair => if (pair._1.isEmpty) { pair._2.iterator.map(w => (None, w)) @@ -536,7 +556,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * in `this` have key k. Uses the given Partitioner to partition the output RDD. */ def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], Option[W]))] = { + : RDD[(K, (Option[V], Option[W]))] = self.withScope { this.cogroup(other, partitioner).flatMapValues { case (vs, Seq()) => vs.iterator.map(v => (Some(v), None)) case (Seq(), ws) => ws.iterator.map(w => (None, Some(w))) @@ -549,7 +569,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * existing partitioner/parallelism level. */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { + : RDD[(K, C)] = self.withScope { combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } @@ -563,7 +583,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ - def groupByKey(): RDD[(K, Iterable[V])] = { + def groupByKey(): RDD[(K, Iterable[V])] = self.withScope { groupByKey(defaultPartitioner(self)) } @@ -572,7 +592,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ - def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope { join(other, defaultPartitioner(self, other)) } @@ -581,7 +601,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ - def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { + def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = self.withScope { join(other, new HashPartitioner(numPartitions)) } @@ -591,7 +611,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output * using the existing partitioner/parallelism level. */ - def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { + def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = self.withScope { leftOuterJoin(other, defaultPartitioner(self, other)) } @@ -601,7 +621,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output * into `numPartitions` partitions. */ - def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { + def leftOuterJoin[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (V, Option[W]))] = self.withScope { leftOuterJoin(other, new HashPartitioner(numPartitions)) } @@ -611,7 +633,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting * RDD using the existing partitioner/parallelism level. */ - def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { + def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = self.withScope { rightOuterJoin(other, defaultPartitioner(self, other)) } @@ -621,7 +643,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting * RDD into the given number of partitions. */ - def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { + def rightOuterJoin[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (Option[V], W))] = self.withScope { rightOuterJoin(other, new HashPartitioner(numPartitions)) } @@ -634,7 +658,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/ * parallelism level. */ - def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = self.withScope { fullOuterJoin(other, defaultPartitioner(self, other)) } @@ -646,7 +670,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions. */ - def fullOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = self.withScope { fullOuterJoin(other, new HashPartitioner(numPartitions)) } @@ -656,7 +682,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only * one value per key is preserved in the map returned) */ - def collectAsMap(): Map[K, V] = { + def collectAsMap(): Map[K, V] = self.withScope { val data = self.collect() val map = new mutable.HashMap[K, V] map.sizeHint(data.length) @@ -668,7 +694,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Pass each value in the key-value pair RDD through a map function without changing the keys; * this also retains the original RDD's partitioning. */ - def mapValues[U](f: V => U): RDD[(K, U)] = { + def mapValues[U](f: V => U): RDD[(K, U)] = self.withScope { val cleanF = self.context.clean(f) new MapPartitionsRDD[(K, U), (K, V)](self, (context, pid, iter) => iter.map { case (k, v) => (k, cleanF(v)) }, @@ -679,7 +705,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Pass each value in the key-value pair RDD through a flatMap function without changing the * keys; this also retains the original RDD's partitioning. */ - def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { + def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = self.withScope { val cleanF = self.context.clean(f) new MapPartitionsRDD[(K, U), (K, V)](self, (context, pid, iter) => iter.flatMap { case (k, v) => @@ -697,7 +723,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) other2: RDD[(K, W2)], other3: RDD[(K, W3)], partitioner: Partitioner) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -715,7 +741,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Iterable[V], Iterable[W]))] = { + : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -730,7 +756,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -748,7 +774,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * for that key in `this`, `other1`, `other2` and `other3`. */ def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) } @@ -756,7 +782,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { cogroup(other, defaultPartitioner(self, other)) } @@ -765,7 +791,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } @@ -773,7 +799,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { cogroup(other, new HashPartitioner(numPartitions)) } @@ -782,7 +810,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { cogroup(other1, other2, new HashPartitioner(numPartitions)) } @@ -795,24 +823,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) other2: RDD[(K, W2)], other3: RDD[(K, W3)], numPartitions: Int) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) } /** Alias for cogroup. */ - def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { + def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { cogroup(other, defaultPartitioner(self, other)) } /** Alias for cogroup. */ def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = self.withScope { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } /** Alias for cogroup. */ def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) - : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = self.withScope { cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) } @@ -822,22 +850,27 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = self.withScope { + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) + } /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = + def subtractByKey[W: ClassTag]( + other: RDD[(K, W)], + numPartitions: Int): RDD[(K, V)] = self.withScope { subtractByKey(other, new HashPartitioner(numPartitions)) + } /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = self.withScope { new SubtractedRDD[K, V, W](self, other, p) + } /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): Seq[V] = { + def lookup(key: K): Seq[V] = self.withScope { self.partitioner match { case Some(p) => val index = p.getPartition(key) @@ -848,7 +881,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } buf } : Seq[V] - val res = self.context.runJob(self, process, Array(index), false) + val res = self.context.runJob(self, process, Array(index)) res(0) case None => self.filter(_._1 == key).map(_._2).collect() @@ -859,7 +892,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String)(implicit fm: ClassTag[F]): Unit = self.withScope { saveAsHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -869,7 +903,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * supplied codec. */ def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { + path: String, + codec: Class[_ <: CompressionCodec])(implicit fm: ClassTag[F]): Unit = self.withScope { val runtimeClass = fm.runtimeClass saveAsHadoopFile(path, keyClass, valueClass, runtimeClass.asInstanceOf[Class[F]], codec) } @@ -878,7 +913,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]]( + path: String)(implicit fm: ClassTag[F]): Unit = self.withScope { saveAsNewAPIHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -891,8 +927,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) - { + conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = new NewAPIHadoopJob(hadoopConf) @@ -912,7 +947,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - codec: Class[_ <: CompressionCodec]) { + codec: Class[_ <: CompressionCodec]): Unit = self.withScope { saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, new JobConf(self.context.hadoopConfiguration), Some(codec)) } @@ -927,7 +962,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(self.context.hadoopConfiguration), - codec: Option[Class[_ <: CompressionCodec]] = None) { + codec: Option[Class[_ <: CompressionCodec]] = None): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) @@ -960,14 +995,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. */ - def saveAsNewAPIHadoopDataset(conf: Configuration) { + def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = new NewAPIHadoopJob(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -976,7 +1011,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, @@ -992,9 +1027,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] - try { - var recordsWritten = 0L + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] + require(writer != null, "Unable to obtain RecordWriter") + var recordsWritten = 0L + Utils.tryWithSafeFinally { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1003,11 +1039,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) 1 } : Int @@ -1025,10 +1062,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop * MapReduce job. */ - def saveAsHadoopDataset(conf: JobConf) { + def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableWritable(hadoopConf) + val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1065,8 +1102,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() - try { - var recordsWritten = 0L + var recordsWritten = 0L + + Utils.tryWithSafeFinally { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1075,11 +1113,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close() } writer.commit() bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } self.context.runJob(self, writeToFile) @@ -1097,9 +1136,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 - && bytesWrittenCallback.isDefined) { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index f12d0cffaba3..e2394e28f8d2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -98,7 +98,7 @@ private[spark] class ParallelCollectionRDD[T: ClassTag]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = { + override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index f781a8d776f2..a00f4c1cdff9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -40,7 +40,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) } } @@ -59,8 +59,10 @@ class PartitionPruningRDD[T: ClassTag]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + } override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions @@ -74,7 +76,7 @@ object PartitionPruningRDD { * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD * when its type T is not known at compile time. */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean): PartitionPruningRDD[T] = { new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 92b0641d0fb6..9e3880714a79 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -60,6 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( var rdds: Seq[RDD[T]] ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { require(rdds.length > 0) + require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) @@ -85,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index ed79032893d3..afbe566b7656 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -23,7 +23,7 @@ import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -72,7 +72,7 @@ private[spark] class PipedRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) + val pb = new ProcessBuilder(command.asJava) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -81,7 +81,7 @@ private[spark] class PipedRDD[T: ClassTag]( // so the user code can access the input filename if (split.isInstanceOf[HadoopPartition]) { val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) } // When spark.worker.separated.working.directory option is turned on, each @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -131,8 +133,10 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { + TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,15 +148,16 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines + val lines = Source.fromInputStream(proc.getInputStream).getLines() new Iterator[String] { - def next() = lines.next() - def hasNext = { + def next(): String = lines.next() + def hasNext: Boolean = { if (lines.hasNext) { true } else { @@ -162,7 +167,7 @@ private[spark] class PipedRDD[T: ClassTag]( } // cleanup task working directory if used - if (workInTaskDirectory == true) { + if (workInTaskDirectory) { scala.util.control.Exception.ignoring(classOf[IOException]) { Utils.deleteRecursively(new File(taskDirectory)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97aee58bddbf..081c721f2368 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -25,11 +25,8 @@ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.{BytesWritable, NullWritable, Text} import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -57,8 +54,7 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] - * through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * through implicit. * * Internally, each RDD is characterized by five main properties: * @@ -153,23 +149,43 @@ abstract class RDD[T: ClassTag]( } /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. + * Mark this RDD for persisting using the specified level. + * + * @param newLevel the target storage level + * @param allowOverride whether to override any existing level with the new one */ - def persist(newLevel: StorageLevel): this.type = { + private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = { // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { + if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) { throw new UnsupportedOperationException( "Cannot change storage level of an RDD after it was already assigned a level") } - sc.persistRDD(this) - // Register the RDD with the ContextCleaner for automatic GC-based cleanup - sc.cleaner.foreach(_.registerRDDForCleanup(this)) + // If this is the first time this RDD is marked for persisting, register it + // with the SparkContext for cleanups and accounting. Do this only once. + if (storageLevel == StorageLevel.NONE) { + sc.cleaner.foreach(_.registerRDDForCleanup(this)) + sc.persistRDD(this) + } storageLevel = newLevel this } + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet. Local checkpointing is an exception. + */ + def persist(newLevel: StorageLevel): this.type = { + if (isLocallyCheckpointed) { + // This means the user previously called localCheckpoint(), which should have already + // marked this RDD for persisting. Here we should override the old storage level with + // one that is explicitly requested by the user (after adapting it to use disk). + persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true) + } else { + persist(newLevel, allowOverride = false) + } + } + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) @@ -190,7 +206,7 @@ abstract class RDD[T: ClassTag]( } /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel + def getStorageLevel: StorageLevel = storageLevel // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed @@ -198,7 +214,7 @@ abstract class RDD[T: ClassTag]( @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) /** * Get the list of dependencies of this RDD, taking into account whether the @@ -281,12 +297,20 @@ abstract class RDD[T: ClassTag]( if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) } + /** + * Execute a block of code in a scope such that all new RDDs created in this body will + * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](sc)(body) + // Transformations (return a new RDD) /** * Return a new RDD by applying a function to all elements of this RDD. */ - def map[U: ClassTag](f: T => U): RDD[U] = { + def map[U: ClassTag](f: T => U): RDD[U] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF)) } @@ -295,7 +319,7 @@ abstract class RDD[T: ClassTag]( * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. */ - def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = { + def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF)) } @@ -303,7 +327,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing only the elements that satisfy a predicate. */ - def filter(f: T => Boolean): RDD[T] = { + def filter(f: T => Boolean): RDD[T] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[T, T]( this, @@ -314,13 +338,16 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = + def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + } /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = withScope { + distinct(partitions.length) + } /** * Return a new RDD that has exactly numPartitions partitions. @@ -331,7 +358,7 @@ abstract class RDD[T: ClassTag]( * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. */ - def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = { + def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) } @@ -356,7 +383,7 @@ abstract class RDD[T: ClassTag]( * data distributed using a hash partitioner. */ def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null) - : RDD[T] = { + : RDD[T] = withScope { if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { @@ -381,10 +408,17 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param seed seed for the random number generator */ - def sample(withReplacement: Boolean, + def sample( + withReplacement: Boolean, fraction: Double, - seed: Long = Utils.random.nextLong): RDD[T] = { + seed: Long = Utils.random.nextLong): RDD[T] = withScope { require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) @@ -401,15 +435,32 @@ abstract class RDD[T: ClassTag]( * * @return split RDDs in an array */ - def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[RDD[T]] = { + def randomSplit( + weights: Array[Double], + seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope { val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T]( - this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) + randomSampleWithRange(x(0), x(1), seed) }.toArray } + /** + * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability + * range. + * @param lb lower bound to use for the Bernoulli sampler + * @param ub upper bound to use for the Bernoulli sampler + * @param seed the seed for the Random number generator + * @return A random sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { + this.mapPartitionsWithIndex( { (index, partition) => + val sampler = new BernoulliCellSampler[T](lb, ub) + sampler.setSeed(seed + index) + sampler.sample(partition) + }, preservesPartitioning = true) + } + /** * Return a fixed-size sampled subset of this RDD in an array * @@ -418,10 +469,12 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - def takeSample(withReplacement: Boolean, + // TODO: rewrite this without return statements so we can wrap it in a scope + def takeSample( + withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - val numStDev = 10.0 + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") @@ -466,13 +519,21 @@ abstract class RDD[T: ClassTag]( * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + def union(other: RDD[T]): RDD[T] = withScope { + if (partitioner.isDefined && other.partitioner == partitioner) { + new PartitionerAwareUnionRDD(sc, Array(this, other)) + } else { + new UnionRDD(sc, Array(this, other)) + } + } /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def ++(other: RDD[T]): RDD[T] = this.union(other) + def ++(other: RDD[T]): RDD[T] = withScope { + this.union(other) + } /** * Return this RDD sorted by the given key function. @@ -480,11 +541,12 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) - (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = + numPartitions: Int = this.partitions.length) + (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = withScope { this.keyBy[K](f) .sortByKey(ascending, numPartitions) .values + } /** * Return the intersection of this RDD and another one. The output will not contain any duplicate @@ -492,7 +554,7 @@ abstract class RDD[T: ClassTag]( * * Note that this method performs a shuffle internally. */ - def intersection(other: RDD[T]): RDD[T] = { + def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } .keys @@ -506,8 +568,9 @@ abstract class RDD[T: ClassTag]( * * @param partitioner Partitioner to use for the resulting RDD */ - def intersection(other: RDD[T], partitioner: Partitioner)(implicit ord: Ordering[T] = null) - : RDD[T] = { + def intersection( + other: RDD[T], + partitioner: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner) .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } .keys @@ -521,16 +584,14 @@ abstract class RDD[T: ClassTag]( * * @param numPartitions How many partitions to use in the resulting RDD */ - def intersection(other: RDD[T], numPartitions: Int): RDD[T] = { - this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions)) - .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } - .keys + def intersection(other: RDD[T], numPartitions: Int): RDD[T] = withScope { + intersection(other, new HashPartitioner(numPartitions)) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ - def glom(): RDD[Array[T]] = { + def glom(): RDD[Array[T]] = withScope { new MapPartitionsRDD[Array[T], T](this, (context, pid, iter) => Iterator(iter.toArray)) } @@ -538,7 +599,9 @@ abstract class RDD[T: ClassTag]( * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of * elements (a, b) where a is in `this` and b is in `other`. */ - def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = withScope { + new CartesianRDD(sc, this, other) + } /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements @@ -549,8 +612,9 @@ abstract class RDD[T: ClassTag]( * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ - def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = + def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) + } /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements @@ -561,8 +625,11 @@ abstract class RDD[T: ClassTag]( * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ - def groupBy[K](f: T => K, numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = + def groupBy[K]( + f: T => K, + numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy(f, new HashPartitioner(numPartitions)) + } /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements @@ -574,7 +641,7 @@ abstract class RDD[T: ClassTag]( * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) - : RDD[(K, Iterable[T])] = { + : RDD[(K, Iterable[T])] = withScope { val cleanF = sc.clean(f) this.map(t => (cleanF(t), t)).groupByKey(p) } @@ -582,13 +649,16 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String): RDD[String] = new PipedRDD(this, command) + def pipe(command: String): RDD[String] = withScope { + new PipedRDD(this, command) + } /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String, env: Map[String, String]): RDD[String] = + def pipe(command: String, env: Map[String, String]): RDD[String] = withScope { new PipedRDD(this, command, env) + } /** * Return an RDD created by piping elements to a forked external process. @@ -596,7 +666,7 @@ abstract class RDD[T: ClassTag]( * * @param command command to run in forked process. * @param env environment variables to set. - * @param printPipeContext Before piping elements, this function is called as an oppotunity + * @param printPipeContext Before piping elements, this function is called as an opportunity * to pipe context data. Print line function (like out.println) will be * passed as printPipeContext's parameter. * @param printRDDElement Use this function to customize how to pipe elements. This function @@ -614,7 +684,7 @@ abstract class RDD[T: ClassTag]( env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, printRDDElement: (T, String => Unit) => Unit = null, - separateWorkingDir: Boolean = false): RDD[String] = { + separateWorkingDir: Boolean = false): RDD[String] = withScope { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null, @@ -628,9 +698,13 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitions[U: ClassTag]( - f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + val cleanedF = sc.clean(f) + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter), + preservesPartitioning) } /** @@ -641,9 +715,13 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitionsWithIndex[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + val cleanedF = sc.clean(f) + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + preservesPartitioning) } /** @@ -658,8 +736,9 @@ abstract class RDD[T: ClassTag]( @deprecated("use TaskContext.get", "1.2.0") def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + preservesPartitioning: Boolean = false): RDD[U] = withScope { + val cleanF = sc.clean(f) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter) new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } @@ -669,7 +748,8 @@ abstract class RDD[T: ClassTag]( */ @deprecated("use mapPartitionsWithIndex", "0.7.0") def mapPartitionsWithSplit[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { mapPartitionsWithIndex(f, preservesPartitioning) } @@ -681,10 +761,12 @@ abstract class RDD[T: ClassTag]( @deprecated("use mapPartitionsWithIndex", "1.0.0") def mapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => U): RDD[U] = { + (f: (T, A) => U): RDD[U] = withScope { + val cleanF = sc.clean(f) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex((index, iter) => { - val a = constructA(index) - iter.map(t => f(t, a)) + val a = cleanA(index) + iter.map(t => cleanF(t, a)) }, preservesPartitioning) } @@ -696,10 +778,12 @@ abstract class RDD[T: ClassTag]( @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0") def flatMapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => Seq[U]): RDD[U] = { + (f: (T, A) => Seq[U]): RDD[U] = withScope { + val cleanF = sc.clean(f) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex((index, iter) => { - val a = constructA(index) - iter.flatMap(t => f(t, a)) + val a = cleanA(index) + iter.flatMap(t => cleanF(t, a)) }, preservesPartitioning) } @@ -709,11 +793,13 @@ abstract class RDD[T: ClassTag]( * partition with the index of that partition. */ @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") - def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit) { + def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { + val cleanF = sc.clean(f) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex { (index, iter) => - val a = constructA(index) - iter.map(t => {f(t, a); t}) - }.foreach(_ => {}) + val a = cleanA(index) + iter.map(t => {cleanF(t, a); t}) + } } /** @@ -722,10 +808,12 @@ abstract class RDD[T: ClassTag]( * partition with the index of that partition. */ @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") - def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope { + val cleanP = sc.clean(p) + val cleanA = sc.clean(constructA) mapPartitionsWithIndex((index, iter) => { - val a = constructA(index) - iter.filter(t => p(t, a)) + val a = cleanA(index) + iter.filter(t => cleanP(t, a)) }, preservesPartitioning = true) } @@ -735,16 +823,16 @@ abstract class RDD[T: ClassTag]( * partitions* and the *same number of elements in each partition* (e.g. one was made through * a map on the other). */ - def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { + def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = withScope { zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { - def hasNext = (thisIter.hasNext, otherIter.hasNext) match { + def hasNext: Boolean = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true case (false, false) => false case _ => throw new SparkException("Can only zip RDDs with " + "same number of elements in each partition") } - def next = (thisIter.next, otherIter.next) + def next(): (T, U) = (thisIter.next(), otherIter.next()) } } } @@ -757,33 +845,39 @@ abstract class RDD[T: ClassTag]( */ def zipPartitions[B: ClassTag, V: ClassTag] (rdd2: RDD[B], preservesPartitioning: Boolean) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = withScope { new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning) + } def zipPartitions[B: ClassTag, V: ClassTag] (rdd2: RDD[B]) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false) + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = withScope { + zipPartitions(rdd2, preservesPartitioning = false)(f) + } def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = withScope { new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning) + } def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C]) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = withScope { + zipPartitions(rdd2, rdd3, preservesPartitioning = false)(f) + } def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = withScope { new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning) + } def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = withScope { + zipPartitions(rdd2, rdd3, rdd4, preservesPartitioning = false)(f) + } // Actions (launch a job to return a value to the user program) @@ -791,7 +885,7 @@ abstract class RDD[T: ClassTag]( /** * Applies a function f to all elements of this RDD. */ - def foreach(f: T => Unit) { + def foreach(f: T => Unit): Unit = withScope { val cleanF = sc.clean(f) sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } @@ -799,7 +893,7 @@ abstract class RDD[T: ClassTag]( /** * Applies a function f to each partition of this RDD. */ - def foreachPartition(f: Iterator[T] => Unit) { + def foreachPartition(f: Iterator[T] => Unit): Unit = withScope { val cleanF = sc.clean(f) sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } @@ -807,7 +901,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. */ - def collect(): Array[T] = { + def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) Array.concat(results: _*) } @@ -816,10 +910,14 @@ abstract class RDD[T: ClassTag]( * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. + * + * Note: this results in multiple Spark jobs, and if the input RDD is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input RDD should be cached first. */ - def toLocalIterator: Iterator[T] = { + def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } @@ -828,13 +926,16 @@ abstract class RDD[T: ClassTag]( * Return an array that contains all of the elements in this RDD. */ @deprecated("use collect", "1.0.0") - def toArray(): Array[T] = collect() + def toArray(): Array[T] = withScope { + collect() + } /** * Return an RDD that contains all matching values by applying `f`. */ - def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = { - filter(f.isDefinedAt).map(f) + def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope { + val cleanF = sc.clean(f) + filter(cleanF.isDefinedAt).map(cleanF) } /** @@ -843,25 +944,29 @@ abstract class RDD[T: ClassTag]( * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + def subtract(other: RDD[T]): RDD[T] = withScope { + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) + } /** * Return an RDD with the elements from `this` that are not in `other`. */ - def subtract(other: RDD[T], numPartitions: Int): RDD[T] = + def subtract(other: RDD[T], numPartitions: Int): RDD[T] = withScope { subtract(other, new HashPartitioner(numPartitions)) + } /** * Return an RDD with the elements from `this` that are not in `other`. */ - def subtract(other: RDD[T], p: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = { + def subtract( + other: RDD[T], + p: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = withScope { if (partitioner == Some(p)) { // Our partitioner knows how to handle T (which, since we have a partitioner, is // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + override def numPartitions: Int = p.numPartitions + override def getPartition(k: Any): Int = p.getPartition(k.asInstanceOf[(Any, _)]._1) } // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies // anyway, and when calling .keys, will not have a partitioner set, even though @@ -877,7 +982,7 @@ abstract class RDD[T: ClassTag]( * Reduces the elements of this RDD using the specified commutative and * associative binary operator. */ - def reduce(f: (T, T) => T): T = { + def reduce(f: (T, T) => T): T = withScope { val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) { @@ -906,7 +1011,7 @@ abstract class RDD[T: ClassTag]( * @param depth suggested depth of the tree (default: 2) * @see [[org.apache.spark.rdd.RDD#reduce]] */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + def treeReduce(f: (T, T) => T, depth: Int = 2): T = withScope { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") val cleanF = context.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { @@ -934,11 +1039,18 @@ abstract class RDD[T: ClassTag]( /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. - */ - def fold(zeroValue: T)(op: (T, T) => T): T = { + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. + */ + def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanOp = sc.clean(op) @@ -956,9 +1068,9 @@ abstract class RDD[T: ClassTag]( * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. */ - def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope { // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) @@ -976,26 +1088,31 @@ abstract class RDD[T: ClassTag]( def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, - depth: Int = 2): U = { + depth: Int = 2): U = withScope { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (partitions.size == 0) { - return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = context.clean(seqOp) - val cleanCombOp = context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + if (partitions.length == 0) { + Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } else { + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = + (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.length + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce + // the wall-clock time, we stop tree aggregation. + + // Don't trigger TreeAggregation when it doesn't save wall-clock time + while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { + (i, iter) => iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) } - partiallyAggregated.reduce(cleanCombOp) } /** @@ -1009,7 +1126,9 @@ abstract class RDD[T: ClassTag]( * within a timeout, even if not all tasks have finished. */ @Experimental - def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + def countApprox( + timeout: Long, + confidence: Double = 0.95): PartialResult[BoundedDouble] = withScope { val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => var result = 0L while (iter.hasNext) { @@ -1018,7 +1137,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -1030,7 +1149,7 @@ abstract class RDD[T: ClassTag]( * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ - def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { + def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = withScope { map(value => (value, null)).countByKey() } @@ -1041,19 +1160,18 @@ abstract class RDD[T: ClassTag]( @Experimental def countByValueApprox(timeout: Long, confidence: Double = 0.95) (implicit ord: Ordering[T] = null) - : PartialResult[Map[T, BoundedDouble]] = - { + : PartialResult[Map[T, BoundedDouble]] = withScope { if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => - val map = new OpenHashMap[T,Long] + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T, Long] = { (ctx, iter) => + val map = new OpenHashMap[T, Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -1075,9 +1193,9 @@ abstract class RDD[T: ClassTag]( * If `sp` equals 0, the sparse representation is skipped. */ @Experimental - def countApproxDistinct(p: Int, sp: Int): Long = { - require(p >= 4, s"p ($p) must be at least 4") - require(sp <= 32, s"sp ($sp) cannot be greater than 32") + def countApproxDistinct(p: Int, sp: Int): Long = withScope { + require(p >= 4, s"p ($p) must be >= 4") + require(sp <= 32, s"sp ($sp) must be <= 32") require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)") val zeroCounter = new HyperLogLogPlus(p, sp) aggregate(zeroCounter)( @@ -1101,9 +1219,10 @@ abstract class RDD[T: ClassTag]( * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. */ - def countApproxDistinct(relativeSD: Double = 0.05): Long = { + def countApproxDistinct(relativeSD: Double = 0.05): Long = withScope { + require(relativeSD > 0.000017, s"accuracy ($relativeSD) must be greater than 0.000017") val p = math.ceil(2.0 * math.log(1.054 / relativeSD) / math.log(2)).toInt - countApproxDistinct(p, 0) + countApproxDistinct(if (p < 4) 4 else p, 0) } /** @@ -1119,7 +1238,9 @@ abstract class RDD[T: ClassTag]( * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ - def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this) + def zipWithIndex(): RDD[(T, Long)] = withScope { + new ZippedWithIndexRDD(this) + } /** * Zips this RDD with generated unique Long ids. Items in the kth partition will get ids k, n+k, @@ -1131,8 +1252,8 @@ abstract class RDD[T: ClassTag]( * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ - def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + def zipWithUniqueId(): RDD[(T, Long)] = withScope { + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1144,49 +1265,54 @@ abstract class RDD[T: ClassTag]( * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. + * + * @note due to complications in the internal implementation, this method will raise + * an exception if called on an RDD of `Nothing` or `Null`. */ - def take(num: Int): Array[T] = { + def take(num: Int): Array[T] = withScope { if (num == 0) { - return new Array[T](0) - } - - val buf = new ArrayBuffer[T] - val totalParts = this.partitions.length - var partsScanned = 0 - while (buf.size < num && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 - if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, - // interpolate the number of partitions we need to try, but overestimate it by 50%. - // We also cap the estimation in the end. - if (buf.size == 0) { - numPartsToTry = partsScanned * 4 - } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + new Array[T](0) + } else { + val buf = new ArrayBuffer[T] + val totalParts = this.partitions.length + var partsScanned = 0 + while (buf.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + if (buf.size == 0) { + numPartsToTry = partsScanned * 4 + } else { + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + } } - } - val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + val left = num - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) - res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry - } + res.foreach(buf ++= _.take(num - buf.size)) + partsScanned += numPartsToTry + } - buf.toArray + buf.toArray + } } /** * Return the first element in this RDD. */ - def first(): T = take(1) match { - case Array(t) => t - case _ => throw new UnsupportedOperationException("empty collection") + def first(): T = withScope { + take(1) match { + case Array(t) => t + case _ => throw new UnsupportedOperationException("empty collection") + } } /** @@ -1204,7 +1330,9 @@ abstract class RDD[T: ClassTag]( * @param ord the implicit ordering for T * @return an array of top elements */ - def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse) + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope { + takeOrdered(num)(ord.reverse) + } /** * Returns the first k (smallest) elements from this RDD as defined by the specified @@ -1222,7 +1350,7 @@ abstract class RDD[T: ClassTag]( * @param ord the implicit ordering for T * @return an array of top elements */ - def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = { + def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = withScope { if (num == 0) { Array.empty } else { @@ -1232,7 +1360,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1247,24 +1375,34 @@ abstract class RDD[T: ClassTag]( * Returns the max of this RDD as defined by the implicit Ordering[T]. * @return the maximum element of the RDD * */ - def max()(implicit ord: Ordering[T]): T = this.reduce(ord.max) + def max()(implicit ord: Ordering[T]): T = withScope { + this.reduce(ord.max) + } /** * Returns the min of this RDD as defined by the implicit Ordering[T]. * @return the minimum element of the RDD * */ - def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min) + def min()(implicit ord: Ordering[T]): T = withScope { + this.reduce(ord.min) + } /** + * @note due to complications in the internal implementation, this method will raise an + * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice + * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. + * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) * @return true if and only if the RDD contains no elements at all. Note that an RDD * may be empty even when it has at least 1 partition. */ - def isEmpty(): Boolean = partitions.length == 0 || take(1).length == 0 + def isEmpty(): Boolean = withScope { + partitions.length == 0 || take(1).length == 0 + } /** * Save this RDD as a text file, using string representations of elements. */ - def saveAsTextFile(path: String) { + def saveAsTextFile(path: String): Unit = withScope { // https://issues.apache.org/jira/browse/SPARK-2075 // // NullWritable is a `Comparable` in Hadoop 1.+, so the compiler cannot find an implicit @@ -1291,7 +1429,7 @@ abstract class RDD[T: ClassTag]( /** * Save this RDD as a compressed text file, using string representations of elements. */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]): Unit = withScope { // https://issues.apache.org/jira/browse/SPARK-2075 val nullWritableClassTag = implicitly[ClassTag[NullWritable]] val textClassTag = implicitly[ClassTag[Text]] @@ -1309,7 +1447,7 @@ abstract class RDD[T: ClassTag]( /** * Save this RDD as a SequenceFile of serialized objects. */ - def saveAsObjectFile(path: String) { + def saveAsObjectFile(path: String): Unit = withScope { this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) .saveAsSequenceFile(path) @@ -1318,40 +1456,111 @@ abstract class RDD[T: ClassTag]( /** * Creates tuples of the elements in this RDD by applying `f`. */ - def keyBy[K](f: T => K): RDD[(K, T)] = { - map(x => (f(x), x)) + def keyBy[K](f: T => K): RDD[(K, T)] = withScope { + val cleanedF = sc.clean(f) + map(x => (cleanedF(x), x)) } /** A private method for tests, to look at the contents of each partition */ - private[spark] def collectPartitions(): Array[Array[T]] = { + private[spark] def collectPartitions(): Array[Array[T]] = withScope { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } /** * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent + * directory set with `SparkContext#setCheckpointDir` and all references to its parent * RDDs will be removed. This function must be called before any job has been * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint() { + def checkpoint(): Unit = RDDCheckpointData.synchronized { + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() + checkpointData = Some(new ReliableRDDCheckpointData(this)) } } /** - * Return whether this RDD has been checkpointed or not + * Mark this RDD for local checkpointing using Spark's existing caching layer. + * + * This method is for users who wish to truncate RDD lineages while skipping the expensive + * step of replicating the materialized data in a reliable distributed file system. This is + * useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX). + * + * Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed + * data is written to ephemeral local storage in the executors instead of to a reliable, + * fault-tolerant storage. The effect is that if an executor fails during the computation, + * the checkpointed data may no longer be accessible, causing an irrecoverable job failure. + * + * This is NOT safe to use with dynamic allocation, which removes executors along + * with their cached blocks. If you must use both features, you are advised to set + * `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value. + * + * The checkpoint directory set through `SparkContext#setCheckpointDir` is not used. + */ + def localCheckpoint(): this.type = RDDCheckpointData.synchronized { + if (conf.getBoolean("spark.dynamicAllocation.enabled", false) && + conf.contains("spark.dynamicAllocation.cachedExecutorIdleTimeout")) { + logWarning("Local checkpointing is NOT safe to use with dynamic allocation, " + + "which removes executors along with their cached blocks. If you must use both " + + "features, you are advised to set `spark.dynamicAllocation.cachedExecutorIdleTimeout` " + + "to a high value. E.g. If you plan to use the RDD for 1 hour, set the timeout to " + + "at least 1 hour.") + } + + // Note: At this point we do not actually know whether the user will call persist() on + // this RDD later, so we must explicitly call it here ourselves to ensure the cached + // blocks are registered for cleanup later in the SparkContext. + // + // If, however, the user has already called persist() on this RDD, then we must adapt + // the storage level he/she specified to one that is appropriate for local checkpointing + // (i.e. uses disk) to guarantee correctness. + + if (storageLevel == StorageLevel.NONE) { + persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } else { + persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true) + } + + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning( + "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") + case _ => + } + checkpointData = Some(new LocalRDDCheckpointData(this)) + this + } + + /** + * Return whether this RDD is marked for checkpointing, either reliably or locally. */ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** - * Gets the name of the file to which this RDD was checkpointed + * Return whether this RDD is marked for local checkpointing. + * Exposed for testing. + */ + private[rdd] def isLocallyCheckpointed: Boolean = { + checkpointData match { + case Some(_: LocalRDDCheckpointData[T]) => true + case _ => false + } + } + + /** + * Gets the name of the directory to which this RDD was checkpointed. + * This is not defined if the RDD is checkpointed locally. */ - def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) + def getCheckpointFile: Option[String] = { + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[T]) => reliable.getCheckpointDir + case _ => None + } + } // ======================================================================= // Other internal methods and fields @@ -1362,6 +1571,17 @@ abstract class RDD[T: ClassTag]( /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSite = sc.getCallSite() + /** + * The scope associated with the operation that created this RDD. + * + * This is more flexible than the call site and can be defined hierarchically. For more + * detail, see the documentation of {{RDDOperationScope}}. This scope is not defined if the + * user instantiates this RDD himself without using any Spark operations. + */ + @transient private[spark] val scope: Option[RDDOperationScope] = { + Option(sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)).map(RDDOperationScope.fromJson) + } + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] @@ -1369,7 +1589,7 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassTag] = { + protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } @@ -1379,7 +1599,7 @@ abstract class RDD[T: ClassTag]( } /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ - def context = sc + def context: SparkContext = sc /** * Private API for changing an RDD's ClassTag. @@ -1406,13 +1626,15 @@ abstract class RDD[T: ClassTag]( * has completed (therefore the RDD has been materialized and potentially stored in memory). * doCheckpoint() is called recursively on the parent RDDs. */ - private[spark] def doCheckpoint() { - if (!doCheckpointCalled) { - doCheckpointCalled = true - if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() - } else { - dependencies.foreach(_.rdd.doCheckpoint()) + private[spark] def doCheckpoint(): Unit = { + RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) { + if (!doCheckpointCalled) { + doCheckpointCalled = true + if (checkpointData.isDefined) { + checkpointData.get.checkpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } } } } @@ -1421,7 +1643,7 @@ abstract class RDD[T: ClassTag]( * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) * created from the checkpoint file, and forget its old dependencies and partitions. */ - private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { + private[spark] def markCheckpointed(): Unit = { clearDependencies() partitions_ = null deps = null // Forget the constructor argument for dependencies too @@ -1440,14 +1662,14 @@ abstract class RDD[T: ClassTag]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString: String = { // Get a debug description of an rdd without its children - def debugSelf (rdd: RDD[_]): Seq[String] = { + def debugSelf(rdd: RDD[_]): Seq[String] = { import Utils.bytesToString val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => - " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( + " CachedPartitions: %d; MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), - bytesToString(info.tachyonSize), bytesToString(info.diskSize))) + bytesToString(info.externalBlockStoreSize), bytesToString(info.diskSize))) s"$rdd [$persistence]" +: storageInfo } @@ -1459,22 +1681,22 @@ abstract class RDD[T: ClassTag]( case 0 => Seq.empty case 1 => val d = rdd.dependencies.head - debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]], true) case _ => val frontDeps = rdd.dependencies.take(len - 1) val frontDepStrings = frontDeps.flatMap( - d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]])) val lastDep = rdd.dependencies.last val lastDepStrings = - debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) (frontDepStrings ++ lastDepStrings) } } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1484,7 +1706,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( @@ -1497,10 +1719,11 @@ abstract class RDD[T: ClassTag]( case (desc: String, _) => s"$nextPrefix$desc" } ++ debugChildren(rdd, nextPrefix) } - def debugString(rdd: RDD[_], - prefix: String = "", - isShuffle: Boolean = true, - isLastChild: Boolean = false): Seq[String] = { + def debugString( + rdd: RDD[_], + prefix: String = "", + isShuffle: Boolean = true, + isLastChild: Boolean = false): Seq[String] = { if (isShuffle) { shuffleDebugString(rdd, prefix, isLastChild) } else { @@ -1527,7 +1750,7 @@ abstract class RDD[T: ClassTag]( */ object RDD { - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -1541,9 +1764,15 @@ object RDD { new AsyncRDDActions(rdd) } - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { - new SequenceFileRDDFunctions(rdd) + implicit def rddToSequenceFileRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], + keyWritableFactory: WritableFactory[K], + valueWritableFactory: WritableFactory[V]) + : SequenceFileRDDFunctions[K, V] = { + implicit val keyConverter = keyWritableFactory.convert + implicit val valueConverter = valueWritableFactory.convert + new SequenceFileRDDFunctions(rdd, + keyWritableFactory.writableClass(kt), valueWritableFactory.writableClass(vt)) } implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](rdd: RDD[(K, V)]) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f185797..0e43520870c0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -19,18 +19,15 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.hadoop.fs.Path - -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException} -import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.Partition /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + * [ Initialized --> checkpointing in progress --> checkpointed ]. */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value + val Initialized, CheckpointingInProgress, Checkpointed = Value } /** @@ -39,96 +36,76 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) - extends Logging with Serializable { +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) + extends Serializable { import CheckpointState._ // The checkpoint state of the associated RDD. - var cpState = Initialized - - // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None + protected var cpState = Initialized - // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None + // The RDD that contains our checkpointed data + private var cpRDD: Option[CheckpointRDD[T]] = None - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } + // TODO: are we sure we need to use a global lock in the following methods? - // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } + /** + * Return whether the checkpoint data for this RDD is already persisted. + */ + def isCheckpointed: Boolean = RDDCheckpointData.synchronized { + cpState == Checkpointed } - // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } - } - - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return + /** + * Materialize this RDD and persist its content. + * This is called immediately after the first action invoked on this RDD has completed. + */ + final def checkpoint(): Unit = { + // Guard against multiple threads checkpointing the same RDD by + // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { + if (cpState == Initialized) { cpState = CheckpointingInProgress } else { return } } - // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } + val newRDD = doCheckpoint() - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { - throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") - } - - // Change the dependencies and partitions of the RDD + // Update our state and truncate the RDD lineage RDDCheckpointData.synchronized { - cpFile = Some(path.toString) cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed + rdd.markCheckpointed() } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } - } - - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } + /** + * Materialize this RDD and persist its content. + * + * Subclasses should override this method to define custom checkpointing behavior. + * @return the checkpoint RDD created in the process. + */ + protected def doCheckpoint(): CheckpointRDD[T] + + /** + * Return the RDD that contains our checkpointed data. + * This is only defined if the checkpoint state is `Checkpointed`. + */ + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD } + + /** + * Return the partitions of the resulting checkpoint RDD. + * For tests only. + */ + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.map(_.partitions).getOrElse { Array.empty } } - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } - } } -// Used for synchronization +/** + * Global lock for synchronizing checkpoint operations. + */ private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala new file mode 100644 index 000000000000..44667281c106 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.atomic.AtomicInteger + +import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude, JsonPropertyOrder} +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.{Logging, SparkContext} + +/** + * A general, named code block representing an operation that instantiates RDDs. + * + * All RDDs instantiated in the corresponding code block will store a pointer to this object. + * Examples include, but will not be limited to, existing RDD operations, such as textFile, + * reduceByKey, and treeAggregate. + * + * An operation scope may be nested in other scopes. For instance, a SQL query may enclose + * scopes associated with the public RDD APIs it uses under the hood. + * + * There is no particular relationship between an operation scope and a stage or a job. + * A scope may live inside one stage (e.g. map) or span across multiple jobs (e.g. take). + */ +@JsonInclude(Include.NON_NULL) +@JsonPropertyOrder(Array("id", "name", "parent")) +private[spark] class RDDOperationScope( + val name: String, + val parent: Option[RDDOperationScope] = None, + val id: String = RDDOperationScope.nextScopeId().toString) { + + def toJson: String = { + RDDOperationScope.jsonMapper.writeValueAsString(this) + } + + /** + * Return a list of scopes that this scope is a part of, including this scope itself. + * The result is ordered from the outermost scope (eldest ancestor) to this scope. + */ + @JsonIgnore + def getAllScopes: Seq[RDDOperationScope] = { + parent.map(_.getAllScopes).getOrElse(Seq.empty) ++ Seq(this) + } + + override def equals(other: Any): Boolean = { + other match { + case s: RDDOperationScope => + id == s.id && name == s.name && parent == s.parent + case _ => false + } + } + + override def toString: String = toJson +} + +/** + * A collection of utility methods to construct a hierarchical representation of RDD scopes. + * An RDD scope tracks the series of operations that created a given RDD. + */ +private[spark] object RDDOperationScope extends Logging { + private val jsonMapper = new ObjectMapper().registerModule(DefaultScalaModule) + private val scopeCounter = new AtomicInteger(0) + + def fromJson(s: String): RDDOperationScope = { + jsonMapper.readValue(s, classOf[RDDOperationScope]) + } + + /** Return a globally unique operation scope ID. */ + def nextScopeId(): Int = scopeCounter.getAndIncrement + + /** + * Execute the given body such that all RDDs created in this body will have the same scope. + * The name of the scope will be the first method name in the stack trace that is not the + * same as this method's. + * + * Note: Return statements are NOT allowed in body. + */ + private[spark] def withScope[T]( + sc: SparkContext, + allowNesting: Boolean = false)(body: => T): T = { + val ourMethodName = "withScope" + val callerMethodName = Thread.currentThread.getStackTrace() + .dropWhile(_.getMethodName != ourMethodName) + .find(_.getMethodName != ourMethodName) + .map(_.getMethodName) + .getOrElse { + // Log a warning just in case, but this should almost certainly never happen + logWarning("No valid method name for this RDD operation scope!") + "N/A" + } + withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body) + } + + /** + * Execute the given body such that all RDDs created in this body will have the same scope. + * + * If nesting is allowed, any subsequent calls to this method in the given body will instantiate + * child scopes that are nested within our scope. Otherwise, these calls will take no effect. + * + * Additionally, the caller of this method may optionally ignore the configurations and scopes + * set by the higher level caller. In this case, this method will ignore the parent caller's + * intention to disallow nesting, and the new scope instantiated will not have a parent. This + * is useful for scoping physical operations in Spark SQL, for instance. + * + * Note: Return statements are NOT allowed in body. + */ + private[spark] def withScope[T]( + sc: SparkContext, + name: String, + allowNesting: Boolean, + ignoreParent: Boolean)(body: => T): T = { + // Save the old scope to restore it later + val scopeKey = SparkContext.RDD_SCOPE_KEY + val noOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY + val oldScopeJson = sc.getLocalProperty(scopeKey) + val oldScope = Option(oldScopeJson).map(RDDOperationScope.fromJson) + val oldNoOverride = sc.getLocalProperty(noOverrideKey) + try { + if (ignoreParent) { + // Ignore all parent settings and scopes and start afresh with our own root scope + sc.setLocalProperty(scopeKey, new RDDOperationScope(name).toJson) + } else if (sc.getLocalProperty(noOverrideKey) == null) { + // Otherwise, set the scope only if the higher level caller allows us to do so + sc.setLocalProperty(scopeKey, new RDDOperationScope(name, oldScope).toJson) + } + // Optionally disallow the child body to override our scope + if (!allowNesting) { + sc.setLocalProperty(noOverrideKey, "true") + } + body + } finally { + // Remember to restore any state that was modified before exiting + sc.setLocalProperty(scopeKey, oldScopeJson) + sc.setLocalProperty(noOverrideKey, oldNoOverride) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala new file mode 100644 index 000000000000..35d8b0bfd18c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * An RDD that reads from checkpoint files previously written to reliable storage. + */ +private[spark] class ReliableCheckpointRDD[T: ClassTag]( + @transient sc: SparkContext, + val checkpointPath: String) + extends CheckpointRDD[T](sc) { + + @transient private val hadoopConf = sc.hadoopConfiguration + @transient private val cpath = new Path(checkpointPath) + @transient private val fs = cpath.getFileSystem(hadoopConf) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(hadoopConf)) + + // Fail fast if checkpoint directory does not exist + require(fs.exists(cpath), s"Checkpoint directory does not exist: $checkpointPath") + + /** + * Return the path of the checkpoint directory this RDD reads data from. + */ + override def getCheckpointFile: Option[String] = Some(checkpointPath) + + /** + * Return partitions described by the files in the checkpoint directory. + * + * Since the original RDD may belong to a prior application, there is no way to know a + * priori the number of partitions to expect. This method assumes that the original set of + * checkpoint files are fully preserved in a reliable storage across application lifespans. + */ + protected override def getPartitions: Array[Partition] = { + // listStatus can throw exception if path does not exist. + val inputFiles = fs.listStatus(cpath) + .map(_.getPath) + .filter(_.getName.startsWith("part-")) + .sortBy(_.toString) + // Fail fast if input files are invalid + inputFiles.zipWithIndex.foreach { case (path, i) => + if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) { + throw new SparkException(s"Invalid checkpoint file: $path") + } + } + Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i)) + } + + /** + * Return the locations of the checkpoint file associated with the given partition. + */ + protected override def getPreferredLocations(split: Partition): Seq[String] = { + val status = fs.getFileStatus( + new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + /** + * Read the content of the checkpoint file associated with the given partition. + */ + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)) + ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context) + } + +} + +private[spark] object ReliableCheckpointRDD extends Logging { + + /** + * Return the checkpoint file name for the given partition. + */ + private def checkpointFileName(partitionIndex: Int): String = { + "part-%05d".format(partitionIndex) + } + + /** + * Write this partition's values to a checkpoint file. + */ + def writeCheckpointFile[T: ClassTag]( + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { + val env = SparkEnv.get + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(broadcastedConf.value.value) + + val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId()) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = + new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}") + + if (fs.exists(tempOutputPath)) { + throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists") + } + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo(s"Deleting tempOutputPath $tempOutputPath") + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo(s"Final output path $finalOutputPath already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + /** + * Read the content of the specified checkpoint file. + */ + def readCheckpointFile[T]( + path: Path, + broadcastedConf: Broadcast[SerializableConfiguration], + context: TaskContext): Iterator[T] = { + val env = SparkEnv.get + val fs = path.getFileSystem(broadcastedConf.value.value) + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + val fileInputStream = fs.open(path, bufferSize) + val serializer = env.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => deserializeStream.close()) + + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala new file mode 100644 index 000000000000..1df8eef5ff2b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.util.SerializableConfiguration + +/** + * An implementation of checkpointing that writes the RDD data to reliable storage. + * This allows drivers to be restarted on failure with previously computed state. + */ +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) + extends RDDCheckpointData[T](rdd) with Logging { + + // The directory to which the associated RDD has been checkpointed to + // This is assumed to be a non-local path that points to some reliable storage + private val cpDir: String = + ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) + .map(_.toString) + .getOrElse { throw new SparkException("Checkpoint dir must be specified.") } + + /** + * Return the directory to which this RDD was checkpointed. + * If the RDD is not checkpointed yet, return None. + */ + def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized { + if (isCheckpointed) { + Some(cpDir.toString) + } else { + None + } + } + + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + protected override def doCheckpoint(): CheckpointRDD[T] = { + + // Create the output path for the checkpoint + val path = new Path(cpDir) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(path)) { + throw new SparkException(s"Failed to create checkpoint path $cpDir") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) + val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) + if (newRDD.partitions.length != rdd.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $rdd(${rdd.partitions.length})") + } + + // Optionally clean our checkpoint files if the reference is out of scope + if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { + rdd.context.cleaner.foreach { cleaner => + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + } + } + + logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") + + newRDD + } + +} + +private[spark] object ReliableRDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ + def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = { + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } + } + + /** Clean up the files associated with the checkpoint data for this RDD. */ + def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = { + checkpointPath(sc, rddId).foreach { path => + val fs = path.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(path)) { + fs.delete(path, true) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 2b4891695143..4b5f15dd06b8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -30,13 +30,35 @@ import org.apache.spark.Logging * through an implicit conversion. Note that this can't be part of PairRDDFunctions because * we need more implicit parameters to convert our keys and values to Writable. * - * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( - self: RDD[(K, V)]) + self: RDD[(K, V)], + _keyWritableClass: Class[_ <: Writable], + _valueWritableClass: Class[_ <: Writable]) extends Logging with Serializable { + @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") + def this(self: RDD[(K, V)]) { + this(self, null, null) + } + + private val keyWritableClass = + if (_keyWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[K]() + } else { + _keyWritableClass + } + + private val valueWritableClass = + if (_valueWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[V]() + } else { + _valueWritableClass + } + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { val c = { if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { @@ -55,6 +77,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag c.asInstanceOf[Class[_ <: Writable]] } + /** * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key * and value types. If the key or value are Writable, then we use their classes directly; @@ -62,29 +85,33 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported * file system. */ - def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { + def saveAsSequenceFile( + path: String, + codec: Option[Class[_ <: CompressionCodec]] = None): Unit = self.withScope { def anyToWritable[U <% Writable](u: U): Writable = u - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass) + // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and + // valueWritableClass at the compile time. To implement that, we need to add type parameters to + // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a + // breaking change. + val convertKey = self.keyClass != keyWritableClass + val convertValue = self.valueClass != valueWritableClass - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + - valueClass.getSimpleName + ")" ) + logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," + + valueWritableClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) + self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index d9fe6847254f..2dc47f95937c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.rdd -import scala.reflect.ClassTag - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx + override val index: Int = idx override def hashCode(): Int = idx } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala new file mode 100644 index 000000000000..fa3fecc80cb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.text.SimpleDateFormat +import java.util.Date + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.DataReadMethod +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + + +private[spark] class SqlNewHadoopPartition( + rddId: Int, + val index: Int, + @transient rawSplit: InputSplit with Writable) + extends SparkPartition { + + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = 41 * (41 + rddId) + index +} + +/** + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). + * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. + * 1. A shared broadcast Hadoop Configuration. + * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side + * to the shared Hadoop Configuration. + * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side + * and the executor side to the shared Hadoop Configuration. + * + * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. + */ +private[spark] class SqlNewHadoopRDD[V: ClassTag]( + @transient sc : SparkContext, + broadcastedConf: Broadcast[SerializableConfiguration], + @transient initDriverSideJobFuncOpt: Option[Job => Unit], + initLocalJobFuncOpt: Option[Job => Unit], + inputFormatClass: Class[_ <: InputFormat[Void, V]], + valueClass: Class[V]) + extends RDD[V](sc, Nil) + with SparkHadoopMapReduceUtil + with Logging { + + protected def getJob(): Job = { + val conf: Configuration = broadcastedConf.value.value + // "new Job" will make a copy of the conf. Then, it is + // safe to mutate conf properties with initLocalJobFuncOpt + // and initDriverSideJobFuncOpt. + val newJob = new Job(conf) + initLocalJobFuncOpt.map(f => f(newJob)) + newJob + } + + def getConf(isDriverSide: Boolean): Configuration = { + val job = getJob() + if (isDriverSide) { + initDriverSideJobFuncOpt.map(f => f(job)) + } + job.getConfiguration + } + + private val jobTrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient protected val jobId = new JobID(jobTrackerId, id) + + override def getPartitions: Array[SparkPartition] = { + val conf = getConf(isDriverSide = true) + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[SparkPartition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } + + override def compute( + theSplit: SparkPartition, + context: TaskContext): Iterator[V] = { + val iter = new Iterator[V] { + val split = theSplit.asInstanceOf[SqlNewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = getConf(isDriverSide = false) + + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + } + inputMetrics.setBytesReadCallback(bytesReadCallback) + + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + private[this] var reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => close()) + + private[this] var havePair = false + private[this] var finished = false + + override def hasNext: Boolean = { + if (context.isInterrupted) { + throw new TaskKilledException + } + if (!finished && !havePair) { + finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } + havePair = !finished + } + !finished + } + + override def next(): V = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + if (!finished) { + inputMetrics.incRecordsRead(1) + } + reader.getCurrentValue + } + + private def close() { + try { + if (reader != null) { + reader.close() + reader = null + + SqlNewHadoopRDD.unsetInputFileName() + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } + } + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } + } + } + iter + } + + override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { + val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value + val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e : Exception => + logDebug("Failed to use InputSplit#getLocationInfo.", e) + None + } + case None => None + } + locs.getOrElse(split.getLocations.filter(_ != "localhost")) + } + + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } +} + +private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[SparkPartition] = firstParent[T].partitions + + override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { + val partition = split.asInstanceOf[SqlNewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index ed24ea22a661..9a4fa301b06e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -76,14 +76,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -105,21 +105,27 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) + def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + dependencies(depNum) match { + case oneToOneDependency: OneToOneDependency[_] => + val dependencyPartition = partition.narrowDeps(depNum).get.split + oneToOneDependency.rdd.iterator(dependencyPartition, context) + .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(handle) => - val iter = SparkEnv.get.shuffleManager - .getReader(handle, partition.index, partition.index + 1, context) - .read() - iter.foreach(op) + case shuffleDependency: ShuffleDependency[_, _, _] => + val iter = SparkEnv.get.shuffleManager + .getReader( + shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + .read() + iter.foreach(op) + } } + // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) + integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + integrate(1, t => map.remove(t._1)) + map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index aece683ff319..3986645350a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -44,7 +44,7 @@ private[spark] class UnionPartition[T: ClassTag]( var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(parentPartition) + def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition) override val index: Int = idx @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 95b2dd954e9f..81f40ad33aa5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -32,7 +32,7 @@ private[spark] class ZippedPartitionsPartition( override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues + def partitions: Seq[Partition] = partitionValues @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => @@ -123,7 +123,7 @@ private[spark] class ZippedPartitionsRDD3 } private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f..e277ae28d588 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { @@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L prev.context.runJob( prev, Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - allowLocal = false + 0 until n - 1 // do not need to count the last partition ).scanLeft(0L)(_ + _) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala new file mode 100644 index 000000000000..3e5b64265e91 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala new file mode 100644 index 000000000000..dfcbc51cdf61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The life-cycle of an endpoint is: + * + * constructor -> onStart -> receive* -> onStop + * + * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Invoked when any exception is thrown during handling messages. + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(_self) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala new file mode 100644 index 000000000000..7409ac885999 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.concurrent.Future +import scala.reflect.ClassTag + +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SparkException, Logging, SparkConf} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = ask[T](message, timeout) + val result = timeout.awaitResult(future) + if (result == null) { + throw new SparkException("RpcEndpoint returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 000000000000..29debe808130 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.net.URI +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.{RpcUtils, Utils} + + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + + +/** + * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to + * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by + * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the + * sender, or logging them if no such sender or `NotSerializableException`. + * + * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. + */ +private[spark] abstract class RpcEnv(conf: SparkConf) { + + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) + + /** + * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. + */ + private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Return the address that [[RpcEnv]] is listening to. + */ + def address: RpcAddress + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not + * guarantee thread-safety. + */ + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. + */ + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. + */ + def setupEndpointRefByURI(uri: String): RpcEndpointRef = { + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + * asynchronously. + */ + def asyncSetupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { + asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * This is a blocking action. + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[shutdown()]]. + */ + def shutdown(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit + + /** + * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of + * creating it manually because different [[RpcEnv]] may have different formats. + */ + def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T +} + + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + + +/** + * Represents a host and port. + */ +private[spark] case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort +} + + +private[spark] object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURI(uri: URI): RpcAddress = { + RpcAddress(uri.getHost, uri.getPort) + } + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + fromURI(new java.net.URI(uri)) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 000000000000..fc17542abf81 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.Future +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.event.Logging.Error +import akka.pattern.{ask => akkaAsk} +import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import akka.serialization.JavaSerializer + +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private[akka] ( + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + extends RpcEnv(conf) with Logging { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + // In some test case, ActorSystem doesn't bind to any address. + // So just use some default value since they are only some unit tests + RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) + } + + override val address: RpcAddress = defaultAddress + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + @volatile var endpointRef: AkkaRpcEndpointRef = null + // Use lazy because the Actor needs to use `endpointRef`. + // So `actorRef` should be created after assigning `endpointRef`. + lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + assert(endpointRef != null) + + override def preStart(): Unit = { + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + safelyCall(endpoint) { + endpoint.onStart() + } + } + + override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case DisassociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } + + case e: AssociationEvent => + // TODO ignore? + + case m: AkkaMessage => + logDebug(s"Received RPC message: $m") + safelyCall(endpoint) { + processMessage(endpoint, m, sender) + } + + case AkkaFailure(e) => + safelyCall(endpoint) { + throw e + } + + case message: Any => { + logWarning(s"Unknown message: $message") + } + + } + + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() + } + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + registerEndpoint(endpoint, endpointRef) + // Now actorRef can be created safely + endpointRef.init() + endpointRef + } + + private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { + val message = m.message + val needReply = m.needReply + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + _sender ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + _sender ! AkkaMessage(response, false) + } + + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + }) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](message, { message => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + } catch { + case NonFatal(e) => + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. + throw e + } + } + } + + /** + * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will + * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + import actorSystem.dispatcher + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { + AkkaUtils.address( + AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) + } + + override def shutdown(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } +} + +private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } +} + +/** + * Monitor errors reported by Akka and log them. + */ +private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { + + override def preStart(): Unit = { + context.system.eventStream.subscribe(self, classOf[Error]) + } + + override def receiveWithLogging: Actor.Receive = { + case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + } +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + @transient _actorRef: => ActorRef, + @transient conf: SparkConf, + @transient initInConstructor: Boolean = true) + extends RpcEndpointRef(conf) with Logging { + + lazy val actorRef = _actorRef + + override lazy val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override lazy val name: String = actorRef.path.name + + private[akka] def init(): Unit = { + // Initialize the lazy vals + actorRef + address + name + } + + if (initInConstructor) { + init() + } + + override def send(message: Any): Unit = { + actorRef ! AkkaMessage(message, false) + } + + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { + // The function will run in the calling thread, so it should be short and never block. + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" + + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() +} + +/** + * A wrapper to `message` so that the receiver knows if the sender expects a reply. + * @param message + * @param needReply if the sender expects a reply message + */ +private[akka] case class AkkaMessage(message: Any, needReply: Boolean) + +/** + * A reply with the failure error from the receiver to the sender + */ +private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index fa83372bb4d1..11d123eec43c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. */ @DeveloperApi -class AccumulableInfo ( +class AccumulableInfo private[spark] ( val id: Long, val name: String, val update: Option[String], // represents a partial update within a task - val value: String) { + val value: String, + val internal: Boolean) { override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => @@ -39,8 +40,11 @@ class AccumulableInfo ( } object AccumulableInfo { - def apply(id: Long, name: String, update: Option[String], value: String) = - new AccumulableInfo(id, name, update, value) + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal = false) + } - def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) + def apply(id: Long, name: String, value: String): AccumulableInfo = { + new AccumulableInfo(id, name, None, value, internal = false) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index b755d8fb1575..50a69379412d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.CallSite */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: Stage, + val finalStage: ResultStage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: CallSite, diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 6d39a5e3fa64..9f218c64cac2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -26,6 +26,7 @@ package org.apache.spark.scheduler private[spark] class ApplicationEventListener extends SparkListener { var appName: Option[String] = None var appId: Option[String] = None + var appAttemptId: Option[String] = None var sparkUser: Option[String] = None var startTime: Option[Long] = None var endTime: Option[Long] = None @@ -35,6 +36,7 @@ private[spark] class ApplicationEventListener extends SparkListener { override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) appId = applicationStart.appId + appAttemptId = applicationStart.appAttemptId startTime = Some(applicationStart.time) sparkUser = Some(applicationStart.sparkUser) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1cfe98673773..daf9b0f95273 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,26 +19,26 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await +import scala.collection.Map +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps -import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout +import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils} +import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -54,6 +54,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -63,7 +67,7 @@ class DAGScheduler( mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Logging { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { @@ -78,13 +82,15 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] // Stages we need to run whose parents aren't done @@ -98,8 +104,14 @@ class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] - // Contains the locations that each RDD's partitions are cached on - private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] + /** + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // every task. When we detect a node failing, we note the current epoch number and failed @@ -109,41 +121,62 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - - /** If enabled, we may run certain actions like take() and first() locally. */ - private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) - /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) private val messageScheduler = - Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) + ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - // Called by TaskScheduler to report task's starting. + // Flag to control if reduce tasks are assigned preferred locations + private val shuffleLocalityEnabled = + sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + // Number of map, reduce tasks above which we do not assign preferred locations + // based on map output sizes. We limit the size of jobs for which assign preferred locations + // as computing the top locations by size becomes expensive. + private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. + // Making this larger will focus on fewer locations where most data can be read locally, but + // may lead to more delay in scheduling if those locations are busy. + private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 + + /** + * Called by the TaskSetManager to report task's starting. + */ def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) } - // Called to report that a task has completed and results are being fetched remotely. + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ def taskGettingResult(taskInfo: TaskInfo) { eventProcessLoop.post(GettingResultEvent(taskInfo)) } - // Called by TaskScheduler to report task completions or failures. + /** + * Called by the TaskSetManager to report task completions or failures. + */ def taskEnded( task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + taskMetrics: TaskMetrics): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } @@ -158,83 +191,115 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } - // Called by TaskScheduler when an executor fails. - def executorLost(execId: String) { + /** + * Called by TaskScheduler implementation when an executor fails. + */ + def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } - // Called by TaskScheduler when a host is added - def executorAdded(execId: String, host: String) { + /** + * Called by TaskScheduler implementation when a host is added. + */ + def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or - // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String) { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + private[scheduler] + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) - cacheLocs(rdd.id) = blockIds.map { id => - locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } + cacheLocs(rdd.id) = locs } cacheLocs(rdd.id) } - private def clearCacheLocs() { + private def clearCacheLocs(): Unit = cacheLocs.synchronized { cacheLocs.clear() } /** * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { + private def getShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + firstJobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, jobId) + registerShuffleDependencies(shuffleDep, firstJobId) // Then register current shuffleDep - val stage = - newOrUsedStage( - shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, - shuffleDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - + stage } } /** - * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation - * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage - * directly. + * Helper function to eliminate some code re-use when creating new stages. + */ + private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, firstJobId) + val id = nextStageId.getAndIncrement() + (parentStages, id) + } + + /** + * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in + * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId. + * Production of shuffle map stages should always use newOrUsedShuffleStage, not + * newShuffleMapStage directly. + */ + private def newShuffleMapStage( + rdd: RDD[_], + numTasks: Int, + shuffleDep: ShuffleDependency[_, _, _], + firstJobId: Int, + callSite: CallSite): ShuffleMapStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) + val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, + firstJobId, callSite, shuffleDep) + + stageIdToStage(id) = stage + updateJobIdStageIdMaps(firstJobId, stage) + stage + } + + /** + * Create a ResultStage associated with the provided jobId. */ - private def newStage( + private def newResultStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, - callSite: CallSite) - : Stage = - { - val parentStages = getParentStages(rdd, jobId) - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) + callSite: CallSite): ResultStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) + stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -242,40 +307,37 @@ class DAGScheduler( /** * Create a shuffle map Stage for the given RDD. The stage will also be associated with the - * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ - private def newOrUsedStage( - rdd: RDD[_], - numTasks: Int, + private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, - callSite: CallSite) - : Stage = - { - val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + firstJobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + val numTasks = rdd.partitions.length + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + for (i <- 0 until locs.length) { + stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) } else { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) } stage } /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -289,7 +351,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, jobId) + parents += getShuffleMapStage(shufDep, firstJobId) case _ => waitingForVisit.push(dep.rdd) } @@ -297,26 +359,23 @@ class DAGScheduler( } } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents.toList } - // Find ancestor missing shuffle dependencies and register into shuffleToMapStage - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (!parentsWithNoMapStage.isEmpty) { + while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = - newOrUsedStage( - currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, - currentShufDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } - // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] @@ -342,7 +401,7 @@ class DAGScheduler( } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents @@ -357,11 +416,12 @@ class DAGScheduler( def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } @@ -373,7 +433,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } missing.toList @@ -383,7 +443,7 @@ class DAGScheduler( * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. */ - private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { def updateJobIdStageIdMapsList(stages: List[Stage]) { if (stages.nonEmpty) { val s = stages.head @@ -403,7 +463,7 @@ class DAGScheduler( * * @param job The job whose state to cleanup. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) @@ -463,10 +523,8 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = - { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -484,29 +542,30 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) + jobId, rdd, func2, partitions.toArray, callSite, waiter, + SerializationUtils.clone(properties))) waiter } - def runJob[T, U: ClassTag]( + def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - { + properties: Properties): Unit = { val start = System.nanoTime - val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => { + case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - } case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. + val callerStackTrace = Thread.currentThread().getStackTrace.tail + exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) throw exception } } @@ -517,27 +576,25 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray + val partitions = (0 until rdd.partitions.length).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int) { + def cancelJob(jobId: Int): Unit = { logInfo("Asked to cancel job " + jobId) eventProcessLoop.post(JobCancelled(jobId)) } - def cancelJobGroup(groupId: String) { + def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) } @@ -545,13 +602,13 @@ class DAGScheduler( /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs() { + def cancelAllJobs(): Unit = { eventProcessLoop.post(AllJobsCancelled) } private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation(_, + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, reason = "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -577,7 +634,7 @@ class DAGScheduler( clearCacheLocs() val failedStagesCopy = failedStages.toArray failedStages.clear() - for (stage <- failedStagesCopy.sortBy(_.jobId)) { + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -597,61 +654,11 @@ class DAGScheduler( logTrace("failed: " + failedStages) val waitingStagesCopy = waitingStages.toArray waitingStages.clear() - for (stage <- waitingStagesCopy.sortBy(_.jobId)) { + for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - var jobResult: JobResult = JobSucceeded - try { - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, - attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.markTaskCompleted() - TaskContextHelper.unset() - } - } catch { - case e: Exception => - val exception = new SparkDriverExecutionException(e) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - case oom: OutOfMemoryError => - val exception = new SparkException("Local job aborted due to out of memory error", oom) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - } finally { - val s = job.finalStage - // clean up data structures that were populated for a local job, - // but that won't get cleaned up via the normal paths through - // completion events or stage abort - stageIdToStage -= s.id - jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult)) - } - } - /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). @@ -665,8 +672,11 @@ class DAGScheduler( private[scheduler] def handleJobGroupCancelled(groupId: String) { // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val activeInGroup = activeJobs.filter { activeJob => + Option(activeJob.properties).exists { + _.getProperty(SparkContext.SPARK_JOB_GROUP_ID) == groupId + } + } val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -680,8 +690,11 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } @@ -693,11 +706,12 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -710,16 +724,14 @@ class DAGScheduler( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) - { - var finalStage: Stage = null + properties: Properties) { + var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -729,29 +741,20 @@ class DAGScheduler( if (finalStage != null) { val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.shortForm, partitions.length, allowLocal)) + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val shouldRunLocally = - localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 - val jobSubmissionTime = clock.getTime() - if (shouldRunLocally) { - // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) - runLocally(job) - } else { - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) - } + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) } submitWaitingStages() } @@ -764,7 +767,7 @@ class DAGScheduler( if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) - if (missing == Nil) { + if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) } else { @@ -775,7 +778,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -786,28 +789,56 @@ class DAGScheduler( stage.pendingTasks.clear() // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = { - if (stage.isShuffleMap) { - (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) - } else { - val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { + stage match { + case stage: ShuffleMapStage => + val allPartitions = 0 until stage.numPartitions + val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } + (allPartitions, filteredPartitions) + case stage: ResultStage => + val job = stage.resultOfJob.get + val allPartitions = 0 until job.numPartitions + val filteredPartitions = allPartitions.filter { id => !job.finished(id) } + (allPartitions, filteredPartitions) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null + // Create internal accumulators if the stage has no accumulators initialized. + // Reset internal accumulators only if this stage is not partially submitted + // Otherwise, we may override existing accumulator values from some tasks + if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { + stage.resetInternalAccumulators() } + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + outputCommitCoordinator.stageStart(stage.id) + val taskIdToLocations = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + val job = s.resultOfJob.get + partitionsToCompute.map { id => + val p = job.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -820,55 +851,77 @@ class DAGScheduler( try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = - if (stage.isShuffleMap) { - closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() - } else { - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() - } + val taskBinaryBytes: Array[Byte] = stage match { + case stage: ShuffleMapStage => + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + case stage: ResultStage => + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage + + // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } - val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } - } else { - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = taskIdToLocations(id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, stage.internalAccumulators) + } + + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = taskIdToLocations(id) + new ResultTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, id, stage.internalAccumulators) + } } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + runningStages -= stage + return } if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) - taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stage.latestInfo.submissionTime = Some(clock.getTime()) + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - runningStages -= stage + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) + + val debugString = stage match { + case stage: ShuffleMapStage => + s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})" + case stage : ResultStage => + s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + } + logDebug(debugString) } } @@ -879,16 +932,24 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // In this instance, although the reference in Accumulators.originals is a WeakRef, + // it's guaranteed to exist since the event.accumUpdates Map exists + + val acc = Accumulators.originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => throw new NullPointerException("Non-existent reference to Accumulator") + } + // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + val value = s"${acc.value}" + stage.latestInfo.accumulables(id) = + new AccumulableInfo(id, name, None, value, acc.isInternal) event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) } } } catch { @@ -909,10 +970,13 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) + outputCommitCoordinator.taskCompleted(stageId, task.partitionId, + event.taskInfo.attempt, event.reason) + // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val attemptId = task.stageAttemptId listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) } @@ -921,23 +985,8 @@ class DAGScheduler( // Skip all the actions if the stage has been cancelled. return } - val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTime()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } + val stage = stageIdToStage(task.stageId) event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -945,7 +994,10 @@ class DAGScheduler( stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - stage.resultOfJob match { + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -953,10 +1005,10 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - markStageAsFinished(stage) + markStageAsFinished(resultStage) cleanupStateForJobAndIndependentStages(job) listenerBus.post( - SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded)) + SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) } // taskSucceeded runs some user code that might throw an exception. Make sure @@ -965,7 +1017,7 @@ class DAGScheduler( job.listener.taskSucceeded(rt.outputId, event.result) } catch { case e: Exception => - // TODO: Perhaps we want to mark the stage as failed? + // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } } @@ -974,57 +1026,63 @@ class DAGScheduler( } case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { - stage.addOutputLoc(smt.partitionId, status) + shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { - markStageAsFinished(stage) + + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - if (stage.shuffleDep.isDefined) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } + + // We supply true to increment the epoch number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // epoch incremented to refetch them. + // TODO: Only increment the epoch number if this is not the first time + // we registered these map outputs. + mapOutputTracker.registerMapOutputs( + shuffleStage.shuffleDep.shuffleId, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head), + changeEpoch = true) + clearCacheLocs() - if (stage.outputLocs.exists(_ == Nil)) { - // Some tasks had failed; let's resubmit this stage + if (shuffleStage.outputLocs.contains(Nil)) { + // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) + .map(_._2).mkString(", ")) + submitStage(shuffleStage) } else { val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waitingStages) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + for (shuffleStage <- waitingStages) { + logInfo("Missing parents for " + shuffleStage + ": " + + getMissingParentStages(shuffleStage)) } - for (stage <- waitingStages if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage + for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) + { + newlyRunnable += shuffleStage } waitingStages --= newlyRunnable runningStages ++= newlyRunnable for { - stage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(stage) + shuffleStage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(shuffleStage) } { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage, jobId) + logInfo("Submitting " + shuffleStage + " (" + + shuffleStage.rdd + "), which is now runnable") + submitMissingTasks(shuffleStage, jobId) } } } @@ -1038,42 +1096,55 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is possible - // the fetch failure has already been handled by the scheduler. - if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage - } + if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + } else { - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) - } - failedStages += failedStage - failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some(failureMessage)) + } else { + logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + + s"longer running") + } - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) + } else if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + } } - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case commitDenied: TaskCommitDenied => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1111,7 +1182,7 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head) mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { @@ -1158,31 +1229,55 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return } val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq - failedStage.latestInfo.completionTime = Some(clock.getTime()) + failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") } } - /** - * Fails a job and all stages that are only used by that job, and cleans up relevant state. - */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1209,8 +1304,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1224,19 +1318,16 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } - /** - * Return true if one of stage's ancestors is target. - */ + /** Return true if one of stage's ancestors is target. */ private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true } val visitedRdds = new HashSet[RDD[_]] - val visitedStages = new HashSet[Stage] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting val waitingForVisit = new Stack[RDD[_]] @@ -1246,9 +1337,8 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { - visitedStages += mapStage waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => @@ -1258,47 +1348,55 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } visitedRdds.contains(target.rdd) } /** - * Synchronized method that might be called from other threads. + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * * @param rdd whose partitions are to be looked at * @param partition to lookup locality information for * @return list of machines that are preferred by the partition */ private[spark] - def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { getPreferredLocsInternal(rdd, partition, new HashSet) } - /** Recursive implementation for getPreferredLocs. */ + /** + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]) - : Seq[TaskLocation] = - { + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 - if (!visited.add((rdd,partition))) { + if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { + if (cached.nonEmpty) { return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { + if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + + // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency // that has any placement preferences. Ideally we would choose based on transfer sizes, // but this will do for now. rdd.dependencies.foreach { @@ -1311,29 +1409,58 @@ class DAGScheduler( } case _ => } + + // If the RDD has shuffle dependencies and shuffle locality is enabled, pick locations that + // have at least REDUCER_PREF_LOCS_FRACTION of data as preferred locations + if (shuffleLocalityEnabled && rdd.partitions.length < SHUFFLE_PREF_REDUCE_THRESHOLD) { + rdd.dependencies.foreach { + case s: ShuffleDependency[_, _, _] => + if (s.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.length, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => + } + } Nil } def stop() { logInfo("Stopping DAGScheduler") + messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { - case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, - listener, properties) + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) @@ -1362,8 +1489,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() @@ -1379,7 +1506,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.sc.stop() } - override def onStop() { + override def onStop(): Unit = { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2b6f7e4205c3..f72a52e85dc1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.Map +import scala.collection.Map import scala.language.existentials import org.apache.spark._ @@ -40,7 +40,6 @@ private[scheduler] case class JobSubmitted( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties = null) @@ -74,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 12668b6c0988..6b667d5d7645 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -46,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 30075c172bdb..5a06ef02f5c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,32 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + appAttemptId : Option[String], + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = - this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) + def this(appId: String, appAttemptId : Option[String], logBaseDir: URI, sparkConf: SparkConf) = + this(appId, appAttemptId, logBaseDir, sparkConf, + SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) + private val compressionCodec = + if (shouldCompress) { + Some(CompressionCodec.createCodec(sparkConf)) + } else { + None + } + private val compressionCodecName = compressionCodec.map { c => + CompressionCodec.getShortName(c.getClass.getName) + } // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None @@ -80,13 +91,13 @@ private[spark] class EventLoggingListener( private[scheduler] val loggedEvents = new ArrayBuffer[JValue] // Visible for tests only. - private[scheduler] val logPath = getLogPath(logBaseDir, appId) + private[scheduler] val logPath = getLogPath(logBaseDir, appId, appAttemptId, compressionCodecName) /** * Creates the log file in the configured log directory. */ def start() { - if (!fileSystem.isDirectory(new Path(logBaseDir))) { + if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) { throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.") } @@ -111,25 +122,27 @@ private[spark] class EventLoggingListener( hadoopDataStream.get } - val compressionCodec = - if (shouldCompress) { - Some(CompressionCodec.createCodec(sparkConf)) - } else { - None - } - - fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) - val logStream = initEventLog(new BufferedOutputStream(dstream, outputBufferSize), - compressionCodec) - writer = Some(new PrintWriter(logStream)) + try { + val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) + val bstream = new BufferedOutputStream(cstream, outputBufferSize) - logInfo("Logging events to %s".format(logPath)) + EventLoggingListener.initEventLog(bstream) + fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) + writer = Some(new PrintWriter(bstream)) + logInfo("Logging events to %s".format(logPath)) + } catch { + case e: Exception => + dstream.close() + throw e + } } /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) @@ -140,47 +153,63 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted) = - logEvent(event) - override def onTaskStart(event: SparkListenerTaskStart) = - logEvent(event) - override def onTaskGettingResult(event: SparkListenerTaskGettingResult) = - logEvent(event) - override def onTaskEnd(event: SparkListenerTaskEnd) = - logEvent(event) - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate) = - logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + + override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = logEvent(event) + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event) + + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { logEvent(event, flushLogger = true) - override def onJobStart(event: SparkListenerJobStart) = - logEvent(event, flushLogger = true) - override def onJobEnd(event: SparkListenerJobEnd) = - logEvent(event, flushLogger = true) - override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded) = + } + + override def onJobStart(event: SparkListenerJobStart): Unit = logEvent(event, flushLogger = true) + + override def onJobEnd(event: SparkListenerJobEnd): Unit = logEvent(event, flushLogger = true) + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved) = + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { logEvent(event, flushLogger = true) - override def onUnpersistRDD(event: SparkListenerUnpersistRDD) = + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { logEvent(event, flushLogger = true) - override def onApplicationStart(event: SparkListenerApplicationStart) = + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { logEvent(event, flushLogger = true) - override def onApplicationEnd(event: SparkListenerApplicationEnd) = + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { logEvent(event, flushLogger = true) - override def onExecutorAdded(event: SparkListenerExecutorAdded) = + } + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { logEvent(event, flushLogger = true) - override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { logEvent(event, flushLogger = true) + } + + // No-op because logging every update would be overkill + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. */ - def stop() = { + def stop(): Unit = { writer.foreach(_.close()) val target = new Path(logPath) @@ -201,77 +230,69 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" + val SPARK_VERSION_KEY = "SPARK_VERSION" + val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) - // Marker for the end of header data in a log file. After this marker, log data, potentially - // compressed, will be found. - private val HEADER_END_MARKER = "=== LOG_HEADER_END ===" - - // To avoid corrupted files causing the heap to fill up. Value is arbitrary. - private val MAX_HEADER_LINE_LENGTH = 4096 - // A cache for compression codecs to avoid creating the same codec many times private val codecMap = new mutable.HashMap[String, CompressionCodec] /** - * Write metadata about the event log to the given stream. - * - * The header is a serialized version of a map, except it does not use Java serialization to - * avoid incompatibilities between different JDKs. It writes one map entry per line, in - * "key=value" format. - * - * The very last entry in the header is the `HEADER_END_MARKER` marker, so that the parsing code - * can know when to stop. + * Write metadata about an event log to the given stream. + * The metadata is encoded in the first line of the event log as JSON. * - * The format needs to be kept in sync with the openEventLog() method below. Also, it cannot - * change in new Spark versions without some other way of detecting the change (like some - * metadata encoded in the file name). - * - * @param logStream Raw output stream to the even log file. - * @param compressionCodec Optional compression codec to use. - * @return A stream where to write event log data. This may be a wrapper around the original - * stream (for example, when compression is enabled). + * @param logStream Raw output stream to the event log file. */ - def initEventLog( - logStream: OutputStream, - compressionCodec: Option[CompressionCodec]): OutputStream = { - val meta = mutable.HashMap(("version" -> SPARK_VERSION)) - compressionCodec.foreach { codec => - meta += ("compressionCodec" -> codec.getClass().getName()) - } - - def write(entry: String) = { - val bytes = entry.getBytes(Charsets.UTF_8) - if (bytes.length > MAX_HEADER_LINE_LENGTH) { - throw new IOException(s"Header entry too long: ${entry}") - } - logStream.write(bytes, 0, bytes.length) - } - - meta.foreach { case (k, v) => write(s"$k=$v\n") } - write(s"$HEADER_END_MARKER\n") - compressionCodec.map(_.compressedOutputStream(logStream)).getOrElse(logStream) + def initEventLog(logStream: OutputStream): Unit = { + val metadata = SparkListenerLogStart(SPARK_VERSION) + val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + logStream.write(metadataJson.getBytes(Charsets.UTF_8)) } /** * Return a file-system-safe path to the log file for the given application. * + * Note that because we currently only create a single log file for each application, + * we must encode all the information needed to parse this event log in the file name + * instead of within the file itself. Otherwise, if the file is compressed, for instance, + * we won't know which codec to use to decompress the metadata needed to open the file in + * the first place. + * + * The log file name will identify the compression codec used for the contents, if any. + * For example, app_123 for an uncompressed log, app_123.lzf for an LZF-compressed log. + * * @param logBaseDir Directory where the log file will be written. * @param appId A unique app ID. + * @param appAttemptId A unique attempt id of appId. May be the empty string. + * @param compressionCodecName Name to identify the codec used to compress the contents + * of the log, or None if compression is not enabled. * @return A path which consists of file-system-safe characters. */ - def getLogPath(logBaseDir: String, appId: String): String = { - val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase - Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") + def getLogPath( + logBaseDir: URI, + appId: String, + appAttemptId: Option[String], + compressionCodecName: Option[String] = None): String = { + val base = logBaseDir.toString.stripSuffix("/") + "/" + sanitize(appId) + val codec = compressionCodecName.map("." + _).getOrElse("") + if (appAttemptId.isDefined) { + base + "_" + sanitize(appAttemptId.get) + codec + } else { + base + codec + } + } + + private def sanitize(str: String): String = { + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase } /** - * Opens an event log file and returns an input stream to the event data. + * Opens an event log file and returns an input stream that contains the event data. * - * @return 2-tuple (event input stream, Spark version of event data) + * @return input stream that holds one JSON record per line. */ - def openEventLog(log: Path, fs: FileSystem): (InputStream, String) = { + def openEventLog(log: Path, fs: FileSystem): InputStream = { // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain // IOException when a file does not exist, so try our best to throw a proper exception. if (!fs.exists(log)) { @@ -279,52 +300,17 @@ private[spark] object EventLoggingListener extends Logging { } val in = new BufferedInputStream(fs.open(log)) - // Read a single line from the input stream without buffering. - // We cannot use BufferedReader because we must avoid reading - // beyond the end of the header, after which the content of the - // file may be compressed. - def readLine(): String = { - val bytes = new ByteArrayOutputStream() - var next = in.read() - var count = 0 - while (next != '\n') { - if (next == -1) { - throw new IOException("Unexpected end of file.") - } - bytes.write(next) - count = count + 1 - if (count > MAX_HEADER_LINE_LENGTH) { - throw new IOException("Maximum header line length exceeded.") - } - next = in.read() - } - new String(bytes.toByteArray(), Charsets.UTF_8) + + // Compression codec is encoded as an extension, e.g. app_123.lzf + // Since we sanitize the app ID to not include periods, it is safe to split on it + val logName = log.getName.stripSuffix(IN_PROGRESS) + val codecName: Option[String] = logName.split("\\.").tail.lastOption + val codec = codecName.map { c => + codecMap.getOrElseUpdate(c, CompressionCodec.createCodec(new SparkConf, c)) } - // Parse the header metadata in the form of k=v pairs - // This assumes that every line before the header end marker follows this format try { - val meta = new mutable.HashMap[String, String]() - var foundEndMarker = false - while (!foundEndMarker) { - readLine() match { - case HEADER_END_MARKER => - foundEndMarker = true - case entry => - val prop = entry.split("=", 2) - if (prop.length != 2) { - throw new IllegalArgumentException("Invalid metadata in log file.") - } - meta += (prop(0) -> prop(1)) - } - } - - val sparkVersion = meta.get("version").getOrElse( - throw new IllegalArgumentException("Missing Spark version in log metadata.")) - val codec = meta.get("compressionCodec").map { codecName => - codecMap.getOrElseUpdate(codecName, CompressionCodec.createCodec(new SparkConf, codecName)) - } - (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion) + codec.map(_.compressedInputStream(in)).getOrElse(in) } catch { case e: Exception => in.close() diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index bac37bfdaa23..0e438ab4366d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -107,7 +107,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) - for (split <- list) { + for (split <- list.asScala) { retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3bb54855bae4..f96eb8ca0ae0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -57,7 +57,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val stageIdToJobId = new HashMap[Int, Int] private val jobIdToStageIds = new HashMap[Int, Seq[Int]] private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } createLogDir() @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** @@ -169,7 +171,8 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + + " LOCAL_BYTES_READ=" + metrics.localBytesRead case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 29879b374b80..382b09422a4a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -34,7 +34,7 @@ private[spark] class JobWaiter[T]( @volatile private var _jobFinished = totalTasks == 0 - def jobFinished = _jobFinished + def jobFinished: Boolean = _jobFinished // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala new file mode 100644 index 000000000000..5d926377ce86 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import org.apache.spark._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} + +private sealed trait OutputCommitCoordinationMessage extends Serializable + +private case object StopCoordinator extends OutputCommitCoordinationMessage +private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) + +/** + * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" + * policy. + * + * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is + * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to + * commit output will be forwarded to the driver's OutputCommitCoordinator. + * + * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) + * for an extensive design discussion. + */ +private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) extends Logging { + + // Initialized by SparkEnv + var coordinatorRef: Option[RpcEndpointRef] = None + + private type StageId = Int + private type PartitionId = Long + private type TaskAttemptId = Long + + /** + * Map from active stages's id => partition id => task attempt with exclusive lock on committing + * output for that partition. + * + * Entries are added to the top-level map when stages start and are removed they finish + * (either successfully or unsuccessfully). + * + * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. + */ + private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() + private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + + /** + * Called by tasks to ask whether they can commit their output to HDFS. + * + * If a task attempt has been authorized to commit, then all other attempts to commit the same + * task will be denied. If the authorized task attempt fails (e.g. due to its executor being + * lost), then a subsequent task attempt may be authorized to commit its output. + * + * @param stage the stage number + * @param partition the partition number + * @param attempt a unique identifier for this task attempt + * @return true if this task is authorized to commit, false otherwise + */ + def canCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attempt) + coordinatorRef match { + case Some(endpointRef) => + endpointRef.askWithRetry[Boolean](msg) + case None => + logError( + "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") + false + } + } + + // Called by DAGScheduler + private[scheduler] def stageStart(stage: StageId): Unit = synchronized { + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + } + + // Called by DAGScheduler + private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + authorizedCommittersByStage.remove(stage) + } + + // Called by DAGScheduler + private[scheduler] def taskCompleted( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId, + reason: TaskEndReason): Unit = synchronized { + val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + logDebug(s"Ignoring task completion for completed stage") + return + }) + reason match { + case Success => + // The task output has been committed successfully + case denied: TaskCommitDenied => + logInfo( + s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + case otherReason => + if (authorizedCommitters.get(partition).exists(_ == attempt)) { + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } + } + } + + def stop(): Unit = synchronized { + if (isDriver) { + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None + authorizedCommittersByStage.clear() + } + } + + // Marked private[scheduler] instead of private so this can be mocked in tests + private[scheduler] def handleAskPermissionToCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = synchronized { + authorizedCommittersByStage.get(stage) match { + case Some(authorizedCommitters) => + authorizedCommitters.get(partition) match { + case Some(existingCommitter) => + logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + + s"existingCommitter = $existingCommitter") + false + case None => + logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") + authorizedCommitters(partition) = attempt + true + } + case None => + logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + false + } + } +} + +private[spark] object OutputCommitCoordinator { + + // This endpoint is used only for RPC + private[spark] class OutputCommitCoordinatorEndpoint( + override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) + extends RpcEndpoint with Logging { + + override def receive: PartialFunction[Any, Unit] = { + case StopCoordinator => + logInfo("OutputCommitCoordinator stopped!") + stop() + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + context.reply( + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 174b73221afc..5821afea9898 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -74,7 +74,7 @@ private[spark] class Pool( if (schedulableNameToSchedulable.containsKey(schedulableName)) { return schedulableNameToSchedulable.get(schedulableName) } - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched @@ -84,12 +84,12 @@ private[spark] class Pool( } override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) + schedulableQueue.asScala.foreach(_.executorLost(executorId, host)) } override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { shouldRevive |= schedulable.checkSpeculatableTasks() } shouldRevive @@ -98,7 +98,7 @@ private[spark] class Pool( override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = - schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) + schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 584f4e7789d1..c6d957b65f3f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -21,6 +21,7 @@ import java.io.{InputStream, IOException} import scala.io.Source +import com.fasterxml.jackson.core.JsonParseException import org.json4s.jackson.JsonMethods._ import org.apache.spark.Logging @@ -39,22 +40,40 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * error is thrown by this method. * * @param logData Stream containing event log data. - * @param version Spark version that generated the events. + * @param sourceName Filename (or other source identifier) from whence @logData is being read + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * encountered, log file might not finished writing) or not */ - def replay(logData: InputStream, version: String) { + def replay( + logData: InputStream, + sourceName: String, + maybeTruncated: Boolean = false): Unit = { var currentLine: String = null + var lineNumber: Int = 1 try { val lines = Source.fromInputStream(logData).getLines() - lines.foreach { line => - currentLine = line - postToAll(JsonProtocol.sparkEventFromJson(parse(line))) + while (lines.hasNext) { + currentLine = lines.next() + try { + postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) + } catch { + case jpe: JsonParseException => + // We can only ignore exception from last line of the file that might be truncated + if (!maybeTruncated || lines.hasNext) { + throw jpe + } else { + logWarning(s"Got JsonParseException from log file $sourceName" + + s" at line $lineNumber, the file might not have finished writing cleanly.") + } + } + lineNumber += 1 } } catch { case ioe: IOException => throw ioe case e: Exception => - logError("Exception in parsing Spark event log.", e) - logError("Malformed line: %s\n".format(currentLine)) + logError(s"Exception parsing Spark event log: $sourceName", e) + logError(s"Malformed line #$lineNumber: $currentLine\n") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala new file mode 100644 index 000000000000..bf81b9aca481 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite + +/** + * The ResultStage represents the final stage in a job. + */ +private[spark] class ResultStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + firstJobId: Int, + callSite: CallSite) + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + + // The active job for this result stage. Will be empty if the job has already finished + // (e.g., because the job was cancelled). + var resultOfJob: Option[ActiveJob] = None + + override def toString: String = "ResultStage " + id +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 4a9ff918afe2..c4dc080e2b22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -41,11 +41,14 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { + val outputId: Int, + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq @@ -53,9 +56,11 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) func(context, rdd.iterator(partition, context)) @@ -64,5 +69,5 @@ private[spark] class ResultTask[T, U]( // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" + override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 992c477493d8..8801a761afae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -41,4 +41,19 @@ private[spark] trait SchedulerBackend { */ def applicationId(): String = appId + /** + * Get the attempt ID for this run, if the cluster manager supports multiple + * attempts. Applications run in client mode will not have attempt IDs. + * + * @return The application attempt id, if available. + */ + def applicationAttemptId(): Option[String] = None + + /** + * Get the URLs for the driver logs. These URLs are used to display the links in the UI + * Executors tab for the driver. + * @return Map containing the log names and their respective URLs + */ + def getDriverLogUrls: Option[Map[String, String]] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f00..864941d468af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -56,7 +56,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare:Int = 0 + var compare: Int = 0 if (s1Needy && !s2Needy) { return true diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala new file mode 100644 index 000000000000..48d8d8e9c4b7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.ShuffleDependency +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite + +/** + * The ShuffleMapStage represents the intermediate stages in a job. + */ +private[spark] class ShuffleMapStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + firstJobId: Int, + callSite: CallSite, + val shuffleDep: ShuffleDependency[_, _, _]) + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + + override def toString: String = "ShuffleMapStage " + id + + var numAvailableOutputs: Int = 0 + + def isAvailable: Boolean = numAvailableOutputs == numPartitions + + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + + def addOutputLoc(partition: Int, status: MapStatus): Unit = { + val prevList = outputLocs(partition) + outputLocs(partition) = status :: prevList + if (prevList == Nil) { + numAvailableOutputs += 1 + } + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 79709089c0da..f478f9982afe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,21 +33,24 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { + @transient private var locs: Seq[TaskLocation], + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -56,9 +59,11 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null @@ -66,7 +71,7 @@ private[spark] class ShuffleMapTask( val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => try { @@ -83,5 +88,5 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index dd28ddb31de1..896f1743332f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -110,12 +113,22 @@ case class SparkListenerExecutorMetricsUpdate( extends SparkListenerEvent @DeveloperApi -case class SparkListenerApplicationStart(appName: String, appId: Option[String], time: Long, - sparkUser: String) extends SparkListenerEvent +case class SparkListenerApplicationStart( + appName: String, + appId: Option[String], + time: Long, + sparkUser: String, + appAttemptId: Option[String], + driverLogs: Option[Map[String, String]] = None) extends SparkListenerEvent @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent +/** + * An internal class that describes the metadata of an event log. + * This event is not meant to be posted to listeners downstream. + */ +private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * :: DeveloperApi :: @@ -205,6 +218,11 @@ trait SparkListener { * Called when the driver removes an executor. */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + + /** + * Called when the driver receives a block update info. + */ + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } } /** @@ -260,7 +278,7 @@ class StatsReportListener extends SparkListener with Logging { private[spark] object StatsReportListener extends Logging { // For profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) val probabilities = percentiles.map(_ / 100.0) val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" @@ -294,8 +312,8 @@ private[spark] object StatsReportListener extends Logging { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d: Double) = format.format(d) + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { + def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -308,7 +326,7 @@ private[spark] object StatsReportListener extends Logging { } def showBytesDistribution( - heading:String, + heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long], taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) @@ -341,7 +359,7 @@ private[spark] object StatsReportListener extends Logging { /** * Reformat a time interval in milliseconds to a prettier format for output */ - def millisToString(ms: Long) = { + def millisToString(ms: Long): String = { val (size, units) = if (ms > hours) { (ms.toDouble / hours, "hours") diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index fe8a19a2c0cb..04afde33f5aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,9 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case blockUpdated: SparkListenerBlockUpdated => + listener.onBlockUpdated(blockUpdated) + case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b8..1cf06856ffbc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,6 @@ import scala.collection.mutable.HashSet import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -35,7 +34,7 @@ import org.apache.spark.util.CallSite * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes * that each output partition is on. * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * @@ -47,97 +46,66 @@ import org.apache.spark.util.CallSite * be updated for each attempt. * */ -private[spark] class Stage( +private[spark] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], - val jobId: Int, + val firstJobId: Int, val callSite: CallSite) extends Logging { - val isShuffleMap = shuffleDep.isDefined val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ - var resultOfJob: Option[ActiveJob] = None var pendingTasks = new HashSet[Task[_]] - private var nextAttemptId = 0 + /** The ID to use for the next new attempt for this stage. */ + private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ - var latestInfo: StageInfo = StageInfo.fromStage(this) + private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } + /** Internal accumulators shared across all tasks in this stage. */ + def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. + * Re-initialize the internal accumulators associated with this stage. + * + * This is called every time the stage is submitted, *except* when a subset of tasks + * belonging to this stage has already finished. Otherwise, reinitializing the internal + * accumulators here again will override partial values from the finished tasks. */ - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } + def resetInternalAccumulators(): Unit = { + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) } - /** Return a new attempt id, starting with 0. */ - def newAttemptId(): Int = { - val id = nextAttemptId + /** + * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized + * here, before any attempts have actually been created, because the DAGScheduler uses this + * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts + * have been created). + */ + private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) + + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ + def makeNewStageAttempt( + numPartitionsToCompute: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { + _latestInfo = StageInfo.fromStage( + this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) nextAttemptId += 1 - id } - def attemptId: Int = nextAttemptId - - override def toString = "Stage " + id - - override def hashCode(): Int = id + /** Returns the StageInfo for the most recent attempt for this stage. */ + def latestInfo: StageInfo = _latestInfo - override def equals(other: Any): Boolean = other match { + override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index c6dc3369ba5c..24796c14300b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,7 +33,9 @@ class StageInfo( val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], - val details: String) { + val parentIds: Seq[Int], + val details: String, + private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -47,6 +49,18 @@ class StageInfo( failureReason = Some(reason) completionTime = Some(System.currentTimeMillis) } + + private[spark] def getStatusString: String = { + if (completionTime.isDefined) { + if (failureReason.isDefined) { + "failed" + } else { + "succeeded" + } + } else { + "running" + } + } } private[spark] object StageInfo { @@ -57,15 +71,22 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { + def fromStage( + stage: Stage, + attemptId: Int, + numTasks: Option[Int] = None, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( stage.id, - stage.attemptId, + attemptId, stage.name, numTasks.getOrElse(stage.numTasks), rddInfos, - stage.details) + stage.parents.map(_.id), + stage.details, + taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 847a4912eec1..9edf9f048f9f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,9 +22,11 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -42,32 +44,71 @@ import org.apache.spark.util.Utils * @param stageId id of the stage this task belongs to * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { +private[spark] abstract class Task[T]( + val stageId: Int, + val stageAttemptId: Int, + val partitionId: Int, + internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { /** - * Called by Executor to run this task. + * The key of the Map is the accumulator id and the value of the Map is the latest accumulator + * local value. + */ + type AccumulatorUpdates = Map[Long, Any] + + /** + * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) - * @return the result of the task + * @return the result of the task along with updates of Accumulators. */ - final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + final def run( + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem) + : (T, AccumulatorUpdates) = { + context = new TaskContextImpl( + stageId, + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + metricsSystem, + internalAccumulators, + runningLocally = false) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) + context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - runTask(context) + (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + } + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } + } finally { + TaskContext.unset() + } } } + private var taskMemoryManager: TaskMemoryManager = _ + + def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { + this.taskMemoryManager = taskMemoryManager + } + def runTask(context: TaskContext): T def preferredLocations: Seq[TaskLocation] = Nil @@ -87,11 +128,18 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // initialized when kill() is invoked. @volatile @transient private var _killed = false + protected var _executorDeserializeTime: Long = 0 + /** * Whether the task has been killed. */ def killed: Boolean = _killed + /** + * Returns the amount of time spent deserializing the RDD and function to be run. + */ + def executorDeserializeTime: Long = _executorDeserializeTime + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can @@ -106,7 +154,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6fa1f2c880f7..132a9ced7770 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,9 +81,11 @@ class TaskInfo( def status: String = { if (running) { - "RUNNING" - } else if (gettingResult) { - "GET RESULT" + if (gettingResult) { + "GET RESULT" + } else { + "RUNNING" + } } else if (failed) { "FAILED" } else if (successful) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 10c685f29d3a..da07ce2c6ea4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -29,23 +29,22 @@ private[spark] sealed trait TaskLocation { /** * A location that includes both a host and an executor id on that host. */ -private [spark] case class ExecutorCacheTaskLocation(override val host: String, - val executorId: String) extends TaskLocation { -} +private [spark] +case class ExecutorCacheTaskLocation(override val host: String, executorId: String) + extends TaskLocation /** * A location on a host. */ private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { - override def toString = host + override def toString: String = host } /** * A location on a host that is cached by HDFS. */ -private [spark] case class HDFSCacheTaskLocation(override val host: String) - extends TaskLocation { - override def toString = TaskLocation.inMemoryLocationTag + host +private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation { + override def toString: String = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { @@ -54,14 +53,16 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" - def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) + def apply(host: String, executorId: String): TaskLocation = { + new ExecutorCacheTaskLocation(host, executorId) + } /** * Create a TaskLocation from a string returned by getPreferredLocations. * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the * location is cached. */ - def apply(str: String) = { + def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { new HostTaskLocation(str) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 1f114a0207f7..b82c7f3fa54f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.executor.TaskMetrics @@ -40,6 +41,9 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long var metrics: TaskMetrics) extends TaskResult[T] with Externalizable { + private var valueObjectDeserialized = false + private var valueObject: T = _ + def this() = this(null.asInstanceOf[ByteBuffer], null, null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { @@ -66,16 +70,33 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - accumUpdates = Map() + val _accumUpdates = mutable.Map[Long, Any]() for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() + _accumUpdates(in.readLong()) = in.readObject() } + accumUpdates = _accumUpdates } metrics = in.readObject().asInstanceOf[TaskMetrics] + valueObjectDeserialized = false } + /** + * When `value()` is called at the first time, it needs to deserialize `valueObject` from + * `valueBytes`. It may cost dozens of seconds for a large instance. So when calling `value` at + * the first time, the caller should avoid to block other threads. + * + * After the first time, `value()` is trivial and just returns the deserialized `valueObject`. + */ def value(): T = { - val resultSer = SparkEnv.get.serializer.newInstance() - resultSer.deserialize(valueBytes) + if (valueObjectDeserialized) { + valueObject + } else { + // This should not run when holding a lock because it may cost dozens of seconds for a large + // value. + val resultSer = SparkEnv.get.serializer.newInstance() + valueObject = resultSer.deserialize(valueBytes) + valueObjectDeserialized = true + valueObject + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 774f3d8cdb27..46a6f6537e2e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.concurrent.RejectedExecutionException import scala.language.existentials import scala.util.control.NonFatal @@ -25,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. @@ -34,7 +35,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( THREADS, "task-result-getter") protected val serializer = new ThreadLocal[SerializerInstance] { @@ -53,6 +54,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { return } + // deserialize "value" without holding any lock so that it won't block other threads. + // We should call it here, so that when it's called again in + // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. + directResult.value() (directResult, serializedData.limit()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { @@ -95,25 +100,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { var reason : TaskEndReason = UnknownReason - getTaskResultExecutor.execute(new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions { - try { - if (serializedData != null && serializedData.limit() > 0) { - reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + try { + getTaskResultExecutor.execute(new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions { + try { + if (serializedData != null && serializedData.limit() > 0) { + reason = serializer.get().deserialize[TaskEndReason]( + serializedData, Utils.getSparkClassLoader) + } + } catch { + case cnd: ClassNotFoundException => + // Log an error but keep going here -- the task failed, so not catastrophic + // if we can't deserialize the reason. + val loader = Utils.getContextOrSparkClassLoader + logError( + "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) + case ex: Exception => {} } - } catch { - case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastrophic if we can't - // deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader - logError( - "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) - } - }) + }) + } catch { + case e: RejectedExecutionException if sparkEnv.isStopped => + // ignore it + } } def stop() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index f095915352b1..f25f3ed0d903 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -74,4 +74,16 @@ private[spark] trait TaskScheduler { */ def applicationId(): String = appId + /** + * Process a lost executor + */ + def executorLost(executorId: String, reason: ExecutorLossReason): Unit + + /** + * Get an application's attempt ID associated with the job. + * + * @return An application's Attempt ID + */ + def applicationAttemptId(): Option[String] + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 79f84e70df9d..1705e7f962de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{TimerTask, Timer} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -62,19 +62,22 @@ private[spark] class TaskSchedulerImpl( val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100) + val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") + + private val speculationScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation") // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000) + val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") // CPUs to request per task val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] + private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToTaskSetId = new HashMap[Long, String] + private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -142,11 +145,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { - Utils.tryOrExit { checkSpeculatableTasks() } - } + speculationScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + checkSpeculatableTasks() + } + }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS) } } @@ -158,8 +161,18 @@ private[spark] class TaskSchedulerImpl( val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager + val manager = createTaskSetManager(taskSet, maxTaskFailures) + val stage = taskSet.stageId + val stageTaskSets = + taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.stageAttemptId) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie + } + if (conflictingTaskSet) { + throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -173,28 +186,37 @@ private[spark] class TaskSchedulerImpl( this.cancel() } } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS) } hasReceivedTask = true } backend.reviveOffers() } + // Label as private[scheduler] to allow tests to swap in different task set managers if necessary + private[scheduler] def createTaskSetManager( + taskSet: TaskSet, + maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures) + } + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) } } @@ -204,7 +226,12 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id + taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.stageAttemptId + if (taskSetsForStage.isEmpty) { + taskSetsByStageIdAndAttempt -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -225,7 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -309,26 +336,24 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskIdToTaskSetManager.get(tid) match { + case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") - .format(state, tid)) + "likely the result of receiving duplicate task finished status updates)") + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -353,9 +378,9 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) - .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) + taskIdToTaskSetManager.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + } } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -366,17 +391,17 @@ private[spark] class TaskSchedulerImpl( } def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]): Unit = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: TaskEndReason) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: TaskEndReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to @@ -387,9 +412,12 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.size > 0) { + if (taskSetsByStageIdAndAttempt.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for { + attempts <- taskSetsByStageIdAndAttempt.values + manager <- attempts.values + } { try { manager.abort(message) } catch { @@ -400,13 +428,13 @@ private[spark] class TaskSchedulerImpl( // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) + throw new SparkException(s"Exiting due to error from cluster scheduler: $message") } } } override def stop() { + speculationScheduler.shutdown() if (backend != null) { backend.stop() } @@ -416,7 +444,7 @@ private[spark] class TaskSchedulerImpl( starvationTimer.cancel() } - override def defaultParallelism() = backend.defaultParallelism() + override def defaultParallelism(): Int = backend.defaultParallelism() // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { @@ -429,7 +457,7 @@ private[spark] class TaskSchedulerImpl( } } - def executorLost(executorId: String, reason: ExecutorLossReason) { + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = { var failedExecutor: Option[String] = None synchronized { @@ -508,6 +536,19 @@ private[spark] class TaskSchedulerImpl( override def applicationId(): String = backend.applicationId() + override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() + + private[scheduler] def taskSetManagerForAttempt( + stageId: Int, + stageAttemptId: Int): Option[TaskSetManager] = { + for { + attempts <- taskSetsByStageIdAndAttempt.get(stageId) + manager <- attempts.get(stageAttemptId) + } yield { + manager + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156f5..be8526ba9b94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -26,10 +26,10 @@ import java.util.Properties private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, - val attempt: Int, + val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 97c22fe724ab..818b95d67f6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.Arrays +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -29,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -44,14 +46,14 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for - * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * @param maxTaskFailures if any particular task fails this number of times, the entire * task set will be aborted */ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Schedulable with Logging { val conf = sched.sc.conf @@ -97,7 +99,8 @@ private[spark] class TaskSetManager( var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] - override def runningTasks = runningTasksSet.size + + override def runningTasks: Int = runningTasksSet.size // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the @@ -166,11 +169,11 @@ private[spark] class TaskSetManager( // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level - override def schedulableQueue = null + override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null - override def schedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.NONE var emittedTaskSizeWarning = false @@ -281,7 +284,7 @@ private[spark] class TaskSetManager( val failed = failedExecutors.get(taskId).get return failed.contains(execId) && - clock.getTime() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT + clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT } false @@ -292,7 +295,8 @@ private[spark] class TaskSetManager( * an attempt running on this host, in case the host is slow. In addition, the task should meet * the given locality constraint. */ - private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + // Labeled as protected to allow tests to override providing speculative tasks if necessary + protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) : Option[(Int, TaskLocality.Value)] = { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set @@ -427,7 +431,7 @@ private[spark] class TaskSetManager( : Option[TaskDescription] = { if (!isZombie) { - val curTime = clock.getTime() + val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -458,7 +462,7 @@ private[spark] class TaskSetManager( lastLaunchTime = curTime } // Serialize and return the task - val startTime = clock.getTime() + val startTime = clock.getTimeMillis() val serializedTask: ByteBuffer = try { Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) } catch { @@ -506,13 +510,64 @@ private[spark] class TaskSetManager( * Get the level we can launch tasks according to delay scheduling, based on current wait time. */ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 + // Remove the scheduled or finished tasks lazily + def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = { + var indexOffset = pendingTaskIds.size + while (indexOffset > 0) { + indexOffset -= 1 + val index = pendingTaskIds(indexOffset) + if (copiesRunning(index) == 0 && !successful(index)) { + return true + } else { + pendingTaskIds.remove(indexOffset) + } + } + false + } + // Walk through the list of tasks that can be scheduled at each location and returns true + // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have + // already been scheduled. + def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = { + val emptyKeys = new ArrayBuffer[String] + val hasTasks = pendingTasks.exists { + case (id: String, tasks: ArrayBuffer[Int]) => + if (tasksNeedToBeScheduledFrom(tasks)) { + true + } else { + emptyKeys += id + false + } + } + // The key could be executorId, host or rackId + emptyKeys.foreach(id => pendingTasks.remove(id)) + hasTasks + } + + while (currentLocalityIndex < myLocalityLevels.length - 1) { + val moreTasks = myLocalityLevels(currentLocalityIndex) match { + case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor) + case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost) + case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty + case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack) + } + if (!moreTasks) { + // This is a performance optimization: if there are no more tasks that can + // be scheduled at a particular locality level, there is no point in waiting + // for the locality wait timeout (SPARK-4939). + lastLaunchTime = curTime + logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " + + s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}") + currentLocalityIndex += 1 + } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) { + // Jump to the next locality level, and reset lastLaunchTime so that the next locality + // wait timer doesn't immediately expire + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " + + s"${localityWaits(currentLocalityIndex)}ms") + } else { + return myLocalityLevels(currentLocalityIndex) + } } myLocalityLevels(currentLocalityIndex) } @@ -533,7 +588,7 @@ private[spark] class TaskSetManager( /** * Marks the task as getting result and notifies the DAG Scheduler */ - def handleTaskGettingResult(tid: Long) = { + def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) @@ -560,11 +615,17 @@ private[spark] class TaskSetManager( /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index info.markSuccessful() removeRunningTask(tid) + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. sched.dagScheduler.taskEnded( tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { @@ -601,7 +662,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -610,6 +671,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -622,7 +684,7 @@ private[spark] class TaskSetManager( return } val key = ef.description - val now = clock.getTime() + val now = clock.getTimeMillis() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) @@ -645,35 +707,41 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). - put(info.executorId, clock.getTime()) + put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED) { + if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { + // If a task failed because its attempt to commit was denied, do not count this failure + // towards failing the stage. This is intended to prevent spurious stage failures in cases + // where many speculative tasks are launched and denied to commit. assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } @@ -717,10 +785,10 @@ private[spark] class TaskSetManager( // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because dequeueTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, @@ -766,7 +834,7 @@ private[spark] class TaskSetManager( val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() + val time = clock.getTimeMillis() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) @@ -790,15 +858,18 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - conf.get("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - conf.get("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - conf.get("spark.locality.wait.rack", defaultWait).toLong - case _ => 0L + val defaultWait = conf.get("spark.locality.wait", "3s") + val localityWaitKey = level match { + case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" + case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" + case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" + case _ => null + } + + if (localityWaitKey != null) { + conf.getTimeAsMs(localityWaitKey, defaultWait) + } else { + 0L } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5..06f5438433b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -39,7 +40,12 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + case class RegisterExecutor( + executorId: String, + executorRef: RpcEndpointRef, + hostPort: String, + cores: Int, + logUrls: Map[String, String]) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } @@ -66,18 +72,25 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) + case class AddWebUIFilter( + filterName: String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running - case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + case class RequestExecutors( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]) + extends CoarseGrainedClusterMessage case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 103a5c053c28..5730a87f960a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -57,8 +52,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds // if minRegisteredRatio has not yet been reached - val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) + val maxRegisteredWaitingTimeMs = + conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s") val createTime = System.currentTimeMillis() private val executorDataMap = new HashMap[String, ExecutorData] @@ -71,48 +66,39 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + // A map to store hostname with its possible task number running on it + protected var hostToLocalTaskCount: Map[String, Int] = Map.empty + + // The number of pending tasks which is locality required + protected var localityAwareTasks = 0 + + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { + + // If this DriverEndpoint is changed to support multiple threads, + // then this may need to be changed so that we don't share the serializer + // instance across threads + private val ser = SparkEnv.get.closureSerializer.newInstance() + override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val addressToExecutorId = new HashMap[RpcAddress, String] - // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") - def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor + override def onStart() { + // Periodically revive offers to allow delay scheduling to work + val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post( - SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) - makeOffers() + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveIntervalMs, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -123,7 +109,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - "from unknown executor $sender with ID $executorId") + s"from unknown executor with ID $executorId") } } @@ -133,64 +119,100 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + // Note: some tests expect the reply to come after we put the executor in the map + context.reply(RegisteredExecutor) + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } + case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors - def makeOffers() { - launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + private def makeOffers() { + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq)) + }.toSeq + launchTasks(scheduler.resourceOffers(workOffers)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) } // Make fake resource offers on just one executor - def makeOffers(executorId: String) { - val executorData = executorDataMap(executorId) - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) + private def makeOffers(executorId: String) { + // Filter out executors under killing + if (!executorsPendingToRemove.contains(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = Seq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + launchTasks(scheduler.resourceOffers(workOffers)) + } } // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + "spark.akka.frameSize or using broadcast variables for large values." msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, AkkaUtils.reservedSizeBytes) - taskSet.abort(msg) + taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) } @@ -199,7 +221,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -211,6 +233,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { + addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId executorsPendingToRemove -= executorId } @@ -219,12 +242,16 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste scheduler.executorLost(executorId, SlaveLost(reason)) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -234,17 +261,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste properties += ((key, value)) } } + // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithRetry[Boolean](StopExecutors) } } catch { case e: Exception => @@ -255,22 +282,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithRetry[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -280,11 +306,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -296,9 +321,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } - if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { + if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTimeMs) { logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTimeMs(ms)") return true } false @@ -311,17 +336,56 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + if (numAdditionalExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of additional executor(s) " + + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") + } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors // Account for executors pending to be added or removed val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size doRequestTotalExecutors(newTotal) } + /** + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + final override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int] + ): Boolean = synchronized { + if (numExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of executor(s) " + + s"$numExecutors from the cluster manager. Please specify a positive number!") + } + + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + + numPendingExecutors = + math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) + doRequestTotalExecutors(numExecutors) + } + /** * Request executors from the cluster manager by specifying the total number desired, * including existing pending and running executors. @@ -332,36 +396,55 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste * insufficient resources to satisfy the first request. We make the assumption here that the * cluster manager will eventually fulfill all requests when resources free up. * - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false /** * Request that the cluster manager kill the specified executors. - * Return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { + killExecutors(executorIds, replace = false) + } + + /** + * Request that the cluster manager kill the specified executors. + * + * @param executorIds identifiers of executors to kill + * @param replace whether to replace the killed executors with new ones + * @return whether the kill request is acknowledged. + */ + final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") - val filteredExecutorIds = new ArrayBuffer[String] - executorIds.foreach { id => - if (executorDataMap.contains(id)) { - filteredExecutorIds += id - } else { - logWarning(s"Executor to kill $id does not exist!") - } + val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) + unknownExecutors.foreach { id => + logWarning(s"Executor to kill $id does not exist!") } - executorsPendingToRemove ++= filteredExecutorIds - doKillExecutors(filteredExecutorIds) + + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + executorsPendingToRemove ++= executorsToKill + + // If we do not wish to replace the executors we kill, sync the target number of executors + // with the cluster manager to avoid allocating new ones. When computing the new target, + // take into account executors that are pending to be added or removed. + if (!replace) { + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } + + doKillExecutors(executorsToKill) } /** * Kill the given list of executors through the cluster manager. - * Return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. */ protected def doKillExecutors(executorIds: Seq[String]): Boolean = false } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index eb52ddfb1eab..626a2b7d69ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,21 +17,22 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, - override val totalCores: Int -) extends ExecutorInfo(executorHost, totalCores) + override val totalCores: Int, + override val logUrlMap: Map[String, String] +) extends ExecutorInfo(executorHost, totalCores, logUrlMap) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala index b4738e64c939..7f218566146a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -25,8 +25,8 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi class ExecutorInfo( val executorHost: String, - val totalCores: Int -) { + val totalCores: Int, + val logUrlMap: Map[String, String]) { def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] @@ -34,12 +34,13 @@ class ExecutorInfo( case that: ExecutorInfo => (that canEqual this) && executorHost == that.executorHost && - totalCores == that.totalCores + totalCores == that.totalCores && + logUrlMap == that.logUrlMap case _ => false } override def hashCode(): Int = { - val state = Seq(executorHost, totalCores) + val state = Seq(executorHost, totalCores, logUrlMap) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 06786a59524e..0324c9dab910 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index d2e1680a5fd1..bbe51b4a09a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,43 +17,48 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.Semaphore + +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { - var client: AppClient = null - var stopping = false - var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - @volatile var appId: String = _ + private var client: AppClient = null + private var stopping = false + + @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ + @volatile private var appId: String = _ - val registrationLock = new Object() - var registrationDone = false + private val registrationBarrier = new Semaphore(0) - val maxCores = conf.getOption("spark.cores.max").map(_.toInt) - val totalExpectedCores = maxCores.getOrElse(0) + private val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + private val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() // The endpoint for executors to talk to us - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", - "{{WORKER_URL}}") + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val args = Seq( + "--driver-url", driverUrl, + "--executor-id", "{{EXECUTOR_ID}}", + "--hostname", "{{HOSTNAME}}", + "--cores", "{{CORES}}", + "--app-id", "{{APP_ID}}", + "--worker-url", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath") @@ -77,12 +82,11 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir) - - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, + command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() - waitForRegistration() } @@ -90,8 +94,10 @@ private[spark] class SparkDeploySchedulerBackend( stopping = true super.stop() client.stop() - if (shutdownCallback != null) { - shutdownCallback(this) + + val callback = shutdownCallback + if (callback != null) { + callback(this) } } @@ -112,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend( notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) - scheduler.error(reason) - // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + try { + scheduler.error(reason) + } finally { + // Ensure the application terminates, as we can no longer run jobs. + sc.stop() + } } } @@ -143,19 +152,40 @@ private[spark] class SparkDeploySchedulerBackend( super.applicationId } - private def waitForRegistration() = { - registrationLock.synchronized { - while (!registrationDone) { - registrationLock.wait() - } + /** + * Request executors from the Master by specifying the total number desired, + * including existing pending and running executors. + * + * @return whether the request is acknowledged. + */ + protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + Option(client) match { + case Some(c) => c.requestTotalExecutors(requestedTotal) + case None => + logWarning("Attempted to request executors before driver fully initialized.") + false } } - private def notifyContext() = { - registrationLock.synchronized { - registrationDone = true - registrationLock.notifyAll() + /** + * Kill the given list of executors through the Master. + * @return whether the kill request is acknowledged. + */ + protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + Option(client) match { + case Some(c) => c.killExecutors(executorIds) + case None => + logWarning("Attempted to kill executors before driver fully initialized.") + false } } + private def waitForRegistration() = { + registrationBarrier.acquire() + } + + private def notifyContext() = { + registrationBarrier.release() + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f14aaeea0a25..044f6288fabd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,14 +19,12 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, RpcUtils} import scala.util.control.NonFatal @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,28 +43,25 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +91,73 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private val askAmThreadPool = + ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + override def receive: PartialFunction[Any, Unit] = { + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) - override def receive = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithRetry[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithRetry[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit = { + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 0d1c2a916ca7..452c32d5411c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,23 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} -import java.util.Collections +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} +import com.google.common.collect.HashBiMap +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} +import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.Utils /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -46,22 +49,20 @@ import org.apache.spark.util.{Utils, AkkaUtils} private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, - master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + master: String, + securityManager: SecurityManager) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler - with Logging { + with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + + // If shuffle service is enabled, the Spark driver will register with the shuffle service. + // This is for cleaning up shuffle files reliably. + private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] @@ -69,12 +70,50 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + // Maping from slave Id to hostname + private val slaveIdToHost = new HashMap[String, String] + + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + private val pendingRemovedSlaveIds = new HashSet[String] + + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + + // A client for talking to the external shuffle service, if it is a + private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { + if (shuffleServiceEnabled) { + Some(new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled())) + } else { + None + } + } + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -87,29 +126,12 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val driver = createSchedulerDriver( + master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -143,51 +165,61 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(sc.env.actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - val uri = conf.get("spark.executor.uri", null) - if (uri == null) { + + val uri = conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + + if (uri.isEmpty) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( - prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" + .format(prefixEnv, runScript) + + s" --driver-url $driverURL" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( - ("cd %s*; %s " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") - .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + s"cd $basename*; $prefixEnv " + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) + } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) } + command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue + mesosExternalShuffleClient.foreach(_.init(appId)) logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } + markRegistered() } override def disconnected(d: SchedulerDriver) {} @@ -199,15 +231,19 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(-1).build() - - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString + stateLock.synchronized { + val filters = Filters.newBuilder().setRefuseSeconds(5).build() + for (offer <- offers.asScala) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -215,107 +251,139 @@ private[spark] class CoarseMesosSchedulerBackend( val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId + taskIdToSlaveId.put(taskId, slaveId) slaveIdsWithExecutors += slaveId coresByTaskId(taskId) = cpusToUse - val task = MesosTaskInfo.newBuilder() + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) - .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) - .build() + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + } + + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task), filters) + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } - - /** Check whether a Mesos task state represents a finished task */ - private def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState - logInfo("Mesos task " + taskId + " is now " + state) - synchronized { - if (isFinished(state)) { - val slaveId = taskIdToSlaveId(taskId) + logInfo(s"Mesos task $taskId is now $state") + val slaveId: String = status.getSlaveId.getValue + stateLock.synchronized { + // If the shuffle service is enabled, have the driver register with each one of the + // shuffle services. This allows the shuffle services to clean up state associated with + // this application when the driver exits. There is currently not a great way to detect + // this through Mesos, since the shuffle services are set up independently. + if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && + slaveIdToHost.contains(slaveId) && + shuffleServiceEnabled) { + assume(mesosExternalShuffleClient.isDefined, + "External shuffle client was not instantiated even though shuffle service is enabled.") + // TODO: Remove this and allow the MesosExternalShuffleService to detect + // framework termination when new Mesos Framework HTTP API is available. + val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) + val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + + s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get + .registerDriverWithShuffleService(hostname, externalShufflePort) + } + + if (TaskState.isFinished(TaskState.fromMesos(state))) { + val slaveId = taskIdToSlaveId.get(taskId) slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId + taskIdToSlaveId.remove(taskId) // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { + if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { - logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + + logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + executorTerminated(d, slaveId, s"Executor finished with state $state") + // In case we'd rejected everything before but have now lost a node + d.reviveOffers() } } } override def error(d: SchedulerDriver, message: String) { - logError("Mesos error: " + message) + logError(s"Mesos error: $message") scheduler.error(message) } override def stop() { super.stop() - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.containsKey(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo(s"Mesos slave lost: ${slaveId.getValue}") + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -326,4 +394,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.containsKey(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala deleted file mode 100644 index 5101ec8352e7..000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.apache.spark.SparkContext - -private[spark] object MemoryUtils { - // These defaults copied from YARN - val OVERHEAD_FRACTION = 1.07 - val OVERHEAD_MINIMUM = 384 - - def calculateTotalMemory(sc: SparkContext) = { - math.max( - sc.conf.getOption("spark.mesos.executor.memoryOverhead") - .getOrElse(OVERHEAD_MINIMUM.toString) - .toInt + sc.executorMemory, - OVERHEAD_FRACTION * sc.executorMemory - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala new file mode 100644 index 000000000000..e0c547dce6d0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConverters._ + +import org.apache.curator.framework.CuratorFramework +import org.apache.zookeeper.CreateMode +import org.apache.zookeeper.KeeperException.NoNodeException + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.util.Utils + +/** + * Persistence engine factory that is responsible for creating new persistence engines + * to store Mesos cluster mode state. + */ +private[spark] abstract class MesosClusterPersistenceEngineFactory(conf: SparkConf) { + def createEngine(path: String): MesosClusterPersistenceEngine +} + +/** + * Mesos cluster persistence engine is responsible for persisting Mesos cluster mode + * specific state, so that on failover all the state can be recovered and the scheduler + * can resume managing the drivers. + */ +private[spark] trait MesosClusterPersistenceEngine { + def persist(name: String, obj: Object): Unit + def expunge(name: String): Unit + def fetch[T](name: String): Option[T] + def fetchAll[T](): Iterable[T] +} + +/** + * Zookeeper backed persistence engine factory. + * All Zk engines created from this factory shares the same Zookeeper client, so + * all of them reuses the same connection pool. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) + extends MesosClusterPersistenceEngineFactory(conf) { + + lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + + def createEngine(path: String): MesosClusterPersistenceEngine = { + new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) + } +} + +/** + * Black hole persistence engine factory that creates black hole + * persistence engines, which stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngineFactory + extends MesosClusterPersistenceEngineFactory(null) { + def createEngine(path: String): MesosClusterPersistenceEngine = { + new BlackHoleMesosClusterPersistenceEngine + } +} + +/** + * Black hole persistence engine that stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngine extends MesosClusterPersistenceEngine { + override def persist(name: String, obj: Object): Unit = {} + override def fetch[T](name: String): Option[T] = None + override def expunge(name: String): Unit = {} + override def fetchAll[T](): Iterable[T] = Iterable.empty[T] +} + +/** + * Zookeeper based Mesos cluster persistence engine, that stores cluster mode state + * into Zookeeper. Each engine object is operating under one folder in Zookeeper, but + * reuses a shared Zookeeper client. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngine( + baseDir: String, + zk: CuratorFramework, + conf: SparkConf) + extends MesosClusterPersistenceEngine with Logging { + private val WORKING_DIR = + conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir + + SparkCuratorUtil.mkdir(zk, WORKING_DIR) + + def path(name: String): String = { + WORKING_DIR + "/" + name + } + + override def expunge(name: String): Unit = { + zk.delete().forPath(path(name)) + } + + override def persist(name: String, obj: Object): Unit = { + val serialized = Utils.serialize(obj) + val zkPath = path(name) + zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized) + } + + override def fetch[T](name: String): Option[T] = { + val zkPath = path(name) + + try { + val fileData = zk.getData().forPath(zkPath) + Some(Utils.deserialize[T](fileData)) + } catch { + case e: NoNodeException => None + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(zkPath) + None + } + } + } + + override def fetchAll[T](): Iterable[T] = { + zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T]) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala new file mode 100644 index 000000000000..07da9242b992 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -0,0 +1,658 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.io.File +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Date, List => JList} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.Protos.TaskStatus.Reason +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.{Scheduler, SchedulerDriver} + +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} + + +/** + * Tracks the current state of a Mesos Task that runs a Spark driver. + * @param driverDescription Submitted driver description from + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] + * @param taskId Mesos TaskID generated for the task + * @param slaveId Slave ID that the task is assigned to + * @param mesosTaskStatus The last known task status update. + * @param startDate The date the task was launched + */ +private[spark] class MesosClusterSubmissionState( + val driverDescription: MesosDriverDescription, + val taskId: TaskID, + val slaveId: SlaveID, + var mesosTaskStatus: Option[TaskStatus], + var startDate: Date, + var finishDate: Option[Date]) + extends Serializable { + + def copy(): MesosClusterSubmissionState = { + new MesosClusterSubmissionState( + driverDescription, taskId, slaveId, mesosTaskStatus, startDate, finishDate) + } +} + +/** + * Tracks the retry state of a driver, which includes the next time it should be scheduled + * and necessary information to do exponential backoff. + * This class is not thread-safe, and we expect the caller to handle synchronizing state. + * @param lastFailureStatus Last Task status when it failed. + * @param retries Number of times it has been retried. + * @param nextRetry Time at which it should be retried next + * @param waitTime The amount of time driver is scheduled to wait until next retry. + */ +private[spark] class MesosClusterRetryState( + val lastFailureStatus: TaskStatus, + val retries: Int, + val nextRetry: Date, + val waitTime: Int) extends Serializable { + def copy(): MesosClusterRetryState = + new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime) +} + +/** + * The full state of the cluster scheduler, currently being used for displaying + * information on the UI. + * @param frameworkId Mesos Framework id for the cluster scheduler. + * @param masterUrl The Mesos master url + * @param queuedDrivers All drivers queued to be launched + * @param launchedDrivers All launched or running drivers + * @param finishedDrivers All terminated drivers + * @param pendingRetryDrivers All drivers pending to be retried + */ +private[spark] class MesosClusterSchedulerState( + val frameworkId: String, + val masterUrl: Option[String], + val queuedDrivers: Iterable[MesosDriverDescription], + val launchedDrivers: Iterable[MesosClusterSubmissionState], + val finishedDrivers: Iterable[MesosClusterSubmissionState], + val pendingRetryDrivers: Iterable[MesosDriverDescription]) + +/** + * The full state of a Mesos driver, that is being used to display driver information on the UI. + */ +private[spark] class MesosDriverState( + val state: String, + val description: MesosDriverDescription, + val submissionState: Option[MesosClusterSubmissionState] = None) + +/** + * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode + * as Mesos tasks in a Mesos cluster. + * All drivers are launched asynchronously by the framework, which will eventually be launched + * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * sandbox which is accessible by visiting the Mesos UI. + * This scheduler supports recovery by persisting all its state and performs task reconciliation + * on recover, which gets all the latest state for all the drivers from Mesos master. + */ +private[spark] class MesosClusterScheduler( + engineFactory: MesosClusterPersistenceEngineFactory, + conf: SparkConf) + extends Scheduler with MesosSchedulerUtils { + var frameworkUrl: String = _ + private val metricsSystem = + MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf)) + private val master = conf.get("spark.master") + private val appName = conf.get("spark.app.name") + private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) + private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) + private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val schedulerState = engineFactory.createEngine("scheduler") + private val stateLock = new ReentrantLock() + private val finishedDrivers = + new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) + private var frameworkId: String = null + // Holds all the launched drivers and current launch state, keyed by driver id. + private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() + // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // All drivers that are loaded after failover are added here, as we need get the latest + // state of the tasks from Mesos. + private val pendingRecover = new mutable.HashMap[String, SlaveID]() + // Stores all the submitted drivers that hasn't been launched. + private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() + // All supervised drivers that are waiting to retry after termination. + private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() + private val queuedDriversState = engineFactory.createEngine("driverQueue") + private val launchedDriversState = engineFactory.createEngine("launchedDrivers") + private val pendingRetryDriversState = engineFactory.createEngine("retryList") + // Flag to mark if the scheduler is ready to be called, which is until the scheduler + // is registered with Mesos master. + @volatile protected var ready = false + private var masterInfo: Option[MasterInfo] = None + + def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { + val c = new CreateSubmissionResponse + if (!ready) { + c.success = false + c.message = "Scheduler is not ready to take requests" + return c + } + + stateLock.synchronized { + if (isQueueFull()) { + c.success = false + c.message = "Already reached maximum submission size" + return c + } + c.submissionId = desc.submissionId + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + c.success = true + } + c + } + + def killDriver(submissionId: String): KillSubmissionResponse = { + val k = new KillSubmissionResponse + if (!ready) { + k.success = false + k.message = "Scheduler is not ready to take requests" + return k + } + k.submissionId = submissionId + stateLock.synchronized { + // We look for the requested driver in the following places: + // 1. Check if submission is running or launched. + // 2. Check if it's still queued. + // 3. Check if it's in the retry list. + // 4. Check if it has already completed. + if (launchedDrivers.contains(submissionId)) { + val task = launchedDrivers(submissionId) + mesosDriver.killTask(task.taskId) + k.success = true + k.message = "Killing running driver" + } else if (removeFromQueuedDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's still pending" + } else if (removeFromPendingRetryDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's being retried" + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + k.success = false + k.message = "Driver already terminated" + } else { + k.success = false + k.message = "Cannot find driver" + } + } + k + } + + def getDriverStatus(submissionId: String): SubmissionStatusResponse = { + val s = new SubmissionStatusResponse + if (!ready) { + s.success = false + s.message = "Scheduler is not ready to take requests" + return s + } + s.submissionId = submissionId + stateLock.synchronized { + if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "QUEUED" + } else if (launchedDrivers.contains(submissionId)) { + s.success = true + s.driverState = "RUNNING" + launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "FINISHED" + finishedDrivers + .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus + .foreach(state => s.message = state.toString) + } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { + val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .get.retryState.get.lastFailureStatus + s.success = true + s.driverState = "RETRYING" + s.message = status.toString + } else { + s.success = false + s.driverState = "NOT_FOUND" + } + } + s + } + + /** + * Gets the driver state to be displayed on the Web UI. + */ + def getDriverState(submissionId: String): Option[MesosDriverState] = { + stateLock.synchronized { + queuedDrivers.find(_.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("QUEUED", d)) + .orElse(launchedDrivers.get(submissionId) + .map(d => new MesosDriverState("RUNNING", d.driverDescription, Some(d)))) + .orElse(finishedDrivers.find(_.driverDescription.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("FINISHED", d.driverDescription, Some(d)))) + .orElse(pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("RETRYING", d))) + } + } + + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity + + /** + * Recover scheduler state that is persisted. + * We still need to do task reconciliation to be up to date of the latest task states + * as it might have changed while the scheduler is failing over. + */ + private def recoverState(): Unit = { + stateLock.synchronized { + launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => + launchedDrivers(state.taskId.getValue) = state + pendingRecover(state.taskId.getValue) = state.slaveId + } + queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) + // There is potential timing issue where a queued driver might have been launched + // but the scheduler shuts down before the queued driver was able to be removed + // from the queue. We try to mitigate this issue by walking through all queued drivers + // and remove if they're already launched. + queuedDrivers + .filter(d => launchedDrivers.contains(d.submissionId)) + .foreach(d => removeFromQueuedDrivers(d.submissionId)) + pendingRetryDriversState.fetchAll[MesosDriverDescription]() + .foreach(s => pendingRetryDrivers += s) + // TODO: Consider storing finished drivers so we can show them on the UI after + // failover. For now we clear the history on each recovery. + finishedDrivers.clear() + } + } + + /** + * Starts the cluster scheduler and wait until the scheduler is registered. + * This also marks the scheduler to be ready for requests. + */ + def start(): Unit = { + // TODO: Implement leader election to make sure only one framework running in the cluster. + val fwId = schedulerState.fetch[String]("frameworkId") + fwId.foreach { id => + frameworkId = id + } + recoverState() + metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) + metricsSystem.start() + val driver = createSchedulerDriver( + master, + MesosClusterScheduler.this, + Utils.getCurrentUserName(), + appName, + conf, + Some(frameworkUrl), + Some(true), + Some(Integer.MAX_VALUE), + fwId) + + startScheduler(driver) + ready = true + } + + def stop(): Unit = { + ready = false + metricsSystem.report() + metricsSystem.stop() + mesosDriver.stop(true) + } + + override def registered( + driver: SchedulerDriver, + newFrameworkId: FrameworkID, + masterInfo: MasterInfo): Unit = { + logInfo("Registered as framework ID " + newFrameworkId.getValue) + if (newFrameworkId.getValue != frameworkId) { + frameworkId = newFrameworkId.getValue + schedulerState.persist("frameworkId", frameworkId) + } + markRegistered() + + stateLock.synchronized { + this.masterInfo = Some(masterInfo) + if (!pendingRecover.isEmpty) { + // Start task reconciliation if we need to recover. + val statuses = pendingRecover.collect { + case (taskId, slaveId) => + val newStatus = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(slaveId) + .setState(MesosTaskState.TASK_STAGING) + .build() + launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + .getOrElse(newStatus) + } + // TODO: Page the status updates to avoid trying to reconcile + // a large amount of tasks at once. + driver.reconcileTasks(statuses.toSeq.asJava) + } + } + } + + private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = { + val appJar = CommandInfo.URI.newBuilder() + .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() + val builder = CommandInfo.newBuilder().addUris(appJar) + val entries = + (conf.getOption("spark.executor.extraLibraryPath").toList ++ + desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { + Utils.libraryPathEnvPrefix(entries) + } else { + "" + } + val envBuilder = Environment.newBuilder() + desc.command.environment.foreach { case (k, v) => + envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build()) + } + // Pass all spark properties to executor. + val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") + envBuilder.addVariables( + Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) + val cmdOptions = generateCmdOption(desc).mkString(" ") + val dockerDefined = desc.schedulerProperties.contains("spark.mesos.executor.docker.image") + val executorUri = desc.schedulerProperties.get("spark.executor.uri") + .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + val appArguments = desc.command.arguments.mkString(" ") + val (executable, jar) = if (dockerDefined) { + // Application jar is automatically downloaded in the mounted sandbox by Mesos, + // and the path to the mounted volume is stored in $MESOS_SANDBOX env variable. + ("./bin/spark-submit", s"$$MESOS_SANDBOX/${desc.jarUrl.split("/").last}") + } else if (executorUri.isDefined) { + builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) + val folderBasename = executorUri.get.split('/').last.split('.').head + val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" + val cmdJar = s"../${desc.jarUrl.split("/").last}" + (cmdExecutable, cmdJar) + } else { + val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") + .orElse(conf.getOption("spark.home")) + .orElse(Option(System.getenv("SPARK_HOME"))) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdJar = desc.jarUrl.split("/").last + (cmdExecutable, cmdJar) + } + builder.setValue(s"$executable $cmdOptions $jar $appArguments") + builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + builder.build() + } + + private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + var options = Seq( + "--name", desc.schedulerProperties("spark.app.name"), + "--class", desc.command.mainClass, + "--master", s"mesos://${conf.get("spark.master")}", + "--driver-cores", desc.cores.toString, + "--driver-memory", s"${desc.mem}M") + desc.schedulerProperties.get("spark.executor.memory").map { v => + options ++= Seq("--executor-memory", v) + } + desc.schedulerProperties.get("spark.cores.max").map { v => + options ++= Seq("--total-executor-cores", v) + } + options + } + + private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + override def toString(): String = { + s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + } + } + + /** + * This method takes all the possible candidates and attempt to schedule them with Mesos offers. + * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled + * logic on each task. + */ + private def scheduleTasks( + candidates: Seq[MesosDriverDescription], + afterLaunchCallback: (String) => Boolean, + currentOffers: List[ResourceOffer], + tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = { + for (submission <- candidates) { + val driverCpu = submission.cores + val driverMem = submission.mem + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val offerOption = currentOffers.find { o => + o.cpu >= driverCpu && o.mem >= driverMem + } + if (offerOption.isEmpty) { + logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + + s"cpu: $driverCpu, mem: $driverMem") + } else { + val offer = offerOption.get + offer.cpu -= driverCpu + offer.mem -= driverMem + val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() + val cpuResource = createResource("cpus", driverCpu) + val memResource = createResource("mem", driverMem) + val commandInfo = buildDriverCommand(submission) + val appName = submission.schedulerProperties("spark.app.name") + val taskInfo = TaskInfo.newBuilder() + .setTaskId(taskId) + .setName(s"Driver for $appName") + .setSlaveId(offer.offer.getSlaveId) + .setCommand(commandInfo) + .addResources(cpuResource) + .addResources(memResource) + submission.schedulerProperties.get("spark.mesos.executor.docker.image").foreach { image => + val container = taskInfo.getContainerBuilder() + val volumes = submission.schedulerProperties + .get("spark.mesos.executor.docker.volumes") + .map(MesosSchedulerBackendUtil.parseVolumesSpec) + val portmaps = submission.schedulerProperties + .get("spark.mesos.executor.docker.portmaps") + .map(MesosSchedulerBackendUtil.parsePortMappingsSpec) + MesosSchedulerBackendUtil.addDockerInfo( + container, image, volumes = volumes, portmaps = portmaps) + taskInfo.setContainer(container.build()) + } + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + queuedTasks += taskInfo.build() + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + submission.submissionId) + val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + None, new Date(), None) + launchedDrivers(submission.submissionId) = newState + launchedDriversState.persist(submission.submissionId, newState) + afterLaunchCallback(submission.submissionId) + } + } + } + + override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { + val currentOffers = offers.asScala.map(o => + new ResourceOffer( + o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) + ).toList + logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() + val currentTime = new Date() + + stateLock.synchronized { + // We first schedule all the supervised drivers that are ready to retry. + // This list will be empty if none of the drivers are marked as supervise. + val driversToRetry = pendingRetryDrivers.filter { d => + d.retryState.get.nextRetry.before(currentTime) + } + + scheduleTasks( + copyBuffer(driversToRetry), + removeFromPendingRetryDrivers, + currentOffers, + tasks) + + // Then we walk through the queued drivers and try to schedule them. + scheduleTasks( + copyBuffer(queuedDrivers), + removeFromQueuedDrivers, + currentOffers, + tasks) + } + tasks.foreach { case (offerId, taskInfos) => + driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) + } + offers.asScala + .filter(o => !tasks.keySet.contains(o.getId)) + .foreach(o => driver.declineOffer(o.getId)) + } + + private def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + + def getSchedulerState(): MesosClusterSchedulerState = { + stateLock.synchronized { + new MesosClusterSchedulerState( + frameworkId, + masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"), + copyBuffer(queuedDrivers), + launchedDrivers.values.map(_.copy()).toList, + finishedDrivers.map(_.copy()).toList, + copyBuffer(pendingRetryDrivers)) + } + } + + override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {} + override def disconnected(driver: SchedulerDriver): Unit = {} + override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { + logInfo(s"Framework re-registered with master ${masterInfo.getId}") + } + override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def error(driver: SchedulerDriver, error: String): Unit = { + logError("Error received: " + error) + } + + /** + * Check if the task state is a recoverable state that we can relaunch the task. + * Task state like TASK_ERROR are not relaunchable state since it wasn't able + * to be validated by Mesos. + */ + private def shouldRelaunch(state: MesosTaskState): Boolean = { + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { + val taskId = status.getTaskId.getValue + stateLock.synchronized { + if (launchedDrivers.contains(taskId)) { + if (status.getReason == Reason.REASON_RECONCILIATION && + !pendingRecover.contains(taskId)) { + // Task has already received update and no longer requires reconciliation. + return + } + val state = launchedDrivers(taskId) + // Check if the driver is supervise enabled and can be relaunched. + if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { + removeFromLaunchedDrivers(taskId) + state.finishDate = Some(new Date()) + val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState + val (retries, waitTimeSec) = retryState + .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } + .getOrElse{ (1, 1) } + val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) + + val newDriverDescription = state.driverDescription.copy( + retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) + pendingRetryDrivers += newDriverDescription + pendingRetryDriversState.persist(taskId, newDriverDescription) + } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { + removeFromLaunchedDrivers(taskId) + state.finishDate = Some(new Date()) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + state.mesosTaskStatus = Option(status) + } else { + logError(s"Unable to find driver $taskId in status update") + } + } + } + + override def frameworkMessage( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + message: Array[Byte]): Unit = {} + + override def executorLost( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + status: Int): Unit = {} + + private def removeFromQueuedDrivers(id: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + queuedDrivers.remove(index) + queuedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromLaunchedDrivers(id: String): Boolean = { + if (launchedDrivers.remove(id).isDefined) { + launchedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromPendingRetryDrivers(id: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + pendingRetryDrivers.remove(index) + pendingRetryDriversState.expunge(id) + true + } else { + false + } + } + + def getQueuedDriversSize: Int = queuedDrivers.size + def getLaunchedDriversSize: Int = launchedDrivers.size + def getPendingRetryDriversSize: Int = pendingRetryDrivers.size +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala new file mode 100644 index 000000000000..1fe94974c8e3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) + extends Source { + override def sourceName: String = "mesos_cluster" + override def metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getQueuedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("launchedDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getLaunchedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("retryDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getPendingRetryDriversSize + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c3c546be6da1..2e424054be78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -18,24 +18,21 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections +import java.util.{ArrayList => JArrayList, Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, - ExecutorInfo => MesosExecutorInfo, _} - +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils + /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks @@ -47,17 +44,10 @@ private[spark] class MesosSchedulerBackend( master: String) extends SchedulerBackend with MScheduler - with Logging { + with MesosSchedulerUtils { - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // Stores the slave ids that has launched a Mesos executor. + val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo] val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks @@ -68,37 +58,35 @@ private[spark] class MesosSchedulerBackend( // The listener bus to publish executor added/removed events. val listenerBus = sc.listenerBus + private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + classLoader = Thread.currentThread.getContextClassLoader + val driver = createSchedulerDriver( + master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createExecutorInfo(execId: String): MesosExecutorInfo = { + /** + * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * @param availableResources Available resources that is offered by Mesos + * @param execId The executor id to assign to this new executor. + * @return A tuple of the new mesos executor info and the remaining available resources. + */ + def createExecutorInfo( + availableResources: JList[Resource], + execId: String): (MesosExecutorInfo, JList[Resource]) = { val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -123,38 +111,42 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.get("spark.executor.uri", null) + val uri = sc.conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val executorBackendName = classOf[MesosExecutorBackend].getName - if (uri == null) { + if (uri.isEmpty) { val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } - val cpus = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder() - .setValue(scheduler.CPUS_PER_TASK).build()) - .build() - val memory = Resource.newBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar( - Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) - .build() - MesosExecutorInfo.newBuilder() + val builder = MesosExecutorInfo.newBuilder() + val (resourcesAfterCpu, usedCpuResources) = + partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + val (resourcesAfterMem, usedMemResources) = + partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) + + builder.addAllResources(usedCpuResources.asJava) + builder.addAllResources(usedMemResources.asJava) + + sc.conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) + + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - .addResources(cpus) - .addResources(memory) - .build() + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) + } + + (executorInfo.build(), resourcesAfterMem.asJava) } /** @@ -164,7 +156,7 @@ private[spark] class MesosSchedulerBackend( private def createExecArg(): Array[Byte] = { if (execArgs == null) { val props = new HashMap[String, String] - for ((key,value) <- sc.conf.getAll) { + for ((key, value) <- sc.conf.getAll) { props(key) = value } // Serialize the map as an array of (String, String) pairs @@ -179,18 +171,7 @@ private[spark] class MesosSchedulerBackend( inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } + markRegistered() } } @@ -208,6 +189,18 @@ private[spark] class MesosSchedulerBackend( override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { + val builder = new StringBuilder + tasks.asScala.foreach { t => + builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") + .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") + .append("Task resources: ").append(t.getResourcesList).append("\n") + .append("Executor resources: ").append(t.getExecutor.getResourcesList) + .append("---------------------------------------------\n") + } + builder.toString() + } + /** * Method called by Mesos to offer resources on slaves. We respond by asking our active task sets * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that @@ -216,26 +209,42 @@ private[spark] class MesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.partition { o => + val (usableOffers, unUsableOffers) = offers.asScala.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => - val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { - // If the executor doesn't exist yet, subtract CPU for executor - // TODO(pwendell): Should below just subtract "1"? - getResource(o.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK + // If the Mesos executor has not been started on this slave yet, set aside a few + // cores for the Mesos executor by offering fewer cores to the Spark executor + (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt } new WorkerOffer( o.getSlaveId.getValue, @@ -245,6 +254,10 @@ private[spark] class MesosSchedulerBackend( val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap + val slaveIdToResources = new HashMap[String, JList[Resource]]() + usableOffers.foreach { o => + slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList + } val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] @@ -256,11 +269,15 @@ private[spark] class MesosSchedulerBackend( .foreach { offer => offer.foreach { taskDesc => val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId slavesIdsOfAcceptedOffers += slaveId taskIdToSlaveId(taskDesc.taskId) = slaveId + val (mesosTask, remainingResources) = createMesosTask( + taskDesc, + slaveIdToResources(slaveId), + slaveId) mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) + .add(mesosTask) + slaveIdToResources(slaveId) = remainingResources } } @@ -270,8 +287,10 @@ private[spark] class MesosSchedulerBackend( mesosTasks.foreach { case (slaveId, tasks) => slaveIdToWorkerOffer.get(slaveId).foreach(o => listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, - new ExecutorInfo(o.host, o.cores))) + // TODO: Add support for log urls for Mesos + new ExecutorInfo(o.host, o.cores, Map.empty))) ) + logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}") d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -280,44 +299,32 @@ private[spark] class MesosSchedulerBackend( for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { d.declineOffer(o.getId) } - - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) - } - } - - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue } - 0 } - /** Turn a Spark TaskDescription into a Mesos task */ - def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { + /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ + def createMesosTask( + task: TaskDescription, + resources: JList[Resource], + slaveId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val cpuResource = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build()) - .build() - MesosTaskInfo.newBuilder() + val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) { + (slaveIdToExecutorInfo(slaveId), resources) + } else { + createExecutorInfo(resources, slaveId) + } + slaveIdToExecutorInfo(slaveId) = executorInfo + val (finalResources, cpuResources) = + partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK) + val taskInfo = MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(createExecutorInfo(slaveId)) + .setExecutor(executorInfo) .setName(task.name) - .addResources(cpuResource) + .addAllResources(cpuResources.asJava) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() - } - - /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST + (taskInfo, finalResources.asJava) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { @@ -325,11 +332,12 @@ private[spark] class MesosSchedulerBackend( val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { - if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + if (TaskState.isFailed(TaskState.fromMesos(status.getState)) + && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone removeExecutor(taskIdToSlaveId(tid), "Lost executor") } - if (isFinished(status.getState)) { + if (TaskState.isFinished(state)) { taskIdToSlaveId.remove(tid) } } @@ -345,13 +353,13 @@ private[spark] class MesosSchedulerBackend( } override def stop() { - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def reviveOffers() { - driver.reviveOffers() + mesosDriver.reviveOffers() } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} @@ -362,7 +370,7 @@ private[spark] class MesosSchedulerBackend( private def removeExecutor(slaveId: String, reason: String) = { synchronized { listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) - slaveIdsWithExecutors -= slaveId + slaveIdToExecutorInfo -= slaveId } } @@ -386,14 +394,14 @@ private[spark] class MesosSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - driver.killTask( + mesosDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def defaultParallelism(): Int = sc.conf.getInt("spark.default.parallelism", 8) override def applicationId(): String = Option(appId).getOrElse { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala new file mode 100644 index 000000000000..e79c543a9de2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.{ContainerInfo, Volume} +import org.apache.mesos.Protos.ContainerInfo.DockerInfo + +import org.apache.spark.{Logging, SparkConf} + +/** + * A collection of utility functions which can be used by both the + * MesosSchedulerBackend and the CoarseMesosSchedulerBackend. + */ +private[mesos] object MesosSchedulerBackendUtil extends Logging { + /** + * Parse a comma-delimited list of volume specs, each of which + * takes the form [host-dir:]container-dir[:rw|:ro]. + */ + def parseVolumesSpec(volumes: String): List[Volume] = { + volumes.split(",").map(_.split(":")).flatMap { spec => + val vol: Volume.Builder = Volume + .newBuilder() + .setMode(Volume.Mode.RW) + spec match { + case Array(container_path) => + Some(vol.setContainerPath(container_path)) + case Array(container_path, "rw") => + Some(vol.setContainerPath(container_path)) + case Array(container_path, "ro") => + Some(vol.setContainerPath(container_path) + .setMode(Volume.Mode.RO)) + case Array(host_path, container_path) => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path)) + case Array(host_path, container_path, "rw") => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path)) + case Array(host_path, container_path, "ro") => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path) + .setMode(Volume.Mode.RO)) + case spec => { + logWarning(s"Unable to parse volume specs: $volumes. " + + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") + None + } + } + } + .map { _.build() } + .toList + } + + /** + * Parse a comma-delimited list of port mapping specs, each of which + * takes the form host_port:container_port[:udp|:tcp] + * + * Note: + * the docker form is [ip:]host_port:container_port, but the DockerInfo + * message has no field for 'ip', and instead has a 'protocol' field. + * Docker itself only appears to support TCP, so this alternative form + * anticipates the expansion of the docker form to allow for a protocol + * and leaves open the chance for mesos to begin to accept an 'ip' field + */ + def parsePortMappingsSpec(portmaps: String): List[DockerInfo.PortMapping] = { + portmaps.split(",").map(_.split(":")).flatMap { spec: Array[String] => + val portmap: DockerInfo.PortMapping.Builder = DockerInfo.PortMapping + .newBuilder() + .setProtocol("tcp") + spec match { + case Array(host_port, container_port) => + Some(portmap.setHostPort(host_port.toInt) + .setContainerPort(container_port.toInt)) + case Array(host_port, container_port, protocol) => + Some(portmap.setHostPort(host_port.toInt) + .setContainerPort(container_port.toInt) + .setProtocol(protocol)) + case spec => { + logWarning(s"Unable to parse port mapping specs: $portmaps. " + + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") + None + } + } + } + .map { _.build() } + .toList + } + + /** + * Construct a DockerInfo structure and insert it into a ContainerInfo + */ + def addDockerInfo( + container: ContainerInfo.Builder, + image: String, + volumes: Option[List[Volume]] = None, + network: Option[ContainerInfo.DockerInfo.Network] = None, + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { + + val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) + + network.foreach(docker.setNetwork) + portmaps.foreach(_.foreach(docker.addPortMappings)) + container.setType(ContainerInfo.Type.DOCKER) + container.setDocker(docker.build()) + volumes.foreach(_.foreach(container.addVolumes)) + } + + /** + * Setup a docker containerizer + */ + def setupContainerBuilderDockerInfo( + imageName: String, + conf: SparkConf, + builder: ContainerInfo.Builder): Unit = { + val volumes = conf + .getOption("spark.mesos.executor.docker.volumes") + .map(parseVolumesSpec) + val portmaps = conf + .getOption("spark.mesos.executor.docker.portmaps") + .map(parsePortMappingsSpec) + addDockerInfo( + builder, + imageName, + volumes = volumes, + portmaps = portmaps) + logDebug("setupContainerDockerInfo: using docker image: " + imageName) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala new file mode 100644 index 000000000000..860c8e097b3b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.{List => JList} +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} +import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} +import org.apache.spark.util.Utils + + +/** + * Shared trait for implementing a Mesos Scheduler. This holds common state and helper + * methods and Mesos scheduler will use. + */ +private[mesos] trait MesosSchedulerUtils extends Logging { + // Lock used to wait for scheduler to be registered + private final val registerLatch = new CountDownLatch(1) + + // Driver for talking to Mesos + protected var mesosDriver: SchedulerDriver = null + + /** + * Creates a new MesosSchedulerDriver that communicates to the Mesos master. + * @param masterUrl The url to connect to Mesos master + * @param scheduler the scheduler class to receive scheduler callbacks + * @param sparkUser User to impersonate with when running tasks + * @param appName The framework name to display on the Mesos UI + * @param conf Spark configuration + * @param webuiUrl The WebUI url to link from Mesos UI + * @param checkpoint Option to checkpoint tasks for failover + * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect + * @param frameworkId The id of the new framework + */ + protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { + val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) + val credBuilder = Credential.newBuilder() + webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } + checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } + failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } + frameworkId.foreach { id => + fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) + } + conf.getOption("spark.mesos.principal").foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret").foreach { secret => + credBuilder.setSecret(ByteString.copyFromUtf8(secret)) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + conf.getOption("spark.mesos.role").foreach { role => + fwInfoBuilder.setRole(role) + } + if (credBuilder.hasPrincipal) { + new MesosSchedulerDriver( + scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) + } else { + new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl) + } + } + + /** + * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. + * This driver is expected to not be running. + * This method returns only after the scheduler has registered with Mesos. + */ + def startScheduler(newDriver: SchedulerDriver): Unit = { + synchronized { + if (mesosDriver != null) { + registerLatch.await() + return + } + + new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { + setDaemon(true) + + override def run() { + mesosDriver = newDriver + try { + val ret = mesosDriver.run() + logInfo("driver.run() returned with code " + ret) + if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { + System.exit(1) + } + } catch { + case e: Exception => { + logError("driver.run() failed", e) + System.exit(1) + } + } + } + }.start() + + registerLatch.await() + } + } + + /** + * Signal that the scheduler has registered with Mesos. + */ + protected def getResource(res: JList[Resource], name: String): Double = { + // A resource can have multiple values in the offer since it can either be from + // a specific role or wildcard. + res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum + } + + protected def markRegistered(): Unit = { + registerLatch.countDown() + } + + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { + val builder = Resource.newBuilder() + .setName(name) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(amount).build()) + + role.foreach { r => builder.setRole(r) } + + builder.build() + } + + /** + * Partition the existing set of resources into two groups, those remaining to be + * scheduled and those requested to be used for a new task. + * @param resources The full list of available resources + * @param resourceName The name of the resource to take from the available resources + * @param amountToUse The amount of resources to take from the available resources + * @return The remaining resources list and the used resources list. + */ + def partitionResources( + resources: JList[Resource], + resourceName: String, + amountToUse: Double): (List[Resource], List[Resource]) = { + var remain = amountToUse + var requestedResources = new ArrayBuffer[Resource] + val remainingResources = resources.asScala.map { + case r => { + if (remain > 0 && + r.getType == Value.Type.SCALAR && + r.getScalar.getValue > 0.0 && + r.getName == resourceName) { + val usage = Math.min(remain, r.getScalar.getValue) + requestedResources += createResource(resourceName, usage, Some(r.getRole)) + remain -= usage + createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole)) + } else { + r + } + } + } + + // Filter any resource that has depleted. + val filteredResources = + remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0) + + (filteredResources.toList, requestedResources.toList) + } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.asScala.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.asScala.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + splitter.split(constraintsVal).asScala.toMap.mapValues(v => + if (v == null || v.isEmpty) { + Set[String]() + } else { + v.split(',').toSet + } + ) + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 05b6fa54564b..4d48fcfea44e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -17,15 +17,16 @@ package org.apache.spark.scheduler.local +import java.io.File +import java.net.URL import java.nio.ByteBuffer -import akka.actor.{Actor, ActorRef, Props} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() @@ -36,25 +37,27 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, + userClassPath: Seq[URL], scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { private var freeCores = totalCores - private val localExecutorId = SparkContext.DRIVER_IDENTIFIER - private val localExecutorHostname = "localhost" + val localExecutorId = SparkContext.DRIVER_IDENTIFIER + val localExecutorHostname = "localhost" private val executor = new Executor( - localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) + localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) - override def receiveWithLogging = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -67,9 +70,12 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopExecutor => executor.stop() + context.reply(true) } def reviveOffers() { @@ -87,35 +93,54 @@ private[spark] class LocalActor( * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + val totalCores: Int) + extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + private var localEndpoint: RpcEndpointRef = null + private val userClassPath = getUserClasspath(conf) + private val listenerBus = scheduler.sc.listenerBus + + /** + * Returns a list of URLs representing the user classpath. + * + * @param conf Spark configuration. + */ + def getUserClasspath(conf: SparkConf): Seq[URL] = { + val userClassPathStr = conf.getOption("spark.executor.extraClassPath") + userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) + } override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + val rpcEnv = SparkEnv.get.rpcEnv + val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) + localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + listenerBus.post(SparkListenerExecutorAdded( + System.currentTimeMillis, + executorEndpoint.localExecutorId, + new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) } override def stop() { - localActor ! StopExecutor + localEndpoint.ask(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } - override def defaultParallelism() = + override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala new file mode 100644 index 000000000000..62f8aae7f212 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.io.CompressionCodec + +/** + * Custom serializer used for generic Avro records. If the user registers the schemas + * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual + * schema, as to reduce network IO. + * Actions like parsing or compressing schemas are computationally expensive so the serializer + * caches all previously seen values as to reduce the amount of work needed to do. + * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the + * string representation of the Avro schema, used to decrease the amount of data + * that needs to be serialized. + */ +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) + extends KSerializer[GenericRecord] { + + /** Used to reduce the amount of effort to compress the schema */ + private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() + private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]() + + /** Reuses the same datum reader/writer since the same schema will be used many times */ + private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]() + private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]() + + /** Fingerprinting is very expensive so this alleviates most of the work */ + private val fingerprintCache = new mutable.HashMap[Schema, Long]() + private val schemaCache = new mutable.HashMap[Long, Schema]() + + // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become + // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make + // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having + // the SparkEnv set (note those tests would fail if they tried to serialize avro data). + private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + + /** + * Used to compress Schemas when they are being sent over the wire. + * The compression results are memoized to reduce the compression time since the + * same schema is compressed many times over + */ + def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { + val bos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(bos) + out.write(schema.toString.getBytes("UTF-8")) + out.close() + bos.toByteArray + }) + + /** + * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already + * seen values so to limit the number of times that decompression has to be done. + */ + def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { + val bis = new ByteArrayInputStream(schemaBytes.array()) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + new Schema.Parser().parse(new String(bytes, "UTF-8")) + }) + + /** + * Serializes a record to the given output stream. It caches a lot of the internal data as + * to not redo work + */ + def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = { + val encoder = EncoderFactory.get.binaryEncoder(output, null) + val schema = datum.getSchema + val fingerprint = fingerprintCache.getOrElseUpdate(schema, { + SchemaNormalization.parsingFingerprint64(schema) + }) + schemas.get(fingerprint) match { + case Some(_) => + output.writeBoolean(true) + output.writeLong(fingerprint) + case None => + output.writeBoolean(false) + val compressedSchema = compress(schema) + output.writeInt(compressedSchema.length) + output.writeBytes(compressedSchema) + } + + writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) + .asInstanceOf[DatumWriter[R]] + .write(datum, encoder) + encoder.flush() + } + + /** + * Deserializes generic records into their in-memory form. There is internal + * state to keep a cache of already seen schemas and datum readers. + */ + def deserializeDatum(input: KryoInput): GenericRecord = { + val schema = { + if (input.readBoolean()) { + val fingerprint = input.readLong() + schemaCache.getOrElseUpdate(fingerprint, { + schemas.get(fingerprint) match { + case Some(s) => new Schema.Parser().parse(s) + case None => + throw new SparkException( + "Error reading attempting to read avro data -- encountered an unknown " + + s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + + "if you registered additional schemas after starting your spark context.") + } + }) + } else { + val length = input.readInt() + decompress(ByteBuffer.wrap(input.readBytes(length))) + } + } + val decoder = DecoderFactory.get.directBinaryDecoder(input, null) + readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema)) + .asInstanceOf[DatumReader[GenericRecord]] + .read(null, decoder) + } + + override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit = + serializeDatum(datum, output) + + override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord = + deserializeDatum(input) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 1baa0e009f3a..4a5274b46b7a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -59,10 +59,14 @@ private[spark] class JavaSerializationStream( } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { + extends DeserializationStream { + private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] @@ -120,6 +124,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) + protected def this() = this(new SparkConf()) // For deserialization only + override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d56e23ce4478..048a93850727 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,25 +17,30 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.annotation.Nullable + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} +import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} +import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer -import scala.reflect.ClassTag - /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * @@ -48,26 +53,31 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = - (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") + } + private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt + + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt + if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + + s"2048 mb, got: + $maxBufferSizeMb mb.") + } + private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt - private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val userRegistrator = conf.getOption("spark.kryo.registrator") private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) - .map { className => - try { - Class.forName(className) - } catch { - case e: Exception => - throw new SparkException("Failed to load class to register with Kryo", e) - } - } - def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + private val avroSchemas = conf.getAvroSchema + + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator @@ -88,20 +98,28 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + try { + // scalastyle:off classforname // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. - classesToRegister.foreach { clazz => kryo.register(clazz) } + classesToRegister + .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } + // scalastyle:on classforname } catch { case e: Exception => throw new SparkException(s"Failed to register classes with Kryo", e) @@ -120,24 +138,55 @@ class KryoSerializer(conf: SparkConf) override def newInstance(): SerializerInstance = { new KryoSerializerInstance(this) } + + private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = { + // If auto-reset is disabled, then Kryo may store references to duplicate occurrences of objects + // in the stream rather than writing those objects' serialized bytes, breaking relocation. See + // https://groups.google.com/d/msg/kryo-users/6ZUSyfjjtdo/FhGG1KHDXPgJ for more details. + newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset() + } } private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) +class KryoSerializationStream( + serInstance: KryoSerializerInstance, + outStream: OutputStream) extends SerializationStream { + + private[this] var output: KryoOutput = new KryoOutput(outStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - override def flush() { output.flush() } - override def close() { output.close() } + override def flush() { + if (output == null) { + throw new IOException("Stream is closed") + } + output.flush() + } + + override def close() { + if (output != null) { + try { + output.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + output = null + } + } + } } private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - private val input = new KryoInput(inStream) +class KryoDeserializationStream( + serInstance: KryoSerializerInstance, + inStream: InputStream) extends DeserializationStream { + + private[this] var input: KryoInput = new KryoInput(inStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { try { @@ -150,44 +199,120 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } override def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() + if (input != null) { + try { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + input = null + } + } } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - private val kryo = ks.newKryo() - // Make these lazy vals to avoid creating a buffer unless we use them + /** + * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do + * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching + * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are + * not synchronized. + */ + @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + + /** + * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; + * otherwise, it allocates a new instance. + */ + private[serializer] def borrowKryo(): Kryo = { + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have been modified + // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } + } + + /** + * Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo + * instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later + * re-use. + */ + private[serializer] def releaseKryo(kryo: Kryo): Unit = { + if (cachedKryo == null) { + cachedKryo = kryo + } + } + + // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() private lazy val input = new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() - kryo.writeClassAndObject(output, t) + val kryo = borrowKryo() + try { + kryo.writeClassAndObject(output, t) + } catch { + case e: KryoException if e.getMessage.startsWith("Buffer overflow") => + throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + + "increase spark.kryoserializer.buffer.max value.") + } finally { + releaseKryo(kryo) + } ByteBuffer.wrap(output.toBytes) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] + val kryo = borrowKryo() + try { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + releaseKryo(kryo) + } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj + try { + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + kryo.setClassLoader(oldClassLoader) + releaseKryo(kryo) + } } override def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) + new KryoSerializationStream(this, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) + new KryoDeserializationStream(this, s) + } + + /** + * Returns true if auto-reset is on. The only reason this would be false is if the user-supplied + * registrator explicitly turns auto-reset off. + */ + def getAutoReset(): Boolean = { + val field = classOf[Kryo].getDeclaredField("autoReset") + field.setAccessible(true) + val kryo = borrowKryo() + try { + field.get(kryo).asInstanceOf[Boolean] + } finally { + releaseKryo(kryo) + } } } @@ -209,9 +334,17 @@ private[serializer] object KryoSerializer { classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], + classOf[RoaringBitmap], + classOf[RoaringArray], + classOf[RoaringArray.Element], + classOf[Array[RoaringArray.Element]], + classOf[ArrayContainer], + classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], + classOf[Array[Short]], + classOf[Array[Long]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) @@ -241,16 +374,15 @@ private class JavaIterableWrapperSerializer override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) : java.lang.Iterable[_] = { kryo.readClassAndObject(in) match { - case scalaIterable: Iterable[_] => - scala.collection.JavaConversions.asJavaIterable(scalaIterable) - case javaIterable: java.lang.Iterable[_] => - javaIterable + case scalaIterable: Iterable[_] => scalaIterable.asJava + case javaIterable: java.lang.Iterable[_] => javaIterable } } } private object JavaIterableWrapperSerializer extends Logging { - // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + // The class returned by JavaConverters.asJava + // (scala.collection.convert.Wrappers$IterableWrapper). val wrapperClass = scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb99257965..a1b1e1631eaf 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -17,16 +17,17 @@ package org.apache.spark.serializer -import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.io._ import java.lang.reflect.{Field, Method} import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging -private[serializer] object SerializationDebugger extends Logging { +private[spark] object SerializationDebugger extends Logging { /** * Improve the given NotSerializableException with the serialization path leading from the given @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } @@ -54,7 +62,7 @@ private[serializer] object SerializationDebugger extends Logging { * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ - def find(obj: Any): List[String] = { + private[serializer] def find(obj: Any): List[String] = { new SerializationDebugger().visit(obj, List.empty) } @@ -117,6 +125,12 @@ private[serializer] object SerializationDebugger extends Logging { return List.empty } + /** + * Visit an externalizable object. + * Since writeExternal() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutput that collects all the relevant objects for further testing. + */ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = { val fieldList = new ListObjectOutput @@ -137,17 +151,50 @@ private[serializer] object SerializationDebugger extends Logging { // An object contains multiple slots in serialization. // Get the slots and visit fields in all of them. val (finalObj, desc) = findObjectAndDescriptor(o) + + // If the object has been replaced using writeReplace(), + // then call visit() on it again to test its type again. + if (!finalObj.eq(o)) { + return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) + } + + // Every class is associated with one or more "slots", each slot refers to the parent + // classes of this class. These slots are used by the ObjectOutputStream + // serialization code to recursively serialize the fields of an object and + // its parent classes. For example, if there are the following classes. + // + // class ParentClass(parentField: Int) + // class ChildClass(childField: Int) extends ParentClass(1) + // + // Then serializing the an object Obj of type ChildClass requires first serializing the fields + // of ParentClass (that is, parentField), and then serializing the fields of ChildClass + // (that is, childField). Correspondingly, there will be two slots related to this object: + // + // 1. ParentClass slot, which will be used to serialize parentField of Obj + // 2. ChildClass slot, which will be used to serialize childField fields of Obj + // + // The following code uses the description of each slot to find the fields in the + // corresponding object to visit. + // val slotDescs = desc.getSlotDescs var i = 0 while (i < slotDescs.length) { val slotDesc = slotDescs(i) if (slotDesc.hasWriteObjectMethod) { - // TODO: Handle classes that specify writeObject method. + // If the class type corresponding to current slot has writeObject() defined, + // then its not obvious which fields of the class will be serialized as the writeObject() + // can choose arbitrary fields for serialization. This case is handled separately. + val elem = s"writeObject data (class: ${slotDesc.getName})" + val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack) + if (childStack.nonEmpty) { + return childStack + } } else { + // Visit all the fields objects of the class corresponding to the current slot. val fields: Array[ObjectStreamField] = slotDesc.getFields val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) val numPrims = fields.length - objFieldValues.length - desc.getObjFieldValues(finalObj, objFieldValues) + slotDesc.getObjFieldValues(finalObj, objFieldValues) var j = 0 while (j < objFieldValues.length) { @@ -161,18 +208,54 @@ private[serializer] object SerializationDebugger extends Logging { } j += 1 } - } i += 1 } return List.empty } + + /** + * Visit a serializable object which has the writeObject() defined. + * Since writeObject() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutputStream that collects all the relevant fields for further testing. + * This is similar to how externalizable objects are visited. + */ + private def visitSerializableWithWriteObjectMethod( + o: Object, stack: List[String]): List[String] = { + val innerObjectsCatcher = new ListObjectOutputStream + var notSerializableFound = false + try { + innerObjectsCatcher.writeObject(o) + } catch { + case io: IOException => + notSerializableFound = true + } + + // If something was not serializable, then visit the captured objects. + // Otherwise, all the captured objects are safely serializable, so no need to visit them. + // As an optimization, just added them to the visited list. + if (notSerializableFound) { + val innerObjects = innerObjectsCatcher.outputArray + var k = 0 + while (k < innerObjects.length) { + val childStack = visit(innerObjects(k), stack) + if (childStack.nonEmpty) { + return childStack + } + k += 1 + } + } else { + visited ++= innerObjectsCatcher.outputArray + } + return List.empty + } } /** * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles * writeReplace in Serializable. It starts with the object itself, and keeps calling the - * writeReplace method until there is no more + * writeReplace method until there is no more. */ @tailrec private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { @@ -212,6 +295,31 @@ private[serializer] object SerializationDebugger extends Logging { override def writeByte(i: Int): Unit = {} } + /** An output stream that emulates /dev/null */ + private class NullOutputStream extends OutputStream { + override def write(b: Int) { } + } + + /** + * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns + * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()` + * method which gets called on every object, only if replacing is enabled. So this subclass + * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that + * are being serializabled. The serialized bytes are ignored by sending them to a + * [[NullOutputStream]], which acts like a /dev/null. + */ + private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) { + private val output = new mutable.ArrayBuffer[Any] + this.enableReplaceObject(true) + + def outputArray: Array[Any] = output.toArray + + override def replaceObject(obj: Object): Object = { + output += obj + obj + } + } + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { def getSlotDescs: Array[ObjectStreamClass] = { @@ -299,7 +407,9 @@ private[serializer] object SerializationDebugger extends Logging { /** ObjectStreamClass$ClassDataSlot.desc field */ val DescField: Field = { + // scalastyle:off classforname val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + // scalastyle:on classforname f.setAccessible(true) f } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ca6e971d227f..bd2704dc8187 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -19,11 +19,12 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Private} import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} /** @@ -63,6 +64,39 @@ abstract class Serializer { /** Creates a new [[SerializerInstance]]. */ def newInstance(): SerializerInstance + + /** + * :: Private :: + * Returns true if this serializer supports relocation of its serialized objects and false + * otherwise. This should return true if and only if reordering the bytes of serialized objects + * in serialization stream output is equivalent to having re-ordered those elements prior to + * serializing them. More specifically, the following should hold if a serializer supports + * relocation: + * + * {{{ + * serOut.open() + * position = 0 + * serOut.write(obj1) + * serOut.flush() + * position = # of bytes writen to stream so far + * obj1Bytes = output[0:position-1] + * serOut.write(obj2) + * serOut.flush() + * position2 = # of bytes written to stream so far + * obj2Bytes = output[position:position2-1] + * serIn.open([obj2bytes] concatenate [obj1bytes]) should return (obj2, obj1) + * }}} + * + * In general, this property should hold for serializers that are stateless and that do not + * write special metadata at the beginning or end of the serialization stream. + * + * This API is private to Spark; this method should not be overridden in third-party subclasses + * or called in user code and is subject to removal in future Spark releases. + * + * See SPARK-7311 for more details. + */ + @Private + private[spark] def supportsRelocationOfSerializedObjects: Boolean = false } @@ -81,8 +115,12 @@ object Serializer { /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. + * + * It is legal to create multiple serialization / deserialization streams from the same + * SerializerInstance as long as those streams are all used within the same thread. */ @DeveloperApi +@NotThreadSafe abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer @@ -101,7 +139,12 @@ abstract class SerializerInstance { */ @DeveloperApi abstract class SerializationStream { + /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream + /** Writes the object representing the key of a key-value pair. */ + def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key) + /** Writes the object representing the value of a key-value pair. */ + def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit def close(): Unit @@ -120,7 +163,12 @@ abstract class SerializationStream { */ @DeveloperApi abstract class DeserializationStream { + /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T + /** Reads the object representing the key of a key-value pair. */ + def readKey[T: ClassTag](): T = readObject[T]() + /** Reads the object representing the value of a key-value pair. */ + def readValue[T: ClassTag](): T = readObject[T]() def close(): Unit /** @@ -134,6 +182,28 @@ abstract class DeserializationStream { } catch { case eof: EOFException => finished = true + null + } + } + + override protected def close() { + DeserializationStream.this.close() + } + } + + /** + * Read the elements of this stream through an iterator over key-value pairs. This can only be + * called once, as reading each element will consume data from the input source. + */ + def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] { + override protected def getNext() = { + try { + (readKey[Any](), readValue[Any]()) + } catch { + case eof: EOFException => { + finished = true + null + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala deleted file mode 100644 index 7de2f9cbb286..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ /dev/null @@ -1,298 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.File -import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.JavaConversions._ - -import org.apache.spark.{Logging, SparkConf, SparkEnv} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup -import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} - -/** A group of writers for a ShuffleMapTask, one writer per reducer. */ -private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] - - /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ - def releaseWriters(success: Boolean) -} - -/** - * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. - */ -// Note: Changes to the format in this file should be kept in sync with -// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). -private[spark] -class FileShuffleBlockManager(conf: SparkConf) - extends ShuffleBlockManager with Logging { - - private val transportConf = SparkTransportConf.fromSparkConf(conf) - - private lazy val blockManager = SparkEnv.get.blockManager - - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 - - /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. - */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - - /** - * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. - */ - val completedMapTasks = new ConcurrentLinkedQueue[Int]() - } - - private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] - - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) - - /** - * Get a ShuffleWriterGroup for the given map task, which will register it as complete - * when the writers are closed successfully - */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, - writeMetrics: ShuffleWriteMetrics) = { - new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) - private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null - - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - val blockFile = blockManager.diskBlockManager.getFile(blockId) - // Because of previous failures, the shuffle file may already exist on this machine. - // If so, remove it. - if (blockFile.exists) { - if (blockFile.delete()) { - logInfo(s"Removed existing shuffle file $blockFile") - } else { - logWarning(s"Failed to remove existing shuffle file $blockFile") - } - } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) - } - } - - override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) - } - } - } - - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockData(blockId) - Some(segment.nioByteBuffer()) - } - - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } - } - - /** Remove all the blocks / files and metadata related to a particular shuffle. */ - def removeShuffle(shuffleId: ShuffleId): Boolean = { - // Do not change the ordering of this, if shuffleStates should be removed only - // after the corresponding shuffle blocks have been removed - val cleaned = removeShuffleBlocks(shuffleId) - shuffleStates.remove(shuffleId) - cleaned - } - - /** Remove all the blocks / files related to a particular shuffle. */ - private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { - shuffleStates.get(shuffleId) match { - case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } - } - logInfo("Deleted all files for shuffle " + shuffleId) - true - case None => - logInfo("Could not find files for shuffle " + shuffleId + " for deleting") - false - } - } - - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - - private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) - } - - override def stop() { - metadataCleaner.cancel() - } -} - -private[spark] -object FileShuffleBlockManager { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int) = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala new file mode 100644 index 000000000000..c057de9b3f4d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ + +import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup +import org.apache.spark.storage._ +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} + +/** A group of writers for a ShuffleMapTask, one writer per reducer. */ +private[spark] trait ShuffleWriterGroup { + val writers: Array[DiskBlockObjectWriter] + + /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ + def releaseWriters(success: Boolean) +} + +/** + * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file + * per reducer (this set of files is called a ShuffleFileGroup). + * + * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle + * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer + * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle + * files, it releases them for another task. + * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: + * - shuffleId: The unique id given to the entire shuffle stage. + * - bucketId: The id of the output partition (i.e., reducer id) + * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a + * time owns a particular fileId, and this id is returned to a pool when the task finishes. + * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) + * that specifies where in a given file the actual block data is located. + * + * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping + * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for + * each block stored in each file. In order to find the location of a shuffle block, we search the + * files within a ShuffleFileGroups associated with the block's reducer. + */ +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). +private[spark] class FileShuffleBlockResolver(conf: SparkConf) + extends ShuffleBlockResolver with Logging { + + private val transportConf = SparkTransportConf.fromSparkConf(conf) + + private lazy val blockManager = SparkEnv.get.blockManager + + // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. + // TODO: Remove this once the shuffle file consolidation feature is stable. + private val consolidateShuffleFiles = + conf.getBoolean("spark.shuffle.consolidateFiles", false) + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + + /** + * Contains all the state related to a particular shuffle. This includes a pool of unused + * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + */ + private class ShuffleState(val numBuckets: Int) { + val nextFileId = new AtomicInteger(0) + val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + + /** + * The mapIds of all map tasks completed on this Executor for this shuffle. + * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. + */ + val completedMapTasks = new ConcurrentLinkedQueue[Int]() + } + + private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] + + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + + /** + * Get a ShuffleWriterGroup for the given map task, which will register it as complete + * when the writers are closed successfully + */ + def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { + new ShuffleWriterGroup { + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + private val shuffleState = shuffleStates(shuffleId) + private var fileGroup: ShuffleFileGroup = null + + val openStartTime = System.nanoTime + val serializerInstance = serializer.newInstance() + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { + fileGroup = getUnusedFileGroup() + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, + writeMetrics) + } + } else { + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) + val blockFile = blockManager.diskBlockManager.getFile(blockId) + // Because of previous failures, the shuffle file may already exist on this machine. + // If so, remove it. + if (blockFile.exists) { + if (blockFile.delete()) { + logInfo(s"Removed existing shuffle file $blockFile") + } else { + logWarning(s"Failed to remove existing shuffle file $blockFile") + } + } + blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, + writeMetrics) + } + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, so should be included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) + + override def releaseWriters(success: Boolean) { + if (consolidateShuffleFiles) { + if (success) { + val offsets = writers.map(_.fileSegment().offset) + val lengths = writers.map(_.fileSegment().length) + fileGroup.recordMapOutput(mapId, offsets, lengths) + } + recycleFileGroup(fileGroup) + } else { + shuffleState.completedMapTasks.add(mapId) + } + } + + private def getUnusedFileGroup(): ShuffleFileGroup = { + val fileGroup = shuffleState.unusedFileGroups.poll() + if (fileGroup != null) fileGroup else newFileGroup() + } + + private def newFileGroup(): ShuffleFileGroup = { + val fileId = shuffleState.nextFileId.getAndIncrement() + val files = Array.tabulate[File](numBuckets) { bucketId => + val filename = physicalFileName(shuffleId, bucketId, fileId) + blockManager.diskBlockManager.getFile(filename) + } + val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) + shuffleState.allFileGroups.add(fileGroup) + fileGroup + } + + private def recycleFileGroup(group: ShuffleFileGroup) { + shuffleState.unusedFileGroups.add(group) + } + } + } + + override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + if (consolidateShuffleFiles) { + // Search all file groups associated with this shuffle. + val shuffleState = shuffleStates(blockId.shuffleId) + val iter = shuffleState.allFileGroups.iterator + while (iter.hasNext) { + val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) + if (segmentOpt.isDefined) { + val segment = segmentOpt.get + return new FileSegmentManagedBuffer( + transportConf, segment.file, segment.offset, segment.length) + } + } + throw new IllegalStateException("Failed to find shuffle block: " + blockId) + } else { + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) + } + } + + /** Remove all the blocks / files and metadata related to a particular shuffle. */ + def removeShuffle(shuffleId: ShuffleId): Boolean = { + // Do not change the ordering of this, if shuffleStates should be removed only + // after the corresponding shuffle blocks have been removed + val cleaned = removeShuffleBlocks(shuffleId) + shuffleStates.remove(shuffleId) + cleaned + } + + /** Remove all the blocks / files related to a particular shuffle. */ + private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { + shuffleStates.get(shuffleId) match { + case Some(state) => + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups.asScala; + file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks.asScala; + reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + logInfo("Deleted all files for shuffle " + shuffleId) + true + case None => + logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + false + } + } + + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { + "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) + } + + private def cleanup(cleanupTime: Long) { + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) + } + + override def stop() { + metadataCleaner.cancel() + } +} + +private[spark] object FileShuffleBlockResolver { + /** + * A group of shuffle files, one per reducer. + * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. + */ + private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { + private var numBlocks: Int = 0 + + /** + * Stores the absolute index of each mapId in the files of this group. For instance, + * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. + */ + private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() + + /** + * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by + * position in the file. + * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every + * reducer. + */ + private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } + private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } + + def apply(bucketId: Int): File = files(bucketId) + + def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { + assert(offsets.length == lengths.length) + mapIdToIndex(mapId) = numBlocks + numBlocks += 1 + for (i <- 0 until offsets.length) { + blockOffsetsByReducer(i) += offsets(i) + blockLengthsByReducer(i) += lengths(i) + } + } + + /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ + def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { + val file = files(reducerId) + val blockOffsets = blockOffsetsByReducer(reducerId) + val blockLengths = blockLengthsByReducer(reducerId) + val index = mapIdToIndex.getOrElse(mapId, -1) + if (index >= 0) { + val offset = blockOffsets(index) + val length = blockLengths(index) + Some(new FileSegment(file, offset, length)) + } else { + None + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala deleted file mode 100644 index b292587d3702..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io._ -import java.nio.ByteBuffer - -import com.google.common.io.ByteStreams - -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.storage._ - -/** - * Create and maintain the shuffle blocks' mapping between logic block and physical file location. - * Data of shuffle blocks from the same map task are stored in a single consolidated data file. - * The offsets of the data blocks in the data file are stored in a separate index file. - * - * We use the name of the shuffle data's shuffleBlockId with reduce ID set to 0 and add ".data" - * as the filename postfix for data file, and ".index" as the filename postfix for index file. - * - */ -// Note: Changes to the format in this file should be kept in sync with -// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). -private[spark] -class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { - - private lazy val blockManager = SparkEnv.get.blockManager - - private val transportConf = SparkTransportConf.fromSparkConf(conf) - - /** - * Mapping to a single shuffleBlockId with reduce ID 0. - * */ - def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { - ShuffleBlockId(shuffleId, mapId, 0) - } - - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) - } - - private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) - } - - /** - * Remove data file and index file that contain the output data from one map. - * */ - def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { - var file = getDataFile(shuffleId, mapId) - if (file.exists()) { - file.delete() - } - - file = getIndexFile(shuffleId, mapId) - if (file.exists()) { - file.delete() - } - } - - /** - * Write an index file with the offsets of each block, plus a final offset at the end for the - * end of the output file. This will be used by getBlockLocation to figure out where each block - * begins and ends. - * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = { - val indexFile = getIndexFile(shuffleId, mapId) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - - for (length <- lengths) { - offset += length - out.writeLong(offset) - } - } finally { - out.close() - } - } - - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).nioByteBuffer()) - } - - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - // The block is actually going to be a range of a single map output file for this map, so - // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - - val in = new DataInputStream(new FileInputStream(indexFile)) - try { - ByteStreams.skipFully(in, blockId.reduceId * 8) - val offset = in.readLong() - val nextOffset = in.readLong() - new FileSegmentManagedBuffer( - transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), - offset, - nextOffset - offset) - } finally { - in.close() - } - } - - override def stop() = {} -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala new file mode 100644 index 000000000000..d0163d326dba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io._ + +import com.google.common.io.ByteStreams + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +import IndexShuffleBlockResolver.NOOP_REDUCE_ID + +/** + * Create and maintain the shuffle blocks' mapping between logic block and physical file location. + * Data of shuffle blocks from the same map task are stored in a single consolidated data file. + * The offsets of the data blocks in the data file are stored in a separate index file. + * + * We use the name of the shuffle data's shuffleBlockId with reduce ID set to 0 and add ".data" + * as the filename postfix for data file, and ".index" as the filename postfix for index file. + * + */ +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). +private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver { + + private lazy val blockManager = SparkEnv.get.blockManager + + private val transportConf = SparkTransportConf.fromSparkConf(conf) + + def getDataFile(shuffleId: Int, mapId: Int): File = { + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + } + + private def getIndexFile(shuffleId: Int, mapId: Int): File = { + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + } + + /** + * Remove data file and index file that contain the output data from one map. + * */ + def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { + var file = getDataFile(shuffleId, mapId) + if (file.exists()) { + file.delete() + } + + file = getIndexFile(shuffleId, mapId) + if (file.exists()) { + file.delete() + } + } + + /** + * Write an index file with the offsets of each block, plus a final offset at the end for the + * end of the output file. This will be used by getBlockData to figure out where each block + * begins and ends. + * */ + def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { + val indexFile = getIndexFile(shuffleId, mapId) + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() + } + } + + override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + // The block is actually going to be a range of a single map output file for this map, so + // find out the consolidated file, then the offset within that from our index + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + + val in = new DataInputStream(new FileInputStream(indexFile)) + try { + ByteStreams.skipFully(in, blockId.reduceId * 8) + val offset = in.readLong() + val nextOffset = in.readLong() + new FileSegmentManagedBuffer( + transportConf, + getDataFile(blockId.shuffleId, blockId.mapId), + offset, + nextOffset - offset) + } finally { + in.close() + } + } + + override def stop(): Unit = {} +} + +private[spark] object IndexShuffleBlockResolver { + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. + // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort + // shuffle outputs for several reduces are glommed into a single file. + // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. + val NOOP_REDUCE_ID = 0 +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala deleted file mode 100644 index b521f0c7fc77..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.nio.ByteBuffer -import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId - -private[spark] -trait ShuffleBlockManager { - type ShuffleId = Int - - /** - * Get shuffle block data managed by the local ShuffleBlockManager. - * @return Some(ByteBuffer) if block found, otherwise None. - */ - def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer - - def stop(): Unit -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala new file mode 100644 index 000000000000..4342b0d598b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.nio.ByteBuffer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.ShuffleBlockId + +private[spark] +/** + * Implementers of this trait understand how to retrieve block data for a logical shuffle block + * identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to + * encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle + * implementations when shuffle data is retrieved. + */ +trait ShuffleBlockResolver { + type ShuffleId = Int + + /** + * Retrieve the data for the specified block. If the data for that block is not available, + * throws an unspecified exception. + */ + def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala index 13c7115f88af..e04c97fe6189 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala @@ -17,9 +17,12 @@ package org.apache.spark.shuffle +import org.apache.spark.annotation.DeveloperApi + /** * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks. * * @param shuffleId ID of the shuffle */ -private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} +@DeveloperApi +abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a44a8e124925..978366d1a1d1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,7 +55,10 @@ private[spark] trait ShuffleManager { */ def unregisterShuffle(shuffleId: Int): Boolean - def shuffleBlockManager: ShuffleBlockManager + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + def shuffleBlockResolver: ShuffleBlockResolver /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8..a0d8abc2eecb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,108 +19,167 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. + * + * Use `ShuffleMemoryManager.create()` factory method to create a new instance. + * + * @param maxMemory total amount of memory available for execution, in bytes. + * @param pageSizeBytes number of bytes for each page, by default. */ -private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes +private[spark] +class ShuffleMemoryManager protected ( + val maxMemory: Long, + val pageSizeBytes: Long) + extends Logging { - def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes + + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currentTaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } + + /** Returns the memory consumption, in bytes, for the current task */ + def getMemoryConsumptionForThisTask(): Long = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.getOrElse(taskAttemptId, 0L) + } } -private object ShuffleMemoryManager { + +private[spark] object ShuffleMemoryManager { + + def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = { + val maxMemory = ShuffleMemoryManager.getMaxMemory(conf) + val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) + new ShuffleMemoryManager(maxMemory, pageSize) + } + + def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, pageSizeBytes) + } + + @VisibleForTesting + def createForTesting(maxMemory: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024) + } + /** * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction * of the memory pool and a safety factor since collections can sometimes grow bigger than * the size we target before we estimate their sizes again. */ - def getMaxMemory(conf: SparkConf): Long = { + private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } + + /** + * Sets the page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + // TODO(davies): don't round to next power of 2 + val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index b934480cfb9b..4cc4ef5f1886 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,14 +17,17 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** * Obtained inside a map task to write out records to the shuffle system. */ -private[spark] trait ShuffleWriter[K, V] { - /** Write a bunch of records to this task's output */ - def write(records: Iterator[_ <: Product2[K, V]]): Unit +private[spark] abstract class ShuffleWriter[K, V] { + /** Write a sequence of records to this task's output */ + @throws[IOException] + def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala deleted file mode 100644 index e3e7434df45b..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator - -private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - serializer: Serializer) - : Iterator[T] = - { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager - - val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) - } - - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] - } - case Failure(e) => { - blockId match { - case ShuffleBlockId(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) - } - } - } - } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 62e0629b3440..c089088f409d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.shuffle._ */ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { - private val fileShuffleBlockManager = new FileShuffleBlockManager(conf) + private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( @@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { new HashShuffleWriter( - shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockManager.removeShuffle(shuffleId) + shuffleBlockResolver.removeShuffle(shuffleId) } - override def shuffleBlockManager: FileShuffleBlockManager = { - fileShuffleBlockManager + override def shuffleBlockResolver: FileShuffleBlockResolver = { + fileShuffleBlockResolver } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b..0c8f08f0f3b1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,18 +17,22 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) - extends ShuffleReader[K, C] -{ + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + extends ShuffleReader[K, C] with Logging { + require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") @@ -36,20 +40,57 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. @@ -59,8 +100,10 @@ private[spark] class HashShuffleReader[K, C]( // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) sorter.iterator case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 755f17d6aa15..41df70c602c3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,10 +22,10 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( - shuffleBlockManager: FileShuffleBlockManager, + shuffleBlockResolver: FileShuffleBlockResolver, handle: BaseShuffleHandle[K, V, _], mapId: Int, context: TaskContext) @@ -45,11 +45,11 @@ private[spark] class HashShuffleWriter[K, V]( private val blockManager = SparkEnv.get.blockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) - private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, + private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { dep.aggregator.get.combineValuesByKey(records, context) @@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V]( for (elem <- iter) { val bucketId = dep.partitioner.getPartition(elem._1) - shuffle.writers(bucketId).write(elem) + shuffle.writers(bucketId).write(elem._1, elem._2) } } @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bda30a56d808..d7fab351ca3b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { - private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf) + private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() /** @@ -58,7 +58,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) new SortShuffleWriter( - shuffleBlockManager, baseShuffleHandle, mapId, context) + shuffleBlockResolver, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -66,18 +66,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager if (shuffleMapNumber.containsKey(shuffleId)) { val numMaps = shuffleMapNumber.remove(shuffleId) (0 until numMaps).map{ mapId => - shuffleBlockManager.removeDataByMap(shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override def shuffleBlockManager: IndexShuffleBlockManager = { - indexShuffleBlockManager + override val shuffleBlockResolver: IndexShuffleBlockResolver = { + indexShuffleBlockResolver } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 27496c5a289c..5865e7640c1c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,15 +17,16 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{IndexShuffleBlockManager, ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( - shuffleBlockManager: IndexShuffleBlockManager, + shuffleBlockResolver: IndexShuffleBlockResolver, handle: BaseShuffleHandle[K, V, C], mapId: Int, context: TaskContext) @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -48,25 +49,36 @@ private[spark] class SortShuffleWriter[K, V, C]( context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) - val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). + val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) - shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) + shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } @@ -82,15 +94,29 @@ private[spark] class SortShuffleWriter[K, V, C]( return Option(mapStatus) } else { // The map task failed, so delete our output data. - shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(dep.shuffleId, mapId) return None } } finally { // Clean up our sorter, which may have its own intermediate files if (sorter != null) { + val startTime = System.nanoTime() sorter.stop() + context.taskMetrics.shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - startTime)) sorter = null } } } } + +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala new file mode 100644 index 000000000000..df7bbd64247d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +private[spark] object UnsafeShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } + } +} + +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * - No individual record is larger than 128 MB when serialized. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + + "manager; its optimized shuffles will continue to spill to disk when necessary.") + } + + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + sortShuffleManager.getReader(handle, startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) + val env = SparkEnv.get + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) + sortShuffleManager.getWriter(handle, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + } + + override val shuffleBlockResolver: IndexShuffleBlockResolver = { + sortShuffleManager.shuffleBlockResolver + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + sortShuffleManager.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala new file mode 100644 index 000000000000..5783df5d8220 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.util.{Arrays, Date, List => JList} +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.JobProgressListener +import org.apache.spark.ui.jobs.UIData.JobUIData + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class AllJobsResource(ui: SparkUI) { + + @GET + def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { + val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = + AllJobsResource.getStatusToJobs(ui) + val adjStatuses: JList[JobExecutionStatus] = { + if (statuses.isEmpty) { + Arrays.asList(JobExecutionStatus.values(): _*) + } else { + statuses + } + } + val jobInfos = for { + (status, jobs) <- statusToJobs + job <- jobs if adjStatuses.contains(status) + } yield { + AllJobsResource.convertJobData(job, ui.jobProgressListener, false) + } + jobInfos.sortBy{- _.jobId} + } + +} + +private[v1] object AllJobsResource { + + def getStatusToJobs(ui: SparkUI): Seq[(JobExecutionStatus, Seq[JobUIData])] = { + val statusToJobs = ui.jobProgressListener.synchronized { + Seq( + JobExecutionStatus.RUNNING -> ui.jobProgressListener.activeJobs.values.toSeq, + JobExecutionStatus.SUCCEEDED -> ui.jobProgressListener.completedJobs.toSeq, + JobExecutionStatus.FAILED -> ui.jobProgressListener.failedJobs.reverse.toSeq + ) + } + statusToJobs + } + + def convertJobData( + job: JobUIData, + listener: JobProgressListener, + includeStageDetails: Boolean): JobData = { + listener.synchronized { + val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageData = lastStageInfo.flatMap { s => + listener.stageIdToData.get((s.stageId, s.attemptId)) + } + val lastStageName = lastStageInfo.map { _.name }.getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap { _.description } + new JobData( + jobId = job.jobId, + name = lastStageName, + description = lastStageDescription, + submissionTime = job.submissionTime.map{new Date(_)}, + completionTime = job.completionTime.map{new Date(_)}, + stageIds = job.stageIds, + jobGroup = job.jobGroup, + status = job.status, + numTasks = job.numTasks, + numActiveTasks = job.numActiveTasks, + numCompletedTasks = job.numCompletedTasks, + numSkippedTasks = job.numCompletedTasks, + numFailedTasks = job.numFailedTasks, + numActiveStages = job.numActiveStages, + numCompletedStages = job.completedStageIndices.size, + numSkippedStages = job.numSkippedStages, + numFailedStages = job.numFailedStages + ) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala new file mode 100644 index 000000000000..645ede26a087 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.storage.{RDDInfo, StorageStatus, StorageUtils} +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.storage.StorageListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class AllRDDResource(ui: SparkUI) { + + @GET + def rddList(): Seq[RDDStorageInfo] = { + val storageStatusList = ui.storageListener.storageStatusList + val rddInfos = ui.storageListener.rddInfoList + rddInfos.map{rddInfo => + AllRDDResource.getRDDStorageInfo(rddInfo.id, rddInfo, storageStatusList, + includeDetails = false) + } + } + +} + +private[spark] object AllRDDResource { + + def getRDDStorageInfo( + rddId: Int, + listener: StorageListener, + includeDetails: Boolean): Option[RDDStorageInfo] = { + val storageStatusList = listener.storageStatusList + listener.rddInfoList.find { _.id == rddId }.map { rddInfo => + getRDDStorageInfo(rddId, rddInfo, storageStatusList, includeDetails) + } + } + + def getRDDStorageInfo( + rddId: Int, + rddInfo: RDDInfo, + storageStatusList: Seq[StorageStatus], + includeDetails: Boolean): RDDStorageInfo = { + val workers = storageStatusList.map { (rddId, _) } + val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList) + val blocks = storageStatusList + .flatMap { _.rddBlocksById(rddId) } + .sortWith { _._1.name < _._1.name } + .map { case (blockId, status) => + (blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown"))) + } + + val dataDistribution = if (includeDetails) { + Some(storageStatusList.map { status => + new RDDDataDistribution( + address = status.blockManagerId.hostPort, + memoryUsed = status.memUsedByRdd(rddId), + memoryRemaining = status.memRemaining, + diskUsed = status.diskUsedByRdd(rddId) + ) } ) + } else { + None + } + val partitions = if (includeDetails) { + Some(blocks.map { case (id, block, locations) => + new RDDPartitionInfo( + blockName = id.name, + storageLevel = block.storageLevel.description, + memoryUsed = block.memSize, + diskUsed = block.diskSize, + executors = locations + ) + } ) + } else { + None + } + + new RDDStorageInfo( + id = rddId, + name = rddInfo.name, + numPartitions = rddInfo.numPartitions, + numCachedPartitions = rddInfo.numCachedPartitions, + storageLevel = rddInfo.storageLevel.description, + memoryUsed = rddInfo.memSize, + diskUsed = rddInfo.diskSize, + dataDistribution = dataDistribution, + partitions = partitions + ) + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala new file mode 100644 index 000000000000..390c136df79b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.util.{Arrays, Date, List => JList} +import javax.ws.rs.{GET, PathParam, Produces, QueryParam} +import javax.ws.rs.core.MediaType + +import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} +import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} +import org.apache.spark.util.Distribution + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class AllStagesResource(ui: SparkUI) { + + @GET + def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { + val listener = ui.jobProgressListener + val stageAndStatus = AllStagesResource.stagesAndStatus(ui) + val adjStatuses = { + if (statuses.isEmpty()) { + Arrays.asList(StageStatus.values(): _*) + } else { + statuses + } + } + for { + (status, stageList) <- stageAndStatus + stageInfo: StageInfo <- stageList if adjStatuses.contains(status) + stageUiData: StageUIData <- listener.synchronized { + listener.stageIdToData.get((stageInfo.stageId, stageInfo.attemptId)) + } + } yield { + AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, includeDetails = false) + } + } +} + +private[v1] object AllStagesResource { + def stageUiToStageData( + status: StageStatus, + stageInfo: StageInfo, + stageUiData: StageUIData, + includeDetails: Boolean): StageData = { + + val taskData = if (includeDetails) { + Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) + } else { + None + } + val executorSummary = if (includeDetails) { + Some(stageUiData.executorSummary.map { case (k, summary) => + k -> new ExecutorStageSummary( + taskTime = summary.taskTime, + failedTasks = summary.failedTasks, + succeededTasks = summary.succeededTasks, + inputBytes = summary.inputBytes, + outputBytes = summary.outputBytes, + shuffleRead = summary.shuffleRead, + shuffleWrite = summary.shuffleWrite, + memoryBytesSpilled = summary.memoryBytesSpilled, + diskBytesSpilled = summary.diskBytesSpilled + ) + }) + } else { + None + } + + val accumulableInfo = stageUiData.accumulables.values.map { convertAccumulableInfo }.toSeq + + new StageData( + status = status, + stageId = stageInfo.stageId, + attemptId = stageInfo.attemptId, + numActiveTasks = stageUiData.numActiveTasks, + numCompleteTasks = stageUiData.numCompleteTasks, + numFailedTasks = stageUiData.numFailedTasks, + executorRunTime = stageUiData.executorRunTime, + inputBytes = stageUiData.inputBytes, + inputRecords = stageUiData.inputRecords, + outputBytes = stageUiData.outputBytes, + outputRecords = stageUiData.outputRecords, + shuffleReadBytes = stageUiData.shuffleReadTotalBytes, + shuffleReadRecords = stageUiData.shuffleReadRecords, + shuffleWriteBytes = stageUiData.shuffleWriteBytes, + shuffleWriteRecords = stageUiData.shuffleWriteRecords, + memoryBytesSpilled = stageUiData.memoryBytesSpilled, + diskBytesSpilled = stageUiData.diskBytesSpilled, + schedulingPool = stageUiData.schedulingPool, + name = stageInfo.name, + details = stageInfo.details, + accumulatorUpdates = accumulableInfo, + tasks = taskData, + executorSummary = executorSummary + ) + } + + def stagesAndStatus(ui: SparkUI): Seq[(StageStatus, Seq[StageInfo])] = { + val listener = ui.jobProgressListener + listener.synchronized { + Seq( + StageStatus.ACTIVE -> listener.activeStages.values.toSeq, + StageStatus.COMPLETE -> listener.completedStages.reverse.toSeq, + StageStatus.FAILED -> listener.failedStages.reverse.toSeq, + StageStatus.PENDING -> listener.pendingStages.values.toSeq + ) + } + } + + def convertTaskData(uiData: TaskUIData): TaskData = { + new TaskData( + taskId = uiData.taskInfo.taskId, + index = uiData.taskInfo.index, + attempt = uiData.taskInfo.attempt, + launchTime = new Date(uiData.taskInfo.launchTime), + executorId = uiData.taskInfo.executorId, + host = uiData.taskInfo.host, + taskLocality = uiData.taskInfo.taskLocality.toString(), + speculative = uiData.taskInfo.speculative, + accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, + errorMessage = uiData.errorMessage, + taskMetrics = uiData.taskMetrics.map { convertUiTaskMetrics } + ) + } + + def taskMetricDistributions( + allTaskData: Iterable[TaskUIData], + quantiles: Array[Double]): TaskMetricDistributions = { + + val rawMetrics = allTaskData.flatMap{_.taskMetrics}.toSeq + + def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] = + Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) + + // We need to do a lot of similar munging to nested metrics here. For each one, + // we want (a) extract the values for nested metrics (b) make a distribution for each metric + // (c) shove the distribution into the right field in our return type and (d) only return + // a result if the option is defined for any of the tasks. MetricHelper is a little util + // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just + // implement one "build" method, which just builds the quantiles for each field. + + val inputMetrics: Option[InputMetricDistributions] = + new MetricHelper[InternalInputMetrics, InputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalInputMetrics] = { + raw.inputMetrics + } + + def build: InputMetricDistributions = new InputMetricDistributions( + bytesRead = submetricQuantiles(_.bytesRead), + recordsRead = submetricQuantiles(_.recordsRead) + ) + }.metricOption + + val outputMetrics: Option[OutputMetricDistributions] = + new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { + raw.outputMetrics + } + def build: OutputMetricDistributions = new OutputMetricDistributions( + bytesWritten = submetricQuantiles(_.bytesWritten), + recordsWritten = submetricQuantiles(_.recordsWritten) + ) + }.metricOption + + val shuffleReadMetrics: Option[ShuffleReadMetricDistributions] = + new MetricHelper[InternalShuffleReadMetrics, ShuffleReadMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalShuffleReadMetrics] = { + raw.shuffleReadMetrics + } + def build: ShuffleReadMetricDistributions = new ShuffleReadMetricDistributions( + readBytes = submetricQuantiles(_.totalBytesRead), + readRecords = submetricQuantiles(_.recordsRead), + remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), + localBlocksFetched = submetricQuantiles(_.localBlocksFetched), + totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), + fetchWaitTime = submetricQuantiles(_.fetchWaitTime) + ) + }.metricOption + + val shuffleWriteMetrics: Option[ShuffleWriteMetricDistributions] = + new MetricHelper[InternalShuffleWriteMetrics, ShuffleWriteMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalShuffleWriteMetrics] = { + raw.shuffleWriteMetrics + } + def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions( + writeBytes = submetricQuantiles(_.shuffleBytesWritten), + writeRecords = submetricQuantiles(_.shuffleRecordsWritten), + writeTime = submetricQuantiles(_.shuffleWriteTime) + ) + }.metricOption + + new TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), + executorRunTime = metricQuantiles(_.executorRunTime), + resultSize = metricQuantiles(_.resultSize), + jvmGcTime = metricQuantiles(_.jvmGCTime), + resultSerializationTime = metricQuantiles(_.resultSerializationTime), + memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), + diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), + inputMetrics = inputMetrics, + outputMetrics = outputMetrics, + shuffleReadMetrics = shuffleReadMetrics, + shuffleWriteMetrics = shuffleWriteMetrics + ) + } + + def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = { + new AccumulableInfo(acc.id, acc.name, acc.update, acc.value) + } + + def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { + new TaskMetrics( + executorDeserializeTime = internal.executorDeserializeTime, + executorRunTime = internal.executorRunTime, + resultSize = internal.resultSize, + jvmGcTime = internal.jvmGCTime, + resultSerializationTime = internal.resultSerializationTime, + memoryBytesSpilled = internal.memoryBytesSpilled, + diskBytesSpilled = internal.diskBytesSpilled, + inputMetrics = internal.inputMetrics.map { convertInputMetrics }, + outputMetrics = Option(internal.outputMetrics).flatten.map { convertOutputMetrics }, + shuffleReadMetrics = internal.shuffleReadMetrics.map { convertShuffleReadMetrics }, + shuffleWriteMetrics = internal.shuffleWriteMetrics.map { convertShuffleWriteMetrics } + ) + } + + def convertInputMetrics(internal: InternalInputMetrics): InputMetrics = { + new InputMetrics( + bytesRead = internal.bytesRead, + recordsRead = internal.recordsRead + ) + } + + def convertOutputMetrics(internal: InternalOutputMetrics): OutputMetrics = { + new OutputMetrics( + bytesWritten = internal.bytesWritten, + recordsWritten = internal.recordsWritten + ) + } + + def convertShuffleReadMetrics(internal: InternalShuffleReadMetrics): ShuffleReadMetrics = { + new ShuffleReadMetrics( + remoteBlocksFetched = internal.remoteBlocksFetched, + localBlocksFetched = internal.localBlocksFetched, + fetchWaitTime = internal.fetchWaitTime, + remoteBytesRead = internal.remoteBytesRead, + totalBlocksFetched = internal.totalBlocksFetched, + recordsRead = internal.recordsRead + ) + } + + def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): ShuffleWriteMetrics = { + new ShuffleWriteMetrics( + bytesWritten = internal.shuffleBytesWritten, + writeTime = internal.shuffleWriteTime, + recordsWritten = internal.shuffleRecordsWritten + ) + } +} + +/** + * Helper for getting distributions from nested metric types. Many of the metrics we want are + * contained in options inside TaskMetrics (eg., ShuffleWriteMetrics). This makes it easy to handle + * the options (returning None if the metrics are all empty), and extract the quantiles for each + * metric. After creating an instance, call metricOption to get the result type. + */ +private[v1] abstract class MetricHelper[I, O]( + rawMetrics: Seq[InternalTaskMetrics], + quantiles: Array[Double]) { + + def getSubmetrics(raw: InternalTaskMetrics): Option[I] + + def build: O + + val data: Seq[I] = rawMetrics.flatMap(getSubmetrics) + + /** applies the given function to all input metrics, and returns the quantiles */ + def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { + Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) + } + + def metricOption: Option[O] = { + if (data.isEmpty) { + None + } else { + Some(build) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala new file mode 100644 index 000000000000..50b6ba67e993 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.util.zip.ZipOutputStream +import javax.servlet.ServletContext +import javax.ws.rs._ +import javax.ws.rs.core.{Context, Response} + +import com.sun.jersey.api.core.ResourceConfig +import com.sun.jersey.spi.container.servlet.ServletContainer +import org.eclipse.jetty.server.handler.ContextHandler +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} + +import org.apache.spark.SecurityManager +import org.apache.spark.ui.SparkUI + +/** + * Main entry point for serving spark application metrics as json, using JAX-RS. + * + * Each resource should have endpoints that return **public** classes defined in api.scala. Mima + * binary compatibility checks ensure that we don't inadvertently make changes that break the api. + * The returned objects are automatically converted to json by jackson with JacksonMessageWriter. + * In addition, there are a number of tests in HistoryServerSuite that compare the json to "golden + * files". Any changes and additions should be reflected there as well -- see the notes in + * HistoryServerSuite. + */ +@Path("/v1") +private[v1] class ApiRootResource extends UIRootFromServletContext { + + @Path("applications") + def getApplicationList(): ApplicationListResource = { + new ApplicationListResource(uiRoot) + } + + @Path("applications/{appId}") + def getApplication(): OneApplicationResource = { + new OneApplicationResource(uiRoot) + } + + @Path("applications/{appId}/{attemptId}/jobs") + def getJobs( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): AllJobsResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new AllJobsResource(ui) + } + } + + @Path("applications/{appId}/jobs") + def getJobs(@PathParam("appId") appId: String): AllJobsResource = { + uiRoot.withSparkUI(appId, None) { ui => + new AllJobsResource(ui) + } + } + + @Path("applications/{appId}/jobs/{jobId: \\d+}") + def getJob(@PathParam("appId") appId: String): OneJobResource = { + uiRoot.withSparkUI(appId, None) { ui => + new OneJobResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/jobs/{jobId: \\d+}") + def getJob( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): OneJobResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new OneJobResource(ui) + } + } + + @Path("applications/{appId}/executors") + def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { + uiRoot.withSparkUI(appId, None) { ui => + new ExecutorListResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/executors") + def getExecutors( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): ExecutorListResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new ExecutorListResource(ui) + } + } + + + @Path("applications/{appId}/stages") + def getStages(@PathParam("appId") appId: String): AllStagesResource = { + uiRoot.withSparkUI(appId, None) { ui => + new AllStagesResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/stages") + def getStages( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): AllStagesResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new AllStagesResource(ui) + } + } + + @Path("applications/{appId}/stages/{stageId: \\d+}") + def getStage(@PathParam("appId") appId: String): OneStageResource = { + uiRoot.withSparkUI(appId, None) { ui => + new OneStageResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/stages/{stageId: \\d+}") + def getStage( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): OneStageResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new OneStageResource(ui) + } + } + + @Path("applications/{appId}/storage/rdd") + def getRdds(@PathParam("appId") appId: String): AllRDDResource = { + uiRoot.withSparkUI(appId, None) { ui => + new AllRDDResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/storage/rdd") + def getRdds( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): AllRDDResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new AllRDDResource(ui) + } + } + + @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") + def getRdd(@PathParam("appId") appId: String): OneRDDResource = { + uiRoot.withSparkUI(appId, None) { ui => + new OneRDDResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/storage/rdd/{rddId: \\d+}") + def getRdd( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): OneRDDResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new OneRDDResource(ui) + } + } + + @Path("applications/{appId}/logs") + def getEventLogs( + @PathParam("appId") appId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, None) + } + + @Path("applications/{appId}/{attemptId}/logs") + def getEventLogs( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } +} + +private[spark] object ApiRootResource { + + def getServletHandler(uiRoot: UIRoot): ServletContextHandler = { + val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) + jerseyContext.setContextPath("/api") + val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) + holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", + "com.sun.jersey.api.core.PackagesResourceConfig") + holder.setInitParameter("com.sun.jersey.config.property.packages", + "org.apache.spark.status.api.v1") + holder.setInitParameter(ResourceConfig.PROPERTY_CONTAINER_REQUEST_FILTERS, + classOf[SecurityFilter].getCanonicalName) + UIRootFromServletContext.setUiRoot(jerseyContext, uiRoot) + jerseyContext.addServlet(holder, "/*") + jerseyContext + } +} + +/** + * This trait is shared by the all the root containers for application UI information -- + * the HistoryServer, the Master UI, and the application UI. This provides the common + * interface needed for them all to expose application info as json. + */ +private[spark] trait UIRoot { + def getSparkUI(appKey: String): Option[SparkUI] + def getApplicationInfoList: Iterator[ApplicationInfo] + + /** + * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is + * [[None]], event logs for all attempts of this application will be written out. + */ + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { + Response.serverError() + .entity("Event logs are only available through the history server.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + + /** + * Get the spark UI with the given appID, and apply a function + * to it. If there is no such app, throw an appropriate exception + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + getSparkUI(appKey) match { + case Some(ui) => + f(ui) + case None => throw new NotFoundException("no such app: " + appId) + } + } + def securityManager: SecurityManager +} + +private[v1] object UIRootFromServletContext { + + private val attribute = getClass.getCanonicalName + + def setUiRoot(contextHandler: ContextHandler, uiRoot: UIRoot): Unit = { + contextHandler.setAttribute(attribute, uiRoot) + } + + def getUiRoot(context: ServletContext): UIRoot = { + context.getAttribute(attribute).asInstanceOf[UIRoot] + } +} + +private[v1] trait UIRootFromServletContext { + @Context + var servletContext: ServletContext = _ + + def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) +} + +private[v1] class NotFoundException(msg: String) extends WebApplicationException( + new NoSuchElementException(msg), + Response + .status(Response.Status.NOT_FOUND) + .entity(ErrorWrapper(msg)) + .build() +) + +private[v1] class BadParameterException(msg: String) extends WebApplicationException( + new IllegalArgumentException(msg), + Response + .status(Response.Status.BAD_REQUEST) + .entity(ErrorWrapper(msg)) + .build() +) { + def this(param: String, exp: String, actual: String) = { + this(raw"""Bad value for parameter "$param". Expected a $exp, got "$actual"""") + } +} + +/** + * Signal to JacksonMessageWriter to not convert the message into json (which would result in an + * extra set of quotes). + */ +private[v1] case class ErrorWrapper(s: String) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala new file mode 100644 index 000000000000..17b521f3e1d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.util.{Arrays, Date, List => JList} +import javax.ws.rs.{DefaultValue, GET, Produces, QueryParam} +import javax.ws.rs.core.MediaType + +import org.apache.spark.deploy.history.ApplicationHistoryInfo +import org.apache.spark.deploy.master.{ApplicationInfo => InternalApplicationInfo} + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ApplicationListResource(uiRoot: UIRoot) { + + @GET + def appList( + @QueryParam("status") status: JList[ApplicationStatus], + @DefaultValue("2010-01-01") @QueryParam("minDate") minDate: SimpleDateParam, + @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam) + : Iterator[ApplicationInfo] = { + val allApps = uiRoot.getApplicationInfoList + val adjStatus = { + if (status.isEmpty) { + Arrays.asList(ApplicationStatus.values(): _*) + } else { + status + } + } + val includeCompleted = adjStatus.contains(ApplicationStatus.COMPLETED) + val includeRunning = adjStatus.contains(ApplicationStatus.RUNNING) + allApps.filter { app => + val anyRunning = app.attempts.exists(!_.completed) + // if any attempt is still running, we consider the app to also still be running + val statusOk = (!anyRunning && includeCompleted) || + (anyRunning && includeRunning) + // keep the app if *any* attempts fall in the right time window + val dateOk = app.attempts.exists { attempt => + attempt.startTime.getTime >= minDate.timestamp && + attempt.startTime.getTime <= maxDate.timestamp + } + statusOk && dateOk + } + } +} + +private[spark] object ApplicationsListResource { + def appHistoryInfoToPublicAppInfo(app: ApplicationHistoryInfo): ApplicationInfo = { + new ApplicationInfo( + id = app.id, + name = app.name, + attempts = app.attempts.map { internalAttemptInfo => + new ApplicationAttemptInfo( + attemptId = internalAttemptInfo.attemptId, + startTime = new Date(internalAttemptInfo.startTime), + endTime = new Date(internalAttemptInfo.endTime), + sparkUser = internalAttemptInfo.sparkUser, + completed = internalAttemptInfo.completed + ) + } + ) + } + + def convertApplicationInfo( + internal: InternalApplicationInfo, + completed: Boolean): ApplicationInfo = { + // standalone application info always has just one attempt + new ApplicationInfo( + id = internal.id, + name = internal.desc.name, + attempts = Seq(new ApplicationAttemptInfo( + attemptId = None, + startTime = new Date(internal.startTime), + endTime = new Date(internal.endTime), + sparkUser = internal.desc.user, + completed = completed + )) + ) + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala new file mode 100644 index 000000000000..22e21f0c62a2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.io.OutputStream +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil + +@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) +private[v1] class EventLogDownloadResource( + val uIRoot: UIRoot, + val appId: String, + val attemptId: Option[String]) extends Logging { + val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) + + @GET + def getEventLogs(): Response = { + try { + val fileName = { + attemptId match { + case Some(id) => s"eventLogs-$appId-$id.zip" + case None => s"eventLogs-$appId.zip" + } + } + + val stream = new StreamingOutput { + override def write(output: OutputStream): Unit = { + val zipStream = new ZipOutputStream(output) + try { + uIRoot.writeEventLogs(appId, attemptId, zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala new file mode 100644 index 000000000000..8ad4656b4dad --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -0,0 +1,36 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.status.api.v1 + +import javax.ws.rs.{GET, PathParam, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.exec.ExecutorsPage + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ExecutorListResource(ui: SparkUI) { + + @GET + def executorList(): Seq[ExecutorSummary] = { + val listener = ui.executorsListener + val storageStatusList = listener.storageStatusList + (0 until storageStatusList.size).map { statusId => + ExecutorsPage.getExecInfo(listener, statusId) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala new file mode 100644 index 000000000000..202a5191ad57 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.io.OutputStream +import java.lang.annotation.Annotation +import java.lang.reflect.Type +import java.text.SimpleDateFormat +import java.util.{Calendar, SimpleTimeZone} +import javax.ws.rs.Produces +import javax.ws.rs.core.{MediaType, MultivaluedMap} +import javax.ws.rs.ext.{MessageBodyWriter, Provider} + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature} + +/** + * This class converts the POJO metric responses into json, using jackson. + * + * This doesn't follow the standard jersey-jackson plugin options, because we want to stick + * with an old version of jersey (since we have it from yarn anyway) and don't want to pull in lots + * of dependencies from a new plugin. + * + * Note that jersey automatically discovers this class based on its package and its annotations. + */ +@Provider +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ + + val mapper = new ObjectMapper() { + override def writeValueAsString(t: Any): String = { + super.writeValueAsString(t) + } + } + mapper.registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) + mapper.enable(SerializationFeature.INDENT_OUTPUT) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + mapper.setDateFormat(JacksonMessageWriter.makeISODateFormat) + + override def isWriteable( + aClass: Class[_], + `type`: Type, + annotations: Array[Annotation], + mediaType: MediaType): Boolean = { + true + } + + override def writeTo( + t: Object, + aClass: Class[_], + `type`: Type, + annotations: Array[Annotation], + mediaType: MediaType, + multivaluedMap: MultivaluedMap[String, AnyRef], + outputStream: OutputStream): Unit = { + t match { + case ErrorWrapper(err) => outputStream.write(err.getBytes("utf-8")) + case _ => mapper.writeValue(outputStream, t) + } + } + + override def getSize( + t: Object, + aClass: Class[_], + `type`: Type, + annotations: Array[Annotation], + mediaType: MediaType): Long = { + -1L + } +} + +private[spark] object JacksonMessageWriter { + def makeISODateFormat: SimpleDateFormat = { + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) + iso8601.setCalendar(cal) + iso8601 + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala new file mode 100644 index 000000000000..b5ef72649e29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs.core.MediaType +import javax.ws.rs.{Produces, PathParam, GET} + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneApplicationResource(uiRoot: UIRoot) { + + @GET + def getApp(@PathParam("appId") appId: String): ApplicationInfo = { + val apps = uiRoot.getApplicationInfoList.find { _.id == appId } + apps.getOrElse(throw new NotFoundException("unknown app: " + appId)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala new file mode 100644 index 000000000000..6d8a60d480ae --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.UIData.JobUIData + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneJobResource(ui: SparkUI) { + + @GET + def oneJob(@PathParam("jobId") jobId: Int): JobData = { + val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = + AllJobsResource.getStatusToJobs(ui) + val jobOpt = statusToJobs.map {_._2} .flatten.find { jobInfo => jobInfo.jobId == jobId} + jobOpt.map { job => + AllJobsResource.convertJobData(job, ui.jobProgressListener, false) + }.getOrElse { + throw new NotFoundException("unknown job: " + jobId) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala new file mode 100644 index 000000000000..dfdc09c6caf3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.ui.SparkUI + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneRDDResource(ui: SparkUI) { + + @GET + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { + AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( + throw new NotFoundException(s"no rdd found w/ id $rddId") + ) + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala new file mode 100644 index 000000000000..f9812f06cf52 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +import org.apache.spark.SparkException +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.status.api.v1.StageStatus._ +import org.apache.spark.status.api.v1.TaskSorting._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.JobProgressListener +import org.apache.spark.ui.jobs.UIData.StageUIData + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneStageResource(ui: SparkUI) { + + @GET + @Path("") + def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = { + withStage(stageId){ stageAttempts => + stageAttempts.map { stage => + AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, + includeDetails = true) + } + } + } + + @GET + @Path("/{stageAttemptId: \\d+}") + def oneAttemptData( + @PathParam("stageId") stageId: Int, + @PathParam("stageAttemptId") stageAttemptId: Int): StageData = { + withStageAttempt(stageId, stageAttemptId) { stage => + AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, + includeDetails = true) + } + } + + @GET + @Path("/{stageAttemptId: \\d+}/taskSummary") + def taskSummary( + @PathParam("stageId") stageId: Int, + @PathParam("stageAttemptId") stageAttemptId: Int, + @DefaultValue("0.05,0.25,0.5,0.75,0.95") @QueryParam("quantiles") quantileString: String) + : TaskMetricDistributions = { + withStageAttempt(stageId, stageAttemptId) { stage => + val quantiles = quantileString.split(",").map { s => + try { + s.toDouble + } catch { + case nfe: NumberFormatException => + throw new BadParameterException("quantiles", "double", s) + } + } + AllStagesResource.taskMetricDistributions(stage.ui.taskData.values, quantiles) + } + } + + @GET + @Path("/{stageAttemptId: \\d+}/taskList") + def taskList( + @PathParam("stageId") stageId: Int, + @PathParam("stageAttemptId") stageAttemptId: Int, + @DefaultValue("0") @QueryParam("offset") offset: Int, + @DefaultValue("20") @QueryParam("length") length: Int, + @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { + withStageAttempt(stageId, stageAttemptId) { stage => + val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq + .sorted(OneStageResource.ordering(sortBy)) + tasks.slice(offset, offset + length) + } + } + + private case class StageStatusInfoUi(status: StageStatus, info: StageInfo, ui: StageUIData) + + private def withStage[T](stageId: Int)(f: Seq[StageStatusInfoUi] => T): T = { + val stageAttempts = findStageStatusUIData(ui.jobProgressListener, stageId) + if (stageAttempts.isEmpty) { + throw new NotFoundException("unknown stage: " + stageId) + } else { + f(stageAttempts) + } + } + + private def findStageStatusUIData( + listener: JobProgressListener, + stageId: Int): Seq[StageStatusInfoUi] = { + listener.synchronized { + def getStatusInfoUi(status: StageStatus, infos: Seq[StageInfo]): Seq[StageStatusInfoUi] = { + infos.filter { _.stageId == stageId }.map { info => + val ui = listener.stageIdToData.getOrElse((info.stageId, info.attemptId), + // this is an internal error -- we should always have uiData + throw new SparkException( + s"no stage ui data found for stage: ${info.stageId}:${info.attemptId}") + ) + StageStatusInfoUi(status, info, ui) + } + } + getStatusInfoUi(ACTIVE, listener.activeStages.values.toSeq) ++ + getStatusInfoUi(COMPLETE, listener.completedStages) ++ + getStatusInfoUi(FAILED, listener.failedStages) ++ + getStatusInfoUi(PENDING, listener.pendingStages.values.toSeq) + } + } + + private def withStageAttempt[T]( + stageId: Int, + stageAttemptId: Int) + (f: StageStatusInfoUi => T): T = { + withStage(stageId) { attempts => + val oneAttempt = attempts.find { stage => stage.info.attemptId == stageAttemptId } + oneAttempt match { + case Some(stage) => + f(stage) + case None => + val stageAttempts = attempts.map { _.info.attemptId } + throw new NotFoundException(s"unknown attempt for stage $stageId. " + + s"Found attempts: ${stageAttempts.mkString("[", ",", "]")}") + } + } + } +} + +object OneStageResource { + def ordering(taskSorting: TaskSorting): Ordering[TaskData] = { + val extractor: (TaskData => Long) = td => + taskSorting match { + case ID => td.taskId + case INCREASING_RUNTIME => td.taskMetrics.map{_.executorRunTime}.getOrElse(-1L) + case DECREASING_RUNTIME => -td.taskMetrics.map{_.executorRunTime}.getOrElse(-1L) + } + Ordering.by(extractor) + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala new file mode 100644 index 000000000000..95fbd96ade5a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs.WebApplicationException +import javax.ws.rs.core.Response + +import com.sun.jersey.spi.container.{ContainerRequest, ContainerRequestFilter} + +private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { + def filter(req: ContainerRequest): ContainerRequest = { + val user = Option(req.getUserPrincipal).map { _.getName }.orNull + if (uiRoot.securityManager.checkUIViewPermissions(user)) { + req + } else { + throw new WebApplicationException( + Response + .status(Response.Status.FORBIDDEN) + .entity(raw"""user "$user"is not authorized""") + .build() + ) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala new file mode 100644 index 000000000000..0c71cd238222 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.text.{ParseException, SimpleDateFormat} +import java.util.TimeZone +import javax.ws.rs.WebApplicationException +import javax.ws.rs.core.Response +import javax.ws.rs.core.Response.Status + +private[v1] class SimpleDateParam(val originalValue: String) { + + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala new file mode 100644 index 000000000000..2bec64f2ef02 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.util.Date + +import scala.collection.Map + +import org.apache.spark.JobExecutionStatus + +class ApplicationInfo private[spark]( + val id: String, + val name: String, + val attempts: Seq[ApplicationAttemptInfo]) + +class ApplicationAttemptInfo private[spark]( + val attemptId: Option[String], + val startTime: Date, + val endTime: Date, + val sparkUser: String, + val completed: Boolean = false) + +class ExecutorStageSummary private[spark]( + val taskTime : Long, + val failedTasks : Int, + val succeededTasks : Int, + val inputBytes : Long, + val outputBytes : Long, + val shuffleRead : Long, + val shuffleWrite : Long, + val memoryBytesSpilled : Long, + val diskBytesSpilled : Long) + +class ExecutorSummary private[spark]( + val id: String, + val hostPort: String, + val rddBlocks: Int, + val memoryUsed: Long, + val diskUsed: Long, + val activeTasks: Int, + val failedTasks: Int, + val completedTasks: Int, + val totalTasks: Int, + val totalDuration: Long, + val totalInputBytes: Long, + val totalShuffleRead: Long, + val totalShuffleWrite: Long, + val maxMemory: Long, + val executorLogs: Map[String, String]) + +class JobData private[spark]( + val jobId: Int, + val name: String, + val description: Option[String], + val submissionTime: Option[Date], + val completionTime: Option[Date], + val stageIds: Seq[Int], + val jobGroup: Option[String], + val status: JobExecutionStatus, + val numTasks: Int, + val numActiveTasks: Int, + val numCompletedTasks: Int, + val numSkippedTasks: Int, + val numFailedTasks: Int, + val numActiveStages: Int, + val numCompletedStages: Int, + val numSkippedStages: Int, + val numFailedStages: Int) + +// Q: should Tachyon size go in here as well? currently the UI only shows it on the overall storage +// page ... does anybody pay attention to it? +class RDDStorageInfo private[spark]( + val id: Int, + val name: String, + val numPartitions: Int, + val numCachedPartitions: Int, + val storageLevel: String, + val memoryUsed: Long, + val diskUsed: Long, + val dataDistribution: Option[Seq[RDDDataDistribution]], + val partitions: Option[Seq[RDDPartitionInfo]]) + +class RDDDataDistribution private[spark]( + val address: String, + val memoryUsed: Long, + val memoryRemaining: Long, + val diskUsed: Long) + +class RDDPartitionInfo private[spark]( + val blockName: String, + val storageLevel: String, + val memoryUsed: Long, + val diskUsed: Long, + val executors: Seq[String]) + +class StageData private[spark]( + val status: StageStatus, + val stageId: Int, + val attemptId: Int, + val numActiveTasks: Int , + val numCompleteTasks: Int, + val numFailedTasks: Int, + + val executorRunTime: Long, + + val inputBytes: Long, + val inputRecords: Long, + val outputBytes: Long, + val outputRecords: Long, + val shuffleReadBytes: Long, + val shuffleReadRecords: Long, + val shuffleWriteBytes: Long, + val shuffleWriteRecords: Long, + val memoryBytesSpilled: Long, + val diskBytesSpilled: Long, + + val name: String, + val details: String, + val schedulingPool: String, + + val accumulatorUpdates: Seq[AccumulableInfo], + val tasks: Option[Map[Long, TaskData]], + val executorSummary: Option[Map[String, ExecutorStageSummary]]) + +class TaskData private[spark]( + val taskId: Long, + val index: Int, + val attempt: Int, + val launchTime: Date, + val executorId: String, + val host: String, + val taskLocality: String, + val speculative: Boolean, + val accumulatorUpdates: Seq[AccumulableInfo], + val errorMessage: Option[String] = None, + val taskMetrics: Option[TaskMetrics] = None) + +class TaskMetrics private[spark]( + val executorDeserializeTime: Long, + val executorRunTime: Long, + val resultSize: Long, + val jvmGcTime: Long, + val resultSerializationTime: Long, + val memoryBytesSpilled: Long, + val diskBytesSpilled: Long, + val inputMetrics: Option[InputMetrics], + val outputMetrics: Option[OutputMetrics], + val shuffleReadMetrics: Option[ShuffleReadMetrics], + val shuffleWriteMetrics: Option[ShuffleWriteMetrics]) + +class InputMetrics private[spark]( + val bytesRead: Long, + val recordsRead: Long) + +class OutputMetrics private[spark]( + val bytesWritten: Long, + val recordsWritten: Long) + +class ShuffleReadMetrics private[spark]( + val remoteBlocksFetched: Int, + val localBlocksFetched: Int, + val fetchWaitTime: Long, + val remoteBytesRead: Long, + val totalBlocksFetched: Int, + val recordsRead: Long) + +class ShuffleWriteMetrics private[spark]( + val bytesWritten: Long, + val writeTime: Long, + val recordsWritten: Long) + +class TaskMetricDistributions private[spark]( + val quantiles: IndexedSeq[Double], + + val executorDeserializeTime: IndexedSeq[Double], + val executorRunTime: IndexedSeq[Double], + val resultSize: IndexedSeq[Double], + val jvmGcTime: IndexedSeq[Double], + val resultSerializationTime: IndexedSeq[Double], + val memoryBytesSpilled: IndexedSeq[Double], + val diskBytesSpilled: IndexedSeq[Double], + + val inputMetrics: Option[InputMetricDistributions], + val outputMetrics: Option[OutputMetricDistributions], + val shuffleReadMetrics: Option[ShuffleReadMetricDistributions], + val shuffleWriteMetrics: Option[ShuffleWriteMetricDistributions]) + +class InputMetricDistributions private[spark]( + val bytesRead: IndexedSeq[Double], + val recordsRead: IndexedSeq[Double]) + +class OutputMetricDistributions private[spark]( + val bytesWritten: IndexedSeq[Double], + val recordsWritten: IndexedSeq[Double]) + +class ShuffleReadMetricDistributions private[spark]( + val readBytes: IndexedSeq[Double], + val readRecords: IndexedSeq[Double], + val remoteBlocksFetched: IndexedSeq[Double], + val localBlocksFetched: IndexedSeq[Double], + val fetchWaitTime: IndexedSeq[Double], + val remoteBytesRead: IndexedSeq[Double], + val totalBlocksFetched: IndexedSeq[Double]) + +class ShuffleWriteMetricDistributions private[spark]( + val writeBytes: IndexedSeq[Double], + val writeRecords: IndexedSeq[Double], + val writeTime: IndexedSeq[Double]) + +class AccumulableInfo private[spark]( + val id: Long, + val name: String, + val update: Option[String], + val value: String) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 1f012941c85a..524f6970992a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -35,13 +35,13 @@ sealed abstract class BlockId { def name: String // convenience methods - def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None - def isRDD = isInstanceOf[RDDBlockId] - def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD: Boolean = isInstanceOf[RDDBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] - override def toString = name - override def hashCode = name.hashCode + override def toString: String = name + override def hashCode: Int = name.hashCode override def equals(other: Any): Boolean = other match { case o: BlockId => getClass == o.getClass && name.equals(o.name) case _ => false @@ -50,54 +50,54 @@ sealed abstract class BlockId { @DeveloperApi case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - def name = "rdd_" + rddId + "_" + splitIndex + override def name: String = "rdd_" + rddId + "_" + splitIndex } // Format of the shuffle block ids (including data and index) should be kept in sync with -// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). +// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { - def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) + override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @DeveloperApi case class TaskResultBlockId(taskId: Long) extends BlockId { - def name = "taskresult_" + taskId + override def name: String = "taskresult_" + taskId } @DeveloperApi case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { - def name = "input-" + streamId + "-" + uniqueId + override def name: String = "input-" + streamId + "-" + uniqueId } /** Id associated with temporary local data managed as blocks. Not serializable. */ private[spark] case class TempLocalBlockId(id: UUID) extends BlockId { - def name = "temp_local_" + id + override def name: String = "temp_local_" + id } /** Id associated with temporary shuffle data managed as blocks. Not serializable. */ private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId { - def name = "temp_shuffle_" + id + override def name: String = "temp_shuffle_" + id } // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { - def name = "test_" + id + override def name: String = "test_" + id } @DeveloperApi @@ -112,7 +112,7 @@ object BlockId { val TEST = "test_(.*)".r /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String) = id match { + def apply(id: String): BlockId = id match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8bc5a1cd18b6..fefaef0ab82c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,27 +17,26 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} +import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor._ +import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo -import org.apache.spark.serializer.Serializer +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ @@ -50,11 +49,8 @@ private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], - readMethod: DataReadMethod.Value, - bytes: Long) { - val inputMetrics = new InputMetrics(readMethod) - inputMetrics.addBytesRead(bytes) -} + val readMethod: DataReadMethod.Value, + val bytes: Long) /** * Manager running on every node (driver and executors) which provides interfaces for putting and @@ -64,7 +60,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -80,19 +76,16 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) + // Actual storage of where blocks are kept - private var tachyonInitialized = false + private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, maxMemory) private[spark] val diskStore = new DiskStore(this, diskBlockManager) - private[spark] lazy val tachyonStore: TachyonStore = { - val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon") - val appFolderName = conf.get("spark.tachyonStore.folderName") - val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}" - val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998") - val tachyonBlockManager = - new TachyonBlockManager(this, tachyonStorePath, tachyonMaster) - tachyonInitialized = true - new TachyonStore(this, tachyonBlockManager) + private[spark] lazy val externalBlockStore: ExternalBlockStore = { + externalBlockStoreInitialized = true + new ExternalBlockStore(this, executorId) } private[spark] @@ -100,8 +93,17 @@ private[spark] class BlockManager( // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. - private val externalShuffleServicePort = - Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + private val externalShuffleServicePort = { + val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + if (tmpPort == 0) { + // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds + // an open port. But we still need to tell our spark apps the right port to use. So + // only if the yarn config has the port set to 0, we prefer the value in the spark config + conf.get("spark.shuffle.service.port").toInt + } else { + tmpPort + } + } // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled @@ -122,7 +124,8 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) } else { blockTransferService } @@ -136,9 +139,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +170,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +179,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +189,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -197,12 +200,13 @@ private[spark] class BlockManager( executorId, blockTransferService.hostName, blockTransferService.port) shuffleServerId = if (externalShuffleServiceEnabled) { + logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -228,7 +232,7 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) } @@ -265,7 +269,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -276,11 +280,13 @@ private[spark] class BlockManager( asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool reregister() asyncReregisterLock.synchronized { asyncReregisterTask = null } - } + }(futureExecutionContext) } } } @@ -301,7 +307,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -323,13 +329,13 @@ private[spark] class BlockManager( /** * Get the BlockStatus for the block identified by the given ID, if it exists. - * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. + * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfo.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L - // Assume that block is not in Tachyon + // Assume that block is not in external block store BlockStatus(info.level, memSize, diskSize, 0L) } } @@ -379,10 +385,10 @@ private[spark] class BlockManager( if (info.tellMaster) { val storageLevel = status.storageLevel val inMemSize = Math.max(status.memSize, droppedMemorySize) - val inTachyonSize = status.tachyonSize + val inExternalBlockStoreSize = status.externalBlockStoreSize val onDiskSize = status.diskSize master.updateBlockInfo( - blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inTachyonSize) + blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inExternalBlockStoreSize) } else { true } @@ -400,15 +406,17 @@ private[spark] class BlockManager( BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) - val inTachyon = level.useOffHeap && tachyonStore.contains(blockId) + val inExternalBlockStore = level.useOffHeap && externalBlockStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false - val replication = if (inMem || inTachyon || onDisk) level.replication else 1 - val storageLevel = StorageLevel(onDisk, inMem, inTachyon, deserialized, replication) + val replication = if (inMem || inExternalBlockStore || onDisk) level.replication else 1 + val storageLevel = + StorageLevel(onDisk, inMem, inExternalBlockStore, deserialized, replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L - val tachyonSize = if (inTachyon) tachyonStore.getSize(blockId) else 0L + val externalBlockStoreSize = + if (inExternalBlockStore) externalBlockStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - BlockStatus(storageLevel, memSize, diskSize, tachyonSize) + BlockStatus(storageLevel, memSize, diskSize, externalBlockStoreSize) } } } @@ -439,14 +447,11 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - val shuffleBlockManager = shuffleManager.shuffleBlockManager - shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { - case Some(bytes) => - Some(bytes) - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } + val shuffleBlockResolver = shuffleManager.shuffleBlockResolver + // TODO: This should gracefully handle case where local block is not available. Currently + // downstream code will throw an exception. + Option( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } @@ -492,20 +497,21 @@ private[spark] class BlockManager( } } - // Look for the block in Tachyon + // Look for the block in external block store if (level.useOffHeap) { - logDebug(s"Getting block $blockId from tachyon") - if (tachyonStore.contains(blockId)) { - tachyonStore.getBytes(blockId) match { - case Some(bytes) => - if (!asBlockResult) { - return Some(bytes) - } else { - return Some(new BlockResult( - dataDeserialize(blockId, bytes), DataReadMethod.Memory, info.size)) - } + logDebug(s"Getting block $blockId from ExternalBlockStore") + if (externalBlockStore.contains(blockId)) { + val result = if (asBlockResult) { + externalBlockStore.getValues(blockId) + .map(new BlockResult(_, DataReadMethod.Memory, info.size)) + } else { + externalBlockStore.getBytes(blockId) + } + result match { + case Some(values) => + return result case None => - logDebug(s"Block $blockId not found in tachyon") + logDebug(s"Block $blockId not found in ExternalBlockStore") } } } @@ -535,9 +541,14 @@ private[spark] class BlockManager( /* We'll store the bytes in memory if the block's storage level includes * "memory serialized", or if it should be cached as objects in memory * but we only requested its serialized bytes. */ - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) + memoryStore.putBytes(blockId, bytes.limit, () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we cannot + // put it into MemoryStore, copyForMemory should not be created. That's why this + // action is put into a `() => ByteBuffer` and created lazily. + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + }) bytes.rewind() } if (!asBlockResult) { @@ -645,13 +656,13 @@ private[spark] class BlockManager( def getDiskWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, - writeMetrics) + new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, + syncWrites, writeMetrics) } /** @@ -750,7 +761,11 @@ private[spark] class BlockManager( case b: ByteBufferValues if putLevel.replication > 1 => // Duplicate doesn't copy the bytes, but just creates a wrapper val bufferView = b.buffer.duplicate() - Future { replicate(blockId, bufferView, putLevel) } + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bufferView, putLevel) + }(futureExecutionContext) case _ => null } @@ -768,8 +783,8 @@ private[spark] class BlockManager( // We will drop it to disk later if the memory store can't hold it. (true, memoryStore) } else if (putLevel.useOffHeap) { - // Use tachyon for off-heap storage - (false, tachyonStore) + // Use external block store + (false, externalBlockStore) } else if (putLevel.useDisk) { // Don't get back the bytes from put unless we replicate them (putLevel.replication > 1, diskStore) @@ -804,7 +819,7 @@ private[spark] class BlockManager( val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) if (putBlockStatus.storageLevel != StorageLevel.NONE) { - // Now that the block is in either the memory, tachyon, or disk store, + // Now that the block is in either the memory, externalBlockStore, or disk store, // let other threads read it, and tell the master about it. marked = true putBlockInfo.markReady(size) @@ -991,15 +1006,23 @@ private[spark] class BlockManager( putIterator(blockId, Iterator(value), level, tellMaster) } + def dropFromMemory( + blockId: BlockId, + data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + dropFromMemory(blockId, () => data) + } + /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * + * If `data` is not put on disk, it won't be created. + * * Return the block status if the given block has been updated, else None. */ def dropFromMemory( blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -1023,7 +1046,7 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") - data match { + data() match { case Left(elements) => diskStore.putArray(blockId, elements, level, returnValues = false) case Right(bytes) => @@ -1074,7 +1097,7 @@ private[spark] class BlockManager( * Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { - logInfo(s"Removing broadcast $broadcastId") + logDebug(s"Removing broadcast $broadcastId") val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } @@ -1086,17 +1109,18 @@ private[spark] class BlockManager( * Remove a block from both memory and disk. */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { - logInfo(s"Removing block $blockId") + logDebug(s"Removing block $blockId") val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) - val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false - if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) { + val removedFromExternalBlockStore = + if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false + if (!removedFromMemory && !removedFromDisk && !removedFromExternalBlockStore) { logWarning(s"Block $blockId could not be removed as it was not found in either " + - "the disk, memory, or tachyon store") + "the disk, memory, or external block store") } blockInfo.remove(blockId) if (tellMaster && info.tellMaster) { @@ -1130,7 +1154,7 @@ private[spark] class BlockManager( val level = info.level if (level.useMemory) { memoryStore.remove(id) } if (level.useDisk) { diskStore.remove(id) } - if (level.useOffHeap) { tachyonStore.remove(id) } + if (level.useOffHeap) { externalBlockStore.remove(id) } iterator.remove() logInfo(s"Dropped block $id") } @@ -1195,8 +1219,19 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream( + blockId: BlockId, + inputStream: InputStream, + serializer: Serializer = defaultSerializer): Iterator[Any] = { + val stream = new BufferedInputStream(inputStream) + serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator } def stop(): Unit = { @@ -1206,15 +1241,16 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() - if (tachyonInitialized) { - tachyonStore.clear() + if (externalBlockStoreInitialized) { + externalBlockStore.clear() } metadataCleaner.cancel() broadcastCleaner.cancel() + futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } } @@ -1245,10 +1281,10 @@ private[spark] object BlockManager extends Logging { } } - def blockIdsToBlockManagers( + def blockIdsToHosts( blockIds: Array[BlockId], env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = { + blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { // blockManagerMaster != null is used in tests assert(env != null || blockManagerMaster != null) @@ -1258,24 +1294,10 @@ private[spark] object BlockManager extends Logging { blockManagerMaster.getLocations(blockIds) } - val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]] + val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i) + blockManagers(blockIds(i)) = blockLocations(i).map(_.host) } blockManagers.toMap } - - def blockIdsToExecutorIds( - blockIds: Array[BlockId], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) - } - - def blockIdsToHosts( - blockIds: Array[BlockId], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) - } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b177a59c721d..69ac37511e73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -60,7 +60,10 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } + def isDriver: Boolean = { + executorId == SparkContext.DRIVER_IDENTIFIER || + executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) @@ -77,11 +80,11 @@ class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port)" override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case id: BlockManagerId => executorId == id.executorId && port == id.port && host == id.host case _ => @@ -100,10 +103,10 @@ private[spark] object BlockManagerId { * @param port Port of the block manager. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int) = + def apply(execId: String, host: String, port: Int): BlockManagerId = getCachedBlockManagerId(new BlockManagerId(execId, host, port)) - def apply(in: ObjectInput) = { + def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) getCachedBlockManagerId(obj) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index b63c7f191155..f45bff34d4db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,38 +17,35 @@ package org.apache.spark.storage +import scala.collection.Iterable +import scala.collection.generic.CanBuildFrom import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global - -import akka.actor._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{ThreadUtils, RpcUtils} private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -58,38 +55,40 @@ class BlockManagerMaster( storageLevel: StorageLevel, memSize: Long, diskSize: Long, - tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( - UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) - logInfo("Updated info of block " + blockId) + externalBlockStoreSize: Long): Boolean = { + val res = driverEndpoint.askWithRetry[Boolean]( + UpdateBlockInfo(blockManagerId, blockId, storageLevel, + memSize, diskSize, externalBlockStoreSize)) + logDebug(s"Updated info of block $blockId") res } /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + GetLocationsMultipleBlockIds(blockIds)) } /** * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. */ - def contains(blockId: BlockId) = { + def contains(blockId: BlockId): Boolean = { !getLocations(blockId).isEmpty } /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,44 +96,44 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") - } + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") - } + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") - } + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -145,11 +144,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,17 +164,24 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip - val result = Await.result(Future.sequence(futures), timeout) - if (result == null) { + implicit val sameThread = ThreadUtils.sameThread + val cbf = + implicitly[ + CanBuildFrom[Iterable[Future[Option[BlockStatus]]], + Option[BlockStatus], + Iterable[Option[BlockStatus]]]] + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) + if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } - val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => status.map { s => (blockManagerId, s) } }.toMap @@ -193,33 +199,36 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) + timeout.awaitResult(future) } - /** Stop the driver actor, called only on the Spark driver node */ + /** + * Find out if the executor has cached blocks. This method does not consider broadcast blocks, + * since they are not reported the master. + */ + def hasCachedBlocks(executorId: String): Boolean = { + driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) + } + + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithRetry[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala deleted file mode 100644 index 64133464d8da..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ /dev/null @@ -1,546 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable -import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ - -import akka.actor.{Actor, ActorRef, Cancellable} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} - -/** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. - */ -private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { - - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] - - // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] - - // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - - private val akkaTimeout = AkkaUtils.askTimeout(conf) - - val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000) - - val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) - - var timeoutCheckingTask: Cancellable = null - - override def preStart() { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - super.preStart() - } - - override def receiveWithLogging = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true - - case UpdateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) - - case GetLocations(blockId) => - sender ! getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) - - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) - - case GetMemoryStatus => - sender ! memoryStatus - - case GetStorageStatus => - sender ! storageStatus - - case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) - - case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) - - case RemoveRdd(rddId) => - sender ! removeRdd(rddId) - - case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) - - case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) - - case RemoveBlock(blockId) => - removeBlockFromWorkers(blockId) - sender ! true - - case RemoveExecutor(execId) => - removeExecutor(execId) - sender ! true - - case StopBlockManagerMaster => - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() - - case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) - - case other => - logWarning("Got unknown message: " + other) - } - - private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks - // from the slaves. - - // Find all blocks for the given RDD, remove the block from both blockLocations and - // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocks.foreach { blockId => - val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) - bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) - blockLocations.remove(blockId) - } - - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher - val removeMsg = RemoveRdd(rddId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq - ) - } - - private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher - val removeMsg = RemoveShuffle(shuffleId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] - }.toSeq - ) - } - - /** - * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified - * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed - * from the executors, but not from the driver. - */ - private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher - val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) - val requiredBlockManagers = blockManagerInfo.values.filter { info => - removeFromDriver || !info.blockManagerId.isDriver - } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq - ) - } - - private def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - - // Remove the block manager from blockManagerIdByExecutor. - blockManagerIdByExecutor -= blockManagerId.executorId - - // Remove it from blockManagerInfo and remove all the blocks. - blockManagerInfo.remove(blockManagerId) - val iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockLocations.get(blockId) - locations -= blockManagerId - if (locations.size == 0) { - blockLocations.remove(blockId) - } - } - listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) - logInfo(s"Removing block manager $blockManagerId") - } - - private def expireDeadHosts() { - logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") - val now = System.currentTimeMillis() - val minSeenTime = now - slaveTimeout - val toRemove = new mutable.HashSet[BlockManagerId] - for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) { - logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " - + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") - toRemove += info.blockManagerId - } - } - toRemove.foreach(removeBlockManager) - } - - private def removeExecutor(execId: String) { - logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") - blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - } - - /** - * Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. - */ - private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.isDriver && !isLocal - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - true - } - } - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: BlockId) { - val locations = blockLocations.get(blockId) - if (locations != null) { - locations.foreach { blockManagerId: BlockManagerId => - val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. - // Doesn't actually wait for a confirmation and the message might get lost. - // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) - } - } - } - } - - // Return a map from the block manager id to max memory and remaining memory. - private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - } - - private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) - }.toArray - } - - /** - * Return the block's status for all block managers, if any. NOTE: This is a - * potentially expensive operation and should only be used for testing. - * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block - * managers. - */ - private def blockStatus( - blockId: BlockId, - askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher - val getBlockStatus = GetBlockStatus(blockId) - /* - * Rather than blocking on the block status query, master actor should simply return - * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. - */ - blockManagerInfo.values.map { info => - val blockStatusFuture = - if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] - } else { - Future { info.getStatus(blockId) } - } - (info.blockManagerId, blockStatusFuture) - }.toMap - } - - /** - * Return the ids of blocks present in all the block managers that match the given filter. - * NOTE: This is a potentially expensive operation and should only be used for testing. - * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block - * managers. - */ - private def getMatchingBlockIds( - filter: BlockId => Boolean, - askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher - val getMatchingBlockIds = GetMatchingBlockIds(filter) - Future.sequence( - blockManagerInfo.values.map { info => - val future = - if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] - } else { - Future { info.blocks.keys.filter(filter).toSeq } - } - future - } - ).map(_.flatten.toSeq) - } - - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val time = System.currentTimeMillis() - if (!blockManagerInfo.contains(id)) { - blockManagerIdByExecutor.get(id.executorId) match { - case Some(oldId) => - // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " - + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) - case None => - } - logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) - - blockManagerIdByExecutor(id.executorId) = id - - blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) - } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) - } - - private def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): Boolean = { - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.isDriver && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - return true - } else { - return false - } - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - return true - } - - blockManagerInfo(blockManagerId).updateBlockInfo( - blockId, storageLevel, memSize, diskSize, tachyonSize) - - var locations: mutable.HashSet[BlockManagerId] = null - if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId) - } else { - locations = new mutable.HashSet[BlockManagerId] - blockLocations.put(blockId, locations) - } - - if (storageLevel.isValid) { - locations.add(blockManagerId) - } else { - locations.remove(blockManagerId) - } - - // Remove the block from master tracking if it has been removed on all slaves. - if (locations.size == 0) { - blockLocations.remove(blockId) - } - true - } - - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty - } - - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - blockIds.map(blockId => getLocations(blockId)) - } - - /** Get the list of the peers of the given block manager */ - private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - val blockManagerIds = blockManagerInfo.keySet - if (blockManagerIds.contains(blockManagerId)) { - blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq - } else { - Seq.empty - } - } - - /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. - */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - for ( - blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port - ) yield { - (host, port) - } - } -} - -@DeveloperApi -case class BlockStatus( - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long) { - def isCached: Boolean = memSize + diskSize + tachyonSize > 0 -} - -@DeveloperApi -object BlockStatus { - def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) -} - -private[spark] class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) - extends Logging { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[BlockId, BlockStatus] - - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo( - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long) { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize - - if (originalLevel.useMemory) { - _remainingMem += originalMemSize - } - } - - if (storageLevel.isValid) { - /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * The memSize here indicates the data size in or dropped from memory, - * tachyonSize here indicates the data size in or dropped from Tachyon, - * and the diskSize here indicates the data size in or dropped to disk. - * They can be both larger than 0, when a block is dropped from memory to disk. - * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ - if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) - _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) - } - if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize)) - logInfo("Added %s on tachyon on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) - } - if (blockStatus.storageLevel.useOffHeap) { - logInfo("Removed %s on %s on tachyon (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize))) - } - } - } - - def removeBlock(blockId: BlockId) { - if (_blocks.containsKey(blockId)) { - _remainingMem += _blocks.get(blockId).memSize - _blocks.remove(blockId) - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[BlockId, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala new file mode 100644 index 000000000000..7db6035553ae --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -0,0 +1,539 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.immutable.HashSet +import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. + */ +private[spark] +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] + + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] + + // Mapping from block id to the set of block managers that have the block. + private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) + + case _updateBlockInfo @ UpdateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) + + case GetLocations(blockId) => + context.reply(getLocations(blockId)) + + case GetLocationsMultipleBlockIds(blockIds) => + context.reply(getLocationsMultipleBlockIds(blockIds)) + + case GetPeers(blockManagerId) => + context.reply(getPeers(blockManagerId)) + + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) + + case GetMemoryStatus => + context.reply(memoryStatus) + + case GetStorageStatus => + context.reply(storageStatus) + + case GetBlockStatus(blockId, askSlaves) => + context.reply(blockStatus(blockId, askSlaves)) + + case GetMatchingBlockIds(filter, askSlaves) => + context.reply(getMatchingBlockIds(filter, askSlaves)) + + case RemoveRdd(rddId) => + context.reply(removeRdd(rddId)) + + case RemoveShuffle(shuffleId) => + context.reply(removeShuffle(shuffleId)) + + case RemoveBroadcast(broadcastId, removeFromDriver) => + context.reply(removeBroadcast(broadcastId, removeFromDriver)) + + case RemoveBlock(blockId) => + removeBlockFromWorkers(blockId) + context.reply(true) + + case RemoveExecutor(execId) => + removeExecutor(execId) + context.reply(true) + + case StopBlockManagerMaster => + context.reply(true) + stop() + + case BlockManagerHeartbeat(blockManagerId) => + context.reply(heartbeatReceived(blockManagerId)) + + case HasCachedBlocks(executorId) => + blockManagerIdByExecutor.get(executorId) match { + case Some(bm) => + if (blockManagerInfo.contains(bm)) { + val bmInfo = blockManagerInfo(bm) + context.reply(bmInfo.cachedBlocks.nonEmpty) + } else { + context.reply(false) + } + case None => context.reply(false) + } + } + + private def removeRdd(rddId: Int): Future[Seq[Int]] = { + // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // from the slaves. + + // Find all blocks for the given RDD, remove the block from both blockLocations and + // the blockManagerInfo that is tracking the blocks. + val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocks.foreach { blockId => + val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) + bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) + blockLocations.remove(blockId) + } + + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + val removeMsg = RemoveRdd(rddId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg) + }.toSeq + ) + } + + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + // Nothing to do in the BlockManagerMasterEndpoint data structures + val removeMsg = RemoveShuffle(shuffleId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveEndpoint.ask[Boolean](removeMsg) + }.toSeq + ) + } + + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { + val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || !info.blockManagerId.isDriver + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg) + }.toSeq + ) + } + + private def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId + + // Remove it from blockManagerInfo and remove all the blocks. + blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockLocations.get(blockId) + locations -= blockManagerId + if (locations.size == 0) { + blockLocations.remove(blockId) + } + } + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) + logInfo(s"Removing block manager $blockManagerId") + } + + private def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) + } + + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { + if (!blockManagerInfo.contains(blockManagerId)) { + blockManagerId.isDriver && !isLocal + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + true + } + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlockFromWorkers(blockId: BlockId) { + val locations = blockLocations.get(blockId) + if (locations != null) { + locations.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveEndpoint.ask[Boolean](RemoveBlock(blockId)) + } + } + } + } + + // Return a map from the block manager id to max memory and remaining memory. + private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { + blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + } + + private def storageStatus: Array[StorageStatus] = { + blockManagerInfo.map { case (blockManagerId, info) => + new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) + }.toArray + } + + /** + * Return the block's status for all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master endpoint should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master endpoint's response to a previous message. + */ + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveEndpoint.ask[Option[BlockStatus]](getBlockStatus) + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) + }.toMap + } + + /** + * Return the ids of blocks present in all the block managers that match the given filter. + * NOTE: This is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Future[Seq[BlockId]] = { + val getMatchingBlockIds = GetMatchingBlockIds(filter) + Future.sequence( + blockManagerInfo.values.map { info => + val future = + if (askSlaves) { + info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) + } else { + Future { info.blocks.asScala.keys.filter(filter).toSeq } + } + future + } + ).map(_.flatten.toSeq) + } + + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + val time = System.currentTimeMillis() + if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(oldId) => + // A block manager of the same executor already exists, so remove it (assumed dead) + logError("Got two different block manager registrations on same executor - " + + s" will replace old one $oldId with new one $id") + removeExecutor(id.executorId) + case None => + } + logInfo("Registering block manager %s with %s RAM, %s".format( + id.hostPort, Utils.bytesToString(maxMemSize), id)) + + blockManagerIdByExecutor(id.executorId) = id + + blockManagerInfo(id) = new BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + } + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + } + + private def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long): Boolean = { + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.isDriver && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + return true + } else { + return false + } + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + return true + } + + blockManagerInfo(blockManagerId).updateBlockInfo( + blockId, storageLevel, memSize, diskSize, externalBlockStoreSize) + + var locations: mutable.HashSet[BlockManagerId] = null + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId) + } else { + locations = new mutable.HashSet[BlockManagerId] + blockLocations.put(blockId, locations) + } + + if (storageLevel.isValid) { + locations.add(blockManagerId) + } else { + locations.remove(blockManagerId) + } + + // Remove the block from master tracking if it has been removed on all slaves. + if (locations.size == 0) { + blockLocations.remove(blockId) + } + true + } + + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty + } + + private def getLocationsMultipleBlockIds( + blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + blockIds.map(blockId => getLocations(blockId)) + } + + /** Get the list of the peers of the given block manager */ + private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + val blockManagerIds = blockManagerInfo.keySet + if (blockManagerIds.contains(blockManagerId)) { + blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq + } else { + Seq.empty + } + } + + /** + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. + */ + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId) + ) yield { + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + } + } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } +} + +@DeveloperApi +case class BlockStatus( + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) { + def isCached: Boolean = memSize + diskSize + externalBlockStoreSize > 0 +} + +@DeveloperApi +object BlockStatus { + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) +} + +private[spark] class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveEndpoint: RpcEndpointRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[BlockId, BlockStatus] + + // Cached blocks held by this BlockManager. This does not include broadcast blocks. + private val _cachedBlocks = new mutable.HashSet[BlockId] + + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo( + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize + + if (originalLevel.useMemory) { + _remainingMem += originalMemSize + } + } + + if (storageLevel.isValid) { + /* isValid means it is either stored in-memory, on-disk or on-externalBlockStore. + * The memSize here indicates the data size in or dropped from memory, + * externalBlockStoreSize here indicates the data size in or dropped from externalBlockStore, + * and the diskSize here indicates the data size in or dropped to disk. + * They can be both larger than 0, when a block is dropped from memory to disk. + * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + var blockStatus: BlockStatus = null + if (storageLevel.useMemory) { + blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + _blocks.put(blockId, blockStatus) + _remainingMem -= memSize + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), + Utils.bytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + _blocks.put(blockId, blockStatus) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + } + if (storageLevel.useOffHeap) { + blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) + _blocks.put(blockId, blockStatus) + logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) + } + if (!blockId.isBroadcast && blockStatus.isCached) { + _cachedBlocks += blockId + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + _cachedBlocks -= blockId + if (blockStatus.storageLevel.useMemory) { + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), + Utils.bytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + } + if (blockStatus.storageLevel.useOffHeap) { + logInfo("Removed %s on %s on externalBlockStore (size: %s)".format( + blockId, blockManagerId.hostPort, + Utils.bytesToString(blockStatus.externalBlockStoreSize))) + } + } + } + + def removeBlock(blockId: BlockId) { + if (_blocks.containsKey(blockId)) { + _remainingMem += _blocks.get(blockId).memSize + _blocks.remove(blockId) + } + _cachedBlocks -= blockId + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[BlockId, BlockStatus] = _blocks + + // This does not include broadcast blocks. + def cachedBlocks: collection.Set[BlockId] = _cachedBlocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3f32099d08cc..376e9eb48843 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -43,7 +42,6 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -52,7 +50,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -61,7 +59,7 @@ private[spark] object BlockManagerMessages { var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long, - var tachyonSize: Long) + var externalBlockStoreSize: Long) extends ToBlockManagerMaster with Externalizable { @@ -73,7 +71,7 @@ private[spark] object BlockManagerMessages { storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) - out.writeLong(tachyonSize) + out.writeLong(externalBlockStoreSize) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -82,7 +80,7 @@ private[spark] object BlockManagerMessages { storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() - tachyonSize = in.readLong() + externalBlockStoreSize = in.readLong() } } @@ -92,7 +90,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster @@ -110,5 +108,5 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case object ExpireDeadHosts extends ToBlockManagerMaster + case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala deleted file mode 100644 index 8462871e798a..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} - -import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive - -/** - * An actor to take commands from the master to execute options. For example, - * this is used to remove blocks from the slave's BlockManager. - */ -private[storage] -class BlockManagerSlaveActor( - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { - - import context.dispatcher - - // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { - case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { - blockManager.removeBlock(blockId) - true - } - - case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { - blockManager.removeRdd(rddId) - } - - case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) - } - SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) - } - - case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { - blockManager.removeBroadcast(broadcastId, tellMaster = true) - } - - case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) - - case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) - } - - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { - val future = Future { - logDebug(actionMessage) - body - } - future.onSuccess { case response => - logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) - } - future.onFailure { case t: Throwable => - logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala new file mode 100644 index 000000000000..7478ab0fc2f7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.ThreadUtils +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.storage.BlockManagerMessages._ + +/** + * An RpcEndpoint to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +private[storage] +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + extends RpcEndpoint with Logging { + + private val asyncThreadPool = + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) + + // Operations that involve removing blocks may be slow and should be done asynchronously + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RemoveBlock(blockId) => + doAsync[Boolean]("removing block " + blockId, context) { + blockManager.removeBlock(blockId) + true + } + + case RemoveRdd(rddId) => + doAsync[Int]("removing RDD " + rddId, context) { + blockManager.removeRdd(rddId) + } + + case RemoveShuffle(shuffleId) => + doAsync[Boolean]("removing shuffle " + shuffleId, context) { + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) + } + + case RemoveBroadcast(broadcastId, _) => + doAsync[Int]("removing broadcast " + broadcastId, context) { + blockManager.removeBroadcast(broadcastId, tellMaster = true) + } + + case GetBlockStatus(blockId, _) => + context.reply(blockManager.getStatus(blockId)) + + case GetMatchingBlockIds(filter, _) => + context.reply(blockManager.getMatchingBlockIds(filter)) + } + + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { + val future = Future { + logDebug(actionMessage) + body + } + future.onSuccess { case response => + logDebug("Done " + actionMessage + ", response is " + response) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + context.sendFailure(t) + } + } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 8569c6f3cbbc..c5ba9af3e265 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala deleted file mode 100644 index 3198d766fca3..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} -import java.nio.channels.FileChannel - -import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializationStream, Serializer} -import org.apache.spark.executor.ShuffleWriteMetrics - -/** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. - * - * This interface does not support concurrent writes. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes an object. - */ - def write(value: Any) - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. - */ -private[spark] class DiskBlockObjectWriter( - blockId: BlockId, - file: File, - serializer: Serializer, - bufferSize: Int, - compressStream: OutputStream => OutputStream, - syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who - // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ - private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]) = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) - override def close() = out.close() - override def flush() = out.flush() - } - - /** The file channel, used for repositioning / truncating the file. */ - private var channel: FileChannel = null - private var bs: OutputStream = null - private var fos: FileOutputStream = null - private var ts: TimeTrackingOutputStream = null - private var objOut: SerializationStream = null - private var initialized = false - - /** - * Cursors used to represent positions in the file. - * - * xxxxxxxx|--------|--- | - * ^ ^ ^ - * | | finalPosition - * | reportedPosition - * initialPosition - * - * initialPosition: Offset in the file where we start writing. Immutable. - * reportedPosition: Position at the time of the last update to the write metrics. - * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed. - * -----: Current writes to the underlying file. - * xxxxx: Existing contents of the file. - */ - private val initialPosition = file.length() - private var finalPosition: Long = -1 - private var reportedPosition = initialPosition - - /** Calling channel.position() to update the write metrics can be a little bit expensive, so we - * only call it every N writes */ - private var writesSinceMetricsUpdate = 0 - - override def open(): BlockObjectWriter = { - fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) - channel = fos.getChannel() - bs = compressStream(new BufferedOutputStream(ts, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - def sync = fos.getFD.sync() - callWithTiming(sync) - } - objOut.close() - - channel = null - bs = null - fos = null - ts = null - objOut = null - initialized = false - } - } - - override def isOpen: Boolean = objOut != null - - override def commitAndClose(): Unit = { - if (initialized) { - // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the - // serializer stream and the lower level stream. - objOut.flush() - bs.flush() - close() - } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) - } - - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { - try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - - if (initialized) { - objOut.flush() - bs.flush() - close() - } - - val truncateStream = new FileOutputStream(file, true) - try { - truncateStream.getChannel.truncate(initialPosition) - } finally { - truncateStream.close() - } - } catch { - case e: Exception => - logError("Uncaught exception while reverting partial writes to file " + file, e) - } - } - - override def write(value: Any) { - if (!initialized) { - open() - } - - objOut.writeObject(value) - - if (writesSinceMetricsUpdate == 32) { - writesSinceMetricsUpdate = 0 - updateBytesWritten() - } else { - writesSinceMetricsUpdate += 1 - } - } - - override def fileSegment(): FileSegment = { - new FileSegment(file, initialPosition, finalPosition - initialPosition) - } - - /** - * Report the number of bytes written in this writer's shuffle write metrics. - * Note that this is only valid before the underlying streams are closed. - */ - private def updateBytesWritten() { - val pos = channel.position() - writeMetrics.incShuffleBytesWritten(pos - reportedPosition) - reportedPosition = pos - } - - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - - // For testing - private[spark] def flush() { - objOut.flush() - bs.flush() - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala new file mode 100644 index 000000000000..2789e25b8d3a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +private[spark] case class BlockUIData( + blockId: BlockId, + location: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +/** + * The aggregated status of stream blocks in an executor + */ +private[spark] case class ExecutorStreamBlockStatus( + executorId: String, + location: String, + blocks: Seq[BlockUIData]) { + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum + + def numStreamBlocks: Int = blocks.size + +} + +private[spark] class BlockStatusListener extends SparkListener { + + private val blockManagers = + new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val blockId = blockUpdated.blockUpdatedInfo.blockId + if (!blockId.isInstanceOf[StreamBlockId]) { + // Now we only monitor StreamBlocks + return + } + val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize + + synchronized { + // Drop the update info if the block manager is not registered + blockManagers.get(blockManagerId).foreach { blocksInBlockManager => + if (storageLevel.isValid) { + blocksInBlockManager.put(blockId, + BlockUIData( + blockId, + blockManagerId.hostPort, + storageLevel, + memSize, + diskSize, + externalBlockStoreSize) + ) + } else { + // If isValid is not true, it means we should drop the block. + blocksInBlockManager -= blockId + } + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + synchronized { + blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { + blockManagers -= blockManagerRemoved.blockManagerId + } + + def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { + blockManagers.map { case (blockManagerId, blocks) => + ExecutorStreamBlockStatus( + blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) + }.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala new file mode 100644 index 000000000000..a5790e4454a8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo + +/** + * :: DeveloperApi :: + * Stores information about a block status in a block manager. + */ +@DeveloperApi +case class BlockUpdatedInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +private[spark] object BlockUpdatedInfo { + + private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = { + BlockUpdatedInfo( + updateBlockInfo.blockManagerId, + updateBlockInfo.blockId, + updateBlockInfo.storageLevel, + updateBlockInfo.memSize, + updateBlockInfo.diskSize, + updateBlockInfo.externalBlockStoreSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 53eaedacbf29..3f8d26e1d4ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -47,13 +47,15 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content + // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - addShutdownHook() + private val shutdownHook = addShutdownHook() /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with - // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile(). + // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getFile(). def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -61,20 +63,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon val subDirId = (hash / localDirs.length) % subDirsPerLocalDir // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") - } - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -91,7 +90,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories - subDirs.flatten.filter(_ != null).flatMap { dir => + subDirs.flatMap { dir => + dir.synchronized { + // Copy the content of dir because it may be modified in other threads + dir.clone() + } + }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files else Seq.empty } @@ -120,8 +124,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") logInfo(s"Created local directory at $localDir") @@ -134,23 +143,34 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - DiskBlockManager.this.stop() - } - }) + private def addShutdownHook(): AnyRef = { + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + logInfo("Shutdown hook called") + DiskBlockManager.this.doStop() + } } /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { + // Remove the shutdown hook. It causes memory leaks if we leave it around. + try { + ShutdownHookManager.removeShutdownHook(shutdownHook) + } catch { + case e: Exception => + logError(s"Exception while removing shutdown hook.", e) + } + doStop() + } + + private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala new file mode 100644 index 000000000000..49d9154f95a5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} +import java.nio.channels.FileChannel + +import org.apache.spark.Logging +import org.apache.spark.serializer.{SerializerInstance, SerializationStream} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils + +/** + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. + * + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. + */ +private[spark] class DiskBlockObjectWriter( + val blockId: BlockId, + file: File, + serializerInstance: SerializerInstance, + bufferSize: Int, + compressStream: OutputStream => OutputStream, + syncWrites: Boolean, + // These write metrics concurrently shared with other active DiskBlockObjectWriters who + // are themselves performing writes. All updates must be relative. + writeMetrics: ShuffleWriteMetrics) + extends OutputStream + with Logging { + + /** The file channel, used for repositioning / truncating the file. */ + private var channel: FileChannel = null + private var bs: OutputStream = null + private var fos: FileOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var objOut: SerializationStream = null + private var initialized = false + private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false + + /** + * Cursors used to represent positions in the file. + * + * xxxxxxxx|--------|--- | + * ^ ^ ^ + * | | finalPosition + * | reportedPosition + * initialPosition + * + * initialPosition: Offset in the file where we start writing. Immutable. + * reportedPosition: Position at the time of the last update to the write metrics. + * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed. + * -----: Current writes to the underlying file. + * xxxxx: Existing contents of the file. + */ + private val initialPosition = file.length() + private var finalPosition: Long = -1 + private var reportedPosition = initialPosition + + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + */ + private var numRecordsWritten = 0 + + def open(): DiskBlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } + fos = new FileOutputStream(file, true) + ts = new TimeTrackingOutputStream(writeMetrics, fos) + channel = fos.getChannel() + bs = compressStream(new BufferedOutputStream(ts, bufferSize)) + objOut = serializerInstance.serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) + } + } { + objOut.close() + } + + channel = null + bs = null + fos = null + ts = null + objOut = null + initialized = false + hasBeenClosed = true + } + } + + def isOpen: Boolean = objOut != null + + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { + if (initialized) { + // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the + // serializer stream and the lower level stream. + objOut.flush() + bs.flush() + close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() + } + commitAndCloseHasBeenCalled = true + } + + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + try { + if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) + objOut.flush() + bs.flush() + close() + } + + val truncateStream = new FileOutputStream(file, true) + try { + truncateStream.getChannel.truncate(initialPosition) + } finally { + truncateStream.close() + } + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) + } + } + + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { + if (!initialized) { + open() + } + + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + override def write(b: Int): Unit = throw new UnsupportedOperationException() + + override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { + if (!initialized) { + open() + } + + bs.write(kvBytes, offs, len) + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) + + if (numRecordsWritten % 32 == 0) { + updateBytesWritten() + } + } + + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } + new FileSegment(file, initialPosition, finalPosition - initialPosition) + } + + /** + * Report the number of bytes written in this writer's shuffle write metrics. + * Note that this is only valid before the underlying streams are closed. + */ + private def updateBytesWritten() { + val pos = channel.position() + writeMetrics.incShuffleBytesWritten(pos - reportedPosition) + reportedPosition = pos + } + + // For testing + private[spark] override def flush() { + objOut.flush() + bs.flush() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff16879..1f4595628216 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length @@ -46,10 +45,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel - while (bytes.remaining > 0) { - channel.write(bytes) + Utils.tryWithSafeFinally { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } { + channel.close() } - channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) @@ -75,9 +77,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) try { - try { + Utils.tryWithSafeFinally { blockManager.dataSerializeStream(blockId, outputStream, values) - } finally { + } { // Close outputStream here because it should be closed before file is deleted. outputStream.close() } @@ -106,8 +108,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel - - try { + Utils.tryWithSafeFinally { // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) @@ -123,7 +124,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } else { Some(channel.map(MapMode.READ_ONLY, offset, length)) } - } finally { + } { channel.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala new file mode 100644 index 000000000000..f39325a12d24 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +/** + * An abstract class that the concrete external block manager has to inherit. + * The class has to have a no-argument constructor, and will be initialized by init, + * which is invoked by ExternalBlockStore. The main input parameter is blockId for all + * the methods, which is the unique identifier for Block in one Spark application. + * + * The underlying external block manager should avoid any name space conflicts among multiple + * Spark applications. For example, creating different directory for different applications + * by randomUUID + * + */ +private[spark] abstract class ExternalBlockManager { + + protected var blockManager: BlockManager = _ + + override def toString: String = {"External Block Store"} + + /** + * Initialize a concrete block manager implementation. Subclass should initialize its internal + * data structure, e.g, file system, in this function, which is invoked by ExternalBlockStore + * right after the class is constructed. The function should throw IOException on failure + * + * @throws java.io.IOException if there is any file system failure during the initialization. + */ + def init(blockManager: BlockManager, executorId: String): Unit = { + this.blockManager = blockManager + } + + /** + * Drop the block from underlying external block store, if it exists.. + * @return true on successfully removing the block + * false if the block could not be removed as it was not found + * + * @throws java.io.IOException if there is any file system failure in removing the block. + */ + def removeBlock(blockId: BlockId): Boolean + + /** + * Used by BlockManager to check the existence of the block in the underlying external + * block store. + * @return true if the block exists. + * false if the block does not exists. + * + * @throws java.io.IOException if there is any file system failure in checking + * the block existence. + */ + def blockExists(blockId: BlockId): Boolean + + /** + * Put the given block to the underlying external block store. Note that in normal case, + * putting a block should never fail unless something wrong happens to the underlying + * external block store, e.g., file system failure, etc. In this case, IOException + * should be thrown. + * + * @throws java.io.IOException if there is any file system failure in putting the block. + */ + def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit + + def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val bytes = blockManager.dataSerialize(blockId, values) + putBytes(blockId, bytes) + } + + /** + * Retrieve the block bytes. + * @return Some(ByteBuffer) if the block bytes is successfully retrieved + * None if the block does not exist in the external block store. + * + * @throws java.io.IOException if there is any file system failure in getting the block. + */ + def getBytes(blockId: BlockId): Option[ByteBuffer] + + /** + * Retrieve the block data. + * @return Some(Iterator[Any]) if the block data is successfully retrieved + * None if the block does not exist in the external block store. + * + * @throws java.io.IOException if there is any file system failure in getting the block. + */ + def getValues(blockId: BlockId): Option[Iterator[_]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + } + + /** + * Get the size of the block saved in the underlying external block store, + * which is saved before by putBytes. + * @return size of the block + * 0 if the block does not exist + * + * @throws java.io.IOException if there is any file system failure in getting the block size. + */ + def getSize(blockId: BlockId): Long + + /** + * Clean up any information persisted in the underlying external block store, + * e.g., the directory, files, etc,which is invoked by the shutdown hook of ExternalBlockStore + * during system shutdown. + * + */ + def shutdown() +} diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala new file mode 100644 index 000000000000..db965d54bafd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + + +/** + * Stores BlockManager blocks on ExternalBlockStore. + * We capture any potential exception from underlying implementation + * and return with the expected failure value + */ +private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: String) + extends BlockStore(blockManager: BlockManager) with Logging { + + lazy val externalBlockManager: Option[ExternalBlockManager] = createBlkManager() + + logInfo("ExternalBlockStore started") + + override def getSize(blockId: BlockId): Long = { + try { + externalBlockManager.map(_.getSize(blockId)).getOrElse(0) + } catch { + case NonFatal(t) => + logError(s"Error in getSize($blockId)", t) + 0L + } + } + + override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { + putIntoExternalBlockStore(blockId, bytes, returnValues = true) + } + + override def putArray( + blockId: BlockId, + values: Array[Any], + level: StorageLevel, + returnValues: Boolean): PutResult = { + putIntoExternalBlockStore(blockId, values.toIterator, returnValues) + } + + override def putIterator( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean): PutResult = { + putIntoExternalBlockStore(blockId, values, returnValues) + } + + private def putIntoExternalBlockStore( + blockId: BlockId, + values: Iterator[_], + returnValues: Boolean): PutResult = { + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") + // we should never hit here if externalBlockManager is None. Handle it anyway for safety. + try { + val startTime = System.currentTimeMillis + if (externalBlockManager.isDefined) { + externalBlockManager.get.putValues(blockId, values) + val size = getSize(blockId) + val data = if (returnValues) { + Left(getValues(blockId).get) + } else { + null + } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) + } else { + logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } catch { + case NonFatal(t) => + logError(s"Error in putValues($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } + + private def putIntoExternalBlockStore( + blockId: BlockId, + bytes: ByteBuffer, + returnValues: Boolean): PutResult = { + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") + // we should never hit here if externalBlockManager is None. Handle it anyway for safety. + try { + val startTime = System.currentTimeMillis + if (externalBlockManager.isDefined) { + val byteBuffer = bytes.duplicate() + byteBuffer.rewind() + externalBlockManager.get.putBytes(blockId, byteBuffer) + val size = bytes.limit() + val data = if (returnValues) { + Right(bytes) + } else { + null + } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) + } else { + logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } catch { + case NonFatal(t) => + logError(s"Error in putBytes($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } + + // We assume the block is removed even if exception thrown + override def remove(blockId: BlockId): Boolean = { + try { + externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true) + } catch { + case NonFatal(t) => + logError(s"Error in removeBlock($blockId)", t) + true + } + } + + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { + try { + externalBlockManager.flatMap(_.getValues(blockId)) + } catch { + case NonFatal(t) => + logError(s"Error in getValues($blockId)", t) + None + } + } + + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + try { + externalBlockManager.flatMap(_.getBytes(blockId)) + } catch { + case NonFatal(t) => + logError(s"Error in getBytes($blockId)", t) + None + } + } + + override def contains(blockId: BlockId): Boolean = { + try { + val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false) + if (!ret) { + logInfo(s"Remove block $blockId") + blockManager.removeBlock(blockId, true) + } + ret + } catch { + case NonFatal(t) => + logError(s"Error in getBytes($blockId)", t) + false + } + } + + private def addShutdownHook() { + Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") { + override def run(): Unit = Utils.logUncaughtExceptions { + logDebug("Shutdown hook called") + externalBlockManager.map(_.shutdown()) + } + }) + } + + // Create concrete block manager and fall back to Tachyon by default for backward compatibility. + private def createBlkManager(): Option[ExternalBlockManager] = { + val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) + .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) + + try { + val instance = Utils.classForName(clsName) + .newInstance() + .asInstanceOf[ExternalBlockManager] + instance.init(blockManager, executorId) + addShutdownHook(); + Some(instance) + } catch { + case NonFatal(t) => + logError("Cannot initialize external block store", t) + None + } + } +} + +private[spark] object ExternalBlockStore extends Logging { + val MAX_DIR_CREATION_ATTEMPTS = 10 + val SUB_DIRS_PER_DIR = "64" + val BASE_DIR = "spark.externalBlockStore.baseDir" + val FOLD_NAME = "spark.externalBlockStore.folderName" + val MASTER_URL = "spark.externalBlockStore.url" + val BLOCK_MANAGER_NAME = "spark.externalBlockStore.blockManager" + val DEFAULT_BLOCK_MANAGER_NAME = "org.apache.spark.storage.TachyonBlockManager" +} diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 132502b75f8c..021a9facfb0b 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,5 +24,9 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 71305a46bf57..6f27f00307f8 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,9 +44,17 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() + // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. + // Pending unroll memory refers to the intermediate memory occupied by a task + // after the unroll but before the actual putting of the block in the cache. + // This chunk of memory is expected to be released *as soon as* we finish + // caching the corresponding block as opposed to until after the task finishes. + // This is only used if a block is successfully unrolled in its entirety in + // memory (SPARK-4777). + private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() /** * The amount of space ensured for unrolling values in memory, shared across all cores. @@ -90,6 +99,26 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and + * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. + * + * The caller should guarantee that `size` is correct. + */ + def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { + // Work on a duplicate - since the original input might be used elsewhere. + lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] + val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val data = + if (putAttempt.success) { + assert(bytes.limit == size) + Right(bytes.duplicate()) + } else { + null + } + PutResult(size, data, putAttempt.droppedBlocks) + } + override def putArray( blockId: BlockId, values: Array[Any], @@ -184,7 +213,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val entry = entries.remove(blockId) if (entry != null) { currentMemory -= entry.size - logInfo(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") + logDebug(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") true } else { false @@ -222,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + // Previous unroll memory held by this task, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -255,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -263,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -283,12 +312,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } finally { - // If we return an array, the values returned do not depend on the underlying vector and - // we can immediately free up space for other threads. Otherwise, if we return an iterator, - // we release the memory claimed by this thread later on when the task finishes. + // If we return an array, the values returned will later be cached in `tryToPut`. + // In this case, we should release the memory after we cache the block there. + // Otherwise, if we return an iterator, we release the memory reserved here + // later when the task finishes. if (keepUnrolling) { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) + accountingLock.synchronized { + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) + } } } } @@ -300,11 +333,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId.asRDDId.map(_.rddId) } + private def tryToPut( + blockId: BlockId, + value: Any, + size: Long, + deserialized: Boolean): ResultWithDroppedBlocks = { + tryToPut(blockId, () => value, size, deserialized) + } + /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size * must also be passed by the caller. * + * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be + * created to avoid OOM since it may be a big ByteBuffer. + * * Synchronize on `accountingLock` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for @@ -314,7 +358,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ private def tryToPut( blockId: BlockId, - value: Any, + value: () => Any, size: Long, deserialized: Boolean): ResultWithDroppedBlocks = { @@ -333,7 +377,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new MemoryEntry(value, size, deserialized) + val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size @@ -345,14 +389,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[Array[Any]]) + lazy val data = if (deserialized) { + Left(value().asInstanceOf[Array[Any]]) } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) + Right(value().asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } + // Release the unroll memory used because we no longer need the underlying Array + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -381,7 +427,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } // Take into account the amount of memory currently occupied by unrolling blocks - val actualFreeMemory = freeMemory - currentUnrollMemory + // and minus the pending unroll memory for that block on current thread. + val taskAttemptId = currentTaskAttemptId() + val actualFreeMemory = freeMemory - currentUnrollMemory + + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -407,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -434,58 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } } /** - * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory - // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory + // If this task claims no more unroll memory, release it completely + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Reserve the unroll memory of current unroll successful block used by this task + * until actually put the block into memory entry. + */ + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() + accountingLock.synchronized { + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory + } + } + + /** + * Release pending unroll memory of current unroll successful block used by this task + */ + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() + accountingLock.synchronized { + pendingUnrollMemoryMap.remove(taskAttemptId) + } + } + + /** + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { - unrollMemoryMap.values.sum + unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -497,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 120c327a7e58..96062626b504 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDDOperationScope, RDD} import org.apache.spark.util.Utils @DeveloperApi @@ -26,32 +26,36 @@ class RDDInfo( val id: Int, val name: String, val numPartitions: Int, - var storageLevel: StorageLevel) + var storageLevel: StorageLevel, + val parentIds: Seq[Int], + val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L - var tachyonSize = 0L + var externalBlockStoreSize = 0L - def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 + def isCached: Boolean = + (memSize + diskSize + externalBlockStoreSize > 0) && numCachedPartitions > 0 - override def toString = { + override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + - "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( + "MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s").format( name, id, storageLevel.toString, numCachedPartitions, numPartitions, - bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) + bytesToString(memSize), bytesToString(externalBlockStoreSize), bytesToString(diskSize)) } - override def compare(that: RDDInfo) = { + override def compare(that: RDDInfo): Int = { this.id - that.id } } private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { - val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) + val parentIds = rdd.dependencies.map(_.rdd.id) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ab9ee4f0096b..a759ceb96ec1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,25 +17,24 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.{Logging, SparkException, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.Serializer -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -46,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -55,9 +53,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -85,7 +82,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -104,7 +101,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,23 +111,29 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } } @@ -155,7 +158,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } @@ -164,7 +167,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), e)) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) @@ -234,13 +237,14 @@ final class ShuffleBlockFetcherIterator( try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, e)) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } @@ -271,7 +275,15 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -280,7 +292,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size case _ => } // Send fetch requests up to maxBytesInFlight @@ -289,29 +301,64 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { - case FailureFetchResult(_, e) => - Failure(e) - case SuccessFetchResult(blockId, _, buf) => - // There is a chance that createInputStream can fail (e.g. fetching a local file that does - // not exist, SPARK-4085). In that case, we should propagate the right exception so - // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializer.newInstance().deserializeStream(is).asIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + result match { + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + + case SuccessFetchResult(blockId, address, _, buf) => + try { + (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) + } catch { + case NonFatal(t) => + throwFetchFailedException(blockId, address, t) } } + } - (result.blockId, iteratorTry) + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) + } } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { @@ -331,16 +378,22 @@ object ShuffleBlockFetcherIterator { */ private[storage] sealed trait FetchResult { val blockId: BlockId + val address: BlockManagerId } /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. */ - private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + private[storage] case class SuccessFetchResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer) extends FetchResult { require(buf != null) require(size >= 0) @@ -349,8 +402,12 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ - private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + private[storage] case class FailureFetchResult( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index e5e1cf5a69a1..703bce3e6b85 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -26,9 +26,9 @@ import org.apache.spark.util.Utils /** * :: DeveloperApi :: * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, - * or Tachyon, whether to drop the RDD to disk if it falls out of memory or Tachyon , whether to - * keep the data in memory in a serialized format, and whether to replicate the RDD partitions on - * multiple nodes. + * or ExternalBlockStore, whether to drop the RDD to disk if it falls out of memory or + * ExternalBlockStore, whether to keep the data in memory in a serialized format, and whether + * to replicate the RDD partitions on multiple nodes. * * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants * for commonly useful storage levels. To create your own storage level object, use the @@ -50,11 +50,11 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = _useDisk - def useMemory = _useMemory - def useOffHeap = _useOffHeap - def deserialized = _deserialized - def replication = _replication + def useDisk: Boolean = _useDisk + def useMemory: Boolean = _useMemory + def useOffHeap: Boolean = _useOffHeap + def deserialized: Boolean = _deserialized + def replication: Int = _replication assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") @@ -80,7 +80,7 @@ class StorageLevel private( false } - def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 @@ -126,7 +126,7 @@ class StorageLevel private( var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") - result += (if (useOffHeap) "Tachyon " else "") + result += (if (useOffHeap) "ExternalBlockStore " else "") result += (if (deserialized) "Deserialized " else "Serialized ") result += s"${replication}x Replicated" result @@ -183,7 +183,7 @@ object StorageLevel { useMemory: Boolean, useOffHeap: Boolean, deserialized: Boolean, - replication: Int) = { + replication: Int): StorageLevel = { getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) } @@ -197,7 +197,7 @@ object StorageLevel { useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, - replication: Int = 1) = { + replication: Int = 1): StorageLevel = { getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index def49e80a360..ec711480ebf3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,20 +19,23 @@ package org.apache.spark.storage import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ /** * :: DeveloperApi :: * A SparkListener that maintains executor storage status. + * + * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi class StorageStatusListener extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - def storageStatusList = executorIdToStorageStatus.values.toSeq + def storageStatusList: Seq[StorageStatus] = synchronized { + executorIdToStorageStatus.values.toSeq + } /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { @@ -56,7 +59,7 @@ class StorageStatusListener extends SparkListener { } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { @@ -67,7 +70,7 @@ class StorageStatusListener extends SparkListener { } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { updateStorageStatus(unpersistRDD.rddId) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 2bd6b749be26..c4ac30092f80 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -199,33 +199,34 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val oldBlockStatus = getBlock(blockId).getOrElse(BlockStatus.empty) val changeInMem = newBlockStatus.memSize - oldBlockStatus.memSize val changeInDisk = newBlockStatus.diskSize - oldBlockStatus.diskSize - val changeInTachyon = newBlockStatus.tachyonSize - oldBlockStatus.tachyonSize + val changeInExternalBlockStore = + newBlockStatus.externalBlockStoreSize - oldBlockStatus.externalBlockStoreSize val level = newBlockStatus.storageLevel // Compute new info from old info - val (oldMem, oldDisk, oldTachyon) = blockId match { + val (oldMem, oldDisk, oldExternalBlockStore) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, tachyon, _) => (mem, disk, tachyon) } + .map { case (mem, disk, externalBlockStore, _) => (mem, disk, externalBlockStore) } .getOrElse((0L, 0L, 0L)) case _ => _nonRddStorageInfo } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) - val newTachyon = math.max(oldTachyon + changeInTachyon, 0L) + val newExternalBlockStore = math.max(oldExternalBlockStore + changeInExternalBlockStore, 0L) // Set the correct info blockId match { case RDDBlockId(rddId, _) => // If this RDD is no longer persisted, remove it - if (newMem + newDisk + newTachyon == 0) { + if (newMem + newDisk + newExternalBlockStore == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, newTachyon, level) + _rddStorageInfo(rddId) = (newMem, newDisk, newExternalBlockStore, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk, newTachyon) + _nonRddStorageInfo = (newMem, newDisk, newExternalBlockStore) } } @@ -247,13 +248,13 @@ private[spark] object StorageUtils { val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum val memSize = statuses.map(_.memUsedByRdd(rddId)).sum val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - val tachyonSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum + val externalBlockStoreSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum rddInfo.storageLevel = storageLevel rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.diskSize = diskSize - rddInfo.tachyonSize = tachyonSize + rddInfo.externalBlockStoreSize = externalBlockStoreSize } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index af873034215a..22878783fca6 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -17,54 +17,157 @@ package org.apache.spark.storage +import java.io.IOException +import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.client.TachyonFS -import tachyon.client.TachyonFile +import scala.util.control.NonFatal + +import com.google.common.io.ByteStreams + +import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} +import tachyon.conf.TachyonConf +import tachyon.TachyonURI import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By * default, one block is mapped to one file with a name given by its BlockId. * - * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ -private[spark] class TachyonBlockManager( - blockManager: BlockManager, - rootDirs: String, - val master: String) - extends Logging { - - val client = if (master != null && master != "") TachyonFS.get(master) else null - - if (client == null) { - logError("Failed to connect to the Tachyon as the master address is not configured") - System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_INITIALIZE) - } +private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging { - private val MAX_DIR_CREATION_ATTEMPTS = 10 - private val subDirsPerTachyonDir = - blockManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt + var rootDirs: String = _ + var master: String = _ + var client: tachyon.client.TachyonFS = _ + private var subDirsPerTachyonDir: Int = _ // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; // then, inside this directory, create multiple subdirectories that we will hash files into, // in order to avoid having really large inodes at the top level in Tachyon. - private val tachyonDirs: Array[TachyonFile] = createTachyonDirs() - private val subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) + private var tachyonDirs: Array[TachyonFile] = _ + private var subDirs: Array[Array[tachyon.client.TachyonFile]] = _ + + + override def init(blockManager: BlockManager, executorId: String): Unit = { + super.init(blockManager, executorId) + val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon") + val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME) + + rootDirs = s"$storeDir/$appFolderName/$executorId" + master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") + client = if (master != null && master != "") { + TachyonFS.get(new TachyonURI(master), new TachyonConf()) + } else { + null + } + // original implementation call System.exit, we change it to run without extblkstore support + if (client == null) { + logError("Failed to connect to the Tachyon as the master address is not configured") + throw new IOException("Failed to connect to the Tachyon as the master " + + "address is not configured") + } + subDirsPerTachyonDir = blockManager.conf.get("spark.externalBlockStore.subDirectories", + ExternalBlockStore.SUB_DIRS_PER_DIR).toInt + + // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; + // then, inside this directory, create multiple subdirectories that we will hash files into, + // in order to avoid having really large inodes at the top level in Tachyon. + tachyonDirs = createTachyonDirs() + subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) + } + + override def toString: String = {"ExternalBlockStore-Tachyon"} + + override def removeBlock(blockId: BlockId): Boolean = { + val file = getFile(blockId) + if (fileExists(file)) { + removeFile(file) + } else { + false + } + } + + override def blockExists(blockId: BlockId): Boolean = { + val file = getFile(blockId) + fileExists(file) + } - addShutdownHook() + override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { + val file = getFile(blockId) + val os = file.getOutStream(WriteType.TRY_CACHE) + try { + os.write(bytes.array()) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } + } + + override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val file = getFile(blockId) + val os = file.getOutStream(WriteType.TRY_CACHE) + try { + blockManager.dataSerializeStream(blockId, os, values) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put values of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } + } + + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + val file = getFile(blockId) + if (file == null || file.getLocationHosts.size == 0) { + return None + } + val is = file.getInStream(ReadType.CACHE) + try { + val size = file.length + val bs = new Array[Byte](size.asInstanceOf[Int]) + ByteStreams.readFully(is, bs) + Some(ByteBuffer.wrap(bs)) + } catch { + case NonFatal(e) => + logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) + None + } finally { + is.close() + } + } + + override def getValues(blockId: BlockId): Option[Iterator[_]] = { + val file = getFile(blockId) + if (file == null || file.getLocationHosts().size() == 0) { + return None + } + val is = file.getInStream(ReadType.CACHE) + Option(is).map { is => + blockManager.dataDeserializeStream(blockId, is) + } + } + + override def getSize(blockId: BlockId): Long = { + getFile(blockId.name).length + } def removeFile(file: TachyonFile): Boolean = { - client.delete(file.getPath(), false) + client.delete(new TachyonURI(file.getPath()), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(file.getPath()) + client.exist(new TachyonURI(file.getPath())) } def getFile(filename: String): TachyonFile = { @@ -81,7 +184,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +192,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = subDir + "/" + filename + val filePath = new TachyonURI(s"$subDir/$filename") if(!client.exist(filePath)) { client.createFile(filePath) } @@ -109,47 +212,42 @@ private[spark] class TachyonBlockManager( var tachyonDirId: String = null var tries = 0 val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + while (!foundLocalDir && tries < ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS) { tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) } } catch { - case e: Exception => + case NonFatal(e) => logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) } } if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " + - rootDir) - System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR) + logError("Failed " + ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS + + " attempts to create tachyon dir in " + rootDir) + System.exit(ExecutorExitCode.EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR) } logInfo("Created tachyon directory at " + tachyonDir) tachyonDir } } - private def addShutdownHook() { - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - tachyonDirs.foreach { tachyonDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) - } - } catch { - case e: Exception => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) - } + override def shutdown() { + logDebug("Shutdown hook called") + tachyonDirs.foreach { tachyonDir => + try { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { + Utils.deleteRecursively(tachyonDir, client) } - client.close() + } catch { + case NonFatal(e) => + logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } - }) + } + client.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala deleted file mode 100644 index b86abbda1d3e..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import tachyon.client.TachyonFile - -/** - * References a particular segment of a file (potentially the entire file), based off an offset and - * a length. - */ -private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) -} diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala deleted file mode 100644 index 233d1e2b7c61..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.IOException -import java.nio.ByteBuffer - -import com.google.common.io.ByteStreams -import tachyon.client.{ReadType, WriteType} - -import org.apache.spark.Logging -import org.apache.spark.util.Utils - -/** - * Stores BlockManager blocks on Tachyon. - */ -private[spark] class TachyonStore( - blockManager: BlockManager, - tachyonManager: TachyonBlockManager) - extends BlockStore(blockManager: BlockManager) with Logging { - - logInfo("TachyonStore started") - - override def getSize(blockId: BlockId): Long = { - tachyonManager.getFile(blockId.name).length - } - - override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { - putIntoTachyonStore(blockId, bytes, returnValues = true) - } - - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIterator(blockId, values.toIterator, level, returnValues) - } - - override def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - logDebug(s"Attempting to write values for block $blockId") - val bytes = blockManager.dataSerialize(blockId, values) - putIntoTachyonStore(blockId, bytes, returnValues) - } - - private def putIntoTachyonStore( - blockId: BlockId, - bytes: ByteBuffer, - returnValues: Boolean): PutResult = { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val byteBuffer = bytes.duplicate() - byteBuffer.rewind() - logDebug(s"Attempting to put block $blockId into Tachyon") - val startTime = System.currentTimeMillis - val file = tachyonManager.getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) - os.write(byteBuffer.array()) - os.close() - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file in Tachyon in %d ms".format( - blockId, Utils.bytesToString(byteBuffer.limit), finishTime - startTime)) - - if (returnValues) { - PutResult(bytes.limit(), Right(bytes.duplicate())) - } else { - PutResult(bytes.limit(), null) - } - } - - override def remove(blockId: BlockId): Boolean = { - val file = tachyonManager.getFile(blockId) - if (tachyonManager.fileExists(file)) { - tachyonManager.removeFile(file) - } else { - false - } - } - - override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) - } - - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val file = tachyonManager.getFile(blockId) - if (file == null || file.getLocationHosts.size == 0) { - return None - } - val is = file.getInStream(ReadType.CACHE) - assert (is != null) - try { - val size = file.length - val bs = new Array[Byte](size.asInstanceOf[Int]) - ByteStreams.readFully(is, bs) - Some(ByteBuffer.wrap(bs)) - } catch { - case ioe: IOException => - logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) - None - } finally { - is.close() - } - } - - override def contains(blockId: BlockId): Boolean = { - val file = tachyonManager.getFile(blockId) - tachyonManager.fileExists(file) - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 27ba9e18237b..77c0bc8b5360 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -28,7 +28,6 @@ import org.apache.spark._ * of them will be combined together, showed in one line. */ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { - // Carrige return val CR = '\r' // Update period of progress bar, in milliseconds @@ -66,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -82,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => @@ -121,4 +120,10 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { clear() lastFinishTime = System.currentTimeMillis() } + + /** + * Tear down the timer thread. The timer thread is a GC root, and it retains the entire + * SparkContext if it's not terminated. + */ + def stop(): Unit = timer.cancel() } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 88fed833f922..779c0ba08359 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -62,19 +62,33 @@ private[spark] object JettyUtils extends Logging { securityMgr: SecurityManager): HttpServlet = { new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { - response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) - response.setStatus(HttpServletResponse.SC_OK) - val result = servletParams.responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter.println(servletParams.extractFn(result)) - } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, - "User is not authorized to access this page.") + try { + if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + // scalastyle:off println + response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page.") + } + } catch { + case e: IllegalArgumentException => + response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage) + case e: Exception => + logWarning(s"GET ${request.getRequestURI} failed: $e", e) + throw e } } + // SPARK-5983 ensure TRACE is not supported + protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } } } @@ -92,7 +106,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -105,15 +123,34 @@ private[spark] object JettyUtils extends Logging { srcPath: String, destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), - basePath: String = ""): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + basePath: String = "", + httpMethods: Set[String] = Set("GET")): ServletContextHandler = { + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { - override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { + if (httpMethods.contains("GET")) { + doRequest(request, response) + } else { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { + if (httpMethods.contains("POST")) { + doRequest(request, response) + } else { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString response.sendRedirect(newUrl) } + // SPARK-5983 ensure TRACE is not supported + protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } } createServletHandler(srcPath, servlet, basePath) } @@ -179,16 +216,25 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) val pool = new QueuedThreadPool pool.setDaemon(true) server.setThreadPool(pool) + val errorHandler = new ErrorHandler() + errorHandler.setShowStacks(true) + server.addBean(errorHandler) server.setHandler(collection) try { server.start() @@ -204,11 +250,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala new file mode 100644 index 000000000000..6e2375477a68 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.{Node, Unparsed} + +/** + * A data source that provides data for a page. + * + * @param pageSize the number of rows in a page + */ +private[ui] abstract class PagedDataSource[T](val pageSize: Int) { + + if (pageSize <= 0) { + throw new IllegalArgumentException("Page size must be positive") + } + + /** + * Return the size of all data. + */ + protected def dataSize: Int + + /** + * Slice a range of data. + */ + protected def sliceData(from: Int, to: Int): Seq[T] + + /** + * Slice the data for this page + */ + def pageData(page: Int): PageData[T] = { + val totalPages = (dataSize + pageSize - 1) / pageSize + if (page <= 0 || page > totalPages) { + throw new IndexOutOfBoundsException( + s"Page $page is out of range. Please select a page number between 1 and $totalPages.") + } + val from = (page - 1) * pageSize + val to = dataSize.min(page * pageSize) + PageData(totalPages, sliceData(from, to)) + } + +} + +/** + * The data returned by `PagedDataSource.pageData`, including the page number, the number of total + * pages and the data in this page. + */ +private[ui] case class PageData[T](totalPage: Int, data: Seq[T]) + +/** + * A paged table that will generate a HTML table for a specified page and also the page navigation. + */ +private[ui] trait PagedTable[T] { + + def tableId: String + + def tableCssClass: String + + def dataSource: PagedDataSource[T] + + def headers: Seq[Node] + + def row(t: T): Seq[Node] + + def table(page: Int): Seq[Node] = { + val _dataSource = dataSource + try { + val PageData(totalPages, data) = _dataSource.pageData(page) +
      + {pageNavigation(page, _dataSource.pageSize, totalPages)} + + {headers} + + {data.map(row)} + +
      +
      + } catch { + case e: IndexOutOfBoundsException => + val PageData(totalPages, _) = _dataSource.pageData(1) +
      + {pageNavigation(1, _dataSource.pageSize, totalPages)} +
      {e.getMessage}
      +
      + } + } + + /** + * Return a page navigation. + *
        + *
      • If the totalPages is 1, the page navigation will be empty
      • + *
      • + * If the totalPages is more than 1, it will create a page navigation including a group of + * page numbers and a form to submit the page number. + *
      • + *
      + * + * Here are some examples of the page navigation: + * {{{ + * << < 11 12 13* 14 15 16 17 18 19 20 > >> + * + * This is the first group, so "<<" is hidden. + * < 1 2* 3 4 5 6 7 8 9 10 > >> + * + * This is the first group and the first page, so "<<" and "<" are hidden. + * 1* 2 3 4 5 6 7 8 9 10 > >> + * + * Assume totalPages is 19. This is the last group, so ">>" is hidden. + * << < 11 12 13* 14 15 16 17 18 19 > + * + * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden. + * << < 11 12 13 14 15 16 17 18 19* + * + * * means the current page number + * << means jumping to the first page of the previous group. + * < means jumping to the previous page. + * >> means jumping to the first page of the next group. + * > means jumping to the next page. + * }}} + */ + private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = { + if (totalPages == 1) { + Nil + } else { + // A group includes all page numbers will be shown in the page navigation. + // The size of group is 10 means there are 10 page numbers will be shown. + // The first group is 1 to 10, the second is 2 to 20, and so on + val groupSize = 10 + val firstGroup = 0 + val lastGroup = (totalPages - 1) / groupSize + val currentGroup = (page - 1) / groupSize + val startPage = currentGroup * groupSize + 1 + val endPage = totalPages.min(startPage + groupSize - 1) + val pageTags = (startPage to endPage).map { p => + if (p == page) { + // The current page should be disabled so that it cannot be clicked. +
    • {p}
    • + } else { +
    • {p}
    • + } + } + val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction + // When clicking the "Go" button, it will call this javascript method and then call + // "goButtonJsFuncName" + val formJs = + s"""$$(function(){ + | $$( "#form-$tableId-page" ).submit(function(event) { + | var page = $$("#form-$tableId-page-no").val() + | var pageSize = $$("#form-$tableId-page-size").val() + | pageSize = pageSize ? pageSize: 100; + | if (page != "") { + | ${goButtonJsFuncName}(page, pageSize); + | } + | event.preventDefault(); + | }); + |}); + """.stripMargin + +
      +
      +
      + + + + + + +
      +
      + + +
      + } + } + + /** + * Return a link to jump to a page. + */ + def pageLink(page: Int): String + + /** + * Only the implementation knows how to create the url with a page number and the page size, so we + * leave this one to the implementation. The implementation should create a JavaScript method that + * accepts a page number along with the page size and jumps to the page. The return value is this + * method name and its JavaScript codes. + */ + def goButtonJavascriptFunction: (String, String) +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0c24ad2760e0..d8b90568b7b9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,6 +17,10 @@ package org.apache.spark.ui +import java.util.Date + +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, + UIRoot} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -25,6 +29,7 @@ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener /** * Top level user interface for a Spark application. @@ -32,35 +37,43 @@ import org.apache.spark.ui.storage.{StorageListener, StorageTab} private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, - val securityManager: SecurityManager, + securityManager: SecurityManager, val environmentListener: EnvironmentListener, val storageStatusListener: StorageStatusListener, val executorsListener: ExecutorsListener, val jobProgressListener: JobProgressListener, val storageListener: StorageListener, + val operationGraphListener: RDDOperationGraphListener, var appName: String, - val basePath: String) + val basePath: String, + val startTime: Long) extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") - with Logging { + with Logging + with UIRoot { val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + + val stagesTab = new StagesTab(this) + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) - val stagesTab = new StagesTab(this) attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) + attachHandler(ApiRootResource.getServletHandler(this)) + // This should be POST only, but, the YARN AM proxy won't proxy POSTs + attachHandler(createRedirectHandler( + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, + httpMethods = Set("GET", "POST"))) } initialize() - def getAppName = appName + def getAppName: String = appName /** Set the app name for this UI. */ def setAppName(name: String) { @@ -79,6 +92,24 @@ private[spark] class SparkUI private ( private[spark] def appUIHostPort = publicHostName + ":" + boundPort private[spark] def appUIAddress = s"http://$appUIHostPort" + + def getSparkUI(appId: String): Option[SparkUI] = { + if (appId == appName) Some(this) else None + } + + def getApplicationInfoList: Iterator[ApplicationInfo] = { + Iterator(new ApplicationInfo( + id = appName, + name = appName, + attempts = Seq(new ApplicationAttemptInfo( + attemptId = None, + startTime = new Date(startTime), + endTime = new Date(-1), + sparkUser = "", + completed = false + )) + )) + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) @@ -91,6 +122,9 @@ private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) private[spark] object SparkUI { val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val DEFAULT_POOL_NAME = "default" + val DEFAULT_RETAINED_STAGES = 1000 + val DEFAULT_RETAINED_JOBS = 1000 def getUIPort(conf: SparkConf): Int = { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) @@ -102,9 +136,10 @@ private[spark] object SparkUI { listenerBus: SparkListenerBus, jobProgressListener: JobProgressListener, securityManager: SecurityManager, - appName: String): SparkUI = { + appName: String, + startTime: Long): SparkUI = { create(Some(sc), conf, listenerBus, securityManager, appName, - jobProgressListener = Some(jobProgressListener)) + jobProgressListener = Some(jobProgressListener), startTime = startTime) } def createHistoryUI( @@ -112,8 +147,9 @@ private[spark] object SparkUI { listenerBus: SparkListenerBus, securityManager: SecurityManager, appName: String, - basePath: String): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath) + basePath: String, + startTime: Long): SparkUI = { + create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) } /** @@ -130,7 +166,8 @@ private[spark] object SparkUI { securityManager: SecurityManager, appName: String, basePath: String = "", - jobProgressListener: Option[JobProgressListener] = None): SparkUI = { + jobProgressListener: Option[JobProgressListener] = None, + startTime: Long): SparkUI = { val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { val listener = new JobProgressListener(conf) @@ -142,13 +179,16 @@ private[spark] object SparkUI { val storageStatusListener = new StorageStatusListener val executorsListener = new ExecutorsListener(storageStatusListener) val storageListener = new StorageListener(storageStatusListener) + val operationGraphListener = new RDDOperationGraphListener(conf) listenerBus.addListener(environmentListener) listenerBus.addListener(storageStatusListener) listenerBus.addListener(executorsListener) listenerBus.addListener(storageListener) + listenerBus.addListener(operationGraphListener) new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, _jobProgressListener, storageListener, appName, basePath) + executorsListener, _jobProgressListener, storageListener, operationGraphListener, + appName, basePath, startTime) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 4307029d44fb..cb122eaed83d 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,20 +24,31 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + val TASK_DESERIALIZATION_TIME = + """Time spent deserializing the task closure on the executor, including the time to read the + broadcasted task.""" val SHUFFLE_READ_BLOCKED_TIME = "Time that the task spent blocked waiting for shuffle data to be read from remote machines." - val INPUT = "Bytes read from Hadoop or from Spark storage." + val INPUT = "Bytes and records read from Hadoop or from Spark storage." - val OUTPUT = "Bytes written to Hadoop." + val OUTPUT = "Bytes and records written to Hadoop." - val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + + val SHUFFLE_WRITE = + "Bytes and records written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = - """Bytes read from remote executors. Typically less than shuffle write bytes - because this does not include shuffle data read locally.""" + """Total shuffle bytes and records read (includes both data read locally and data read from + remote executors). """ + + val SHUFFLE_READ_REMOTE_SIZE = + """Total shuffle bytes read from remote executors. This is a subset of the shuffle + read bytes; the remaining shuffle data is read locally. """ val GETTING_RESULT_TIME = """Time that the driver spends fetching task results from workers. If this is large, consider @@ -50,4 +61,30 @@ private[spark] object ToolTips { val GC_TIME = """Time that the executor spent paused for Java garbage collection while the task was running.""" + + val PEAK_EXECUTION_MEMORY = + """Execution memory refers to the memory used by internal data structures created during + shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator + should be approximately the sum of the peak sizes across all such data structures created + in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and + external sort.""" + + val JOB_TIMELINE = + """Shows when jobs started and ended and when executors joined or left. Drag to scroll. + Click Enable Zooming and use mouse wheel to zoom in/out.""" + + val STAGE_TIMELINE = + """Shows when stages started and ended and when executors joined or left. Drag to scroll. + Click Enable Zooming and use mouse wheel to zoom in/out.""" + + val JOB_DAG = + """Shows a graph of stages executed for this job, each of which can contain + multiple RDD operations (e.g. map() and filter()), and of RDDs inside each operation + (shown as dots).""" + + val STAGE_DAG = + """Shows a graph of RDD operations in this stage, and RDDs inside each one. A stage can run + multiple operations (e.g. two map() functions) if they can be pipelined. Some operations + also create multiple RDDs internally. Cached RDDs are shown in green. + """ } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index b5022fe853c4..f2da41772410 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,13 +20,14 @@ package org.apache.spark.ui import java.text.SimpleDateFormat import java.util.{Locale, Date} -import scala.xml.{Node, Text} +import scala.xml.{Node, Text, Unparsed} import org.apache.spark.Logging +import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. @@ -149,20 +150,32 @@ private[spark] object UIUtils extends Logging { } } - def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource + def prependBaseUri(basePath: String = "", resource: String = ""): String = { + uiRoot + basePath + resource + } - def commonHeaderNodes = { + def commonHeaderNodes: Seq[Node] = { - - + + + + + + + } + + def vizHeaderNodes: Seq[Node] = { + + + + + } /** Returns a spark page with correctly formatted headers */ @@ -171,7 +184,8 @@ private[spark] object UIUtils extends Logging { content: => Seq[Node], activeTab: SparkUITab, refreshInterval: Option[Int] = None, - helpText: Option[String] = None): Seq[Node] = { + helpText: Option[String] = None, + showVisualization: Boolean = false): Seq[Node] = { val appName = activeTab.appName val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." @@ -180,15 +194,12 @@ private[spark] object UIUtils extends Logging { {tab.name} } - val helpButton: Seq[Node] = helpText.map { helpText => - - (?) - - }.getOrElse(Seq.empty) + val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) {commonHeaderNodes} + {if (showVisualization) vizHeaderNodes else Seq.empty} {appName} - {title} @@ -235,7 +246,7 @@ private[spark] object UIUtils extends Logging {

      - {org.apache.spark.SPARK_VERSION} {title} @@ -256,9 +267,17 @@ private[spark] object UIUtils extends Logging { fixedWidth: Boolean = false, id: Option[String] = None, headerClasses: Seq[String] = Seq.empty, - stripeRowsWithCss: Boolean = true): Seq[Node] = { + stripeRowsWithCss: Boolean = true, + sortable: Boolean = true): Seq[Node] = { - val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + val listingTableClass = { + val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + if (sortable) { + _tableClass + " sortable" + } else { + _tableClass + } + } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" @@ -298,7 +317,7 @@ private[spark] object UIUtils extends Logging { started: Int, completed: Int, failed: Int, - skipped:Int, + skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100) @@ -313,4 +332,67 @@ private[spark] object UIUtils extends Logging {

    } + + /** Return a "DAG visualization" DOM element that expands into a visualization for a stage. */ + def showDagVizForStage(stageId: Int, graph: Option[RDDOperationGraph]): Seq[Node] = { + showDagViz(graph.toSeq, forJob = false) + } + + /** Return a "DAG visualization" DOM element that expands into a visualization for a job. */ + def showDagVizForJob(jobId: Int, graphs: Seq[RDDOperationGraph]): Seq[Node] = { + showDagViz(graphs, forJob = true) + } + + /** + * Return a "DAG visualization" DOM element that expands into a visualization on the UI. + * + * This populates metadata necessary for generating the visualization on the front-end in + * a format that is expected by spark-dag-viz.js. Any changes in the format here must be + * reflected there. + */ + private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = { +
    + + + + DAG Visualization + + +
    + +
    + } + + def tooltip(text: String, position: String): Seq[Node] = { + + (?) + + } + + /** Return a script element that automatically expands the DAG visualization on page load. */ + def expandDagVizOnLoad(forJob: Boolean): Seq[Node] = { + + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index fc1844600f1c..5a8c2914314c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui +import java.util.concurrent.Semaphore + import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} @@ -36,9 +38,11 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 3) { + // scalastyle:off println println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") + // scalastyle:on println System.exit(1) } @@ -51,15 +55,15 @@ private[spark] object UIWorkloadGenerator { val nJobSet = args(2).toInt val sc = new SparkContext(conf) - def setProperties(s: String) = { - if(schedulingMode == SchedulingMode.FAIR) { + def setProperties(s: String): Unit = { + if (schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = new Random().nextFloat() + def nextFloat(): Float = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), @@ -88,10 +92,13 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) + val barrier = new Semaphore(-nJobSet * jobs.size + 1) + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { override def run() { + // scalastyle:off println try { setProperties(desc) job() @@ -99,12 +106,18 @@ private[spark] object UIWorkloadGenerator { } catch { case e: Exception => println("Job Failed: " + desc) + } finally { + barrier.release() } + // scalastyle:on println } }.start Thread.sleep(INTER_JOB_WAIT_MS) } } + + // Waiting for threads. + barrier.acquire() sc.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 9be65a4a39a0..61449847add3 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -20,14 +20,15 @@ package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} -import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * The top level component of the UI hierarchy that contains the server. @@ -36,7 +37,7 @@ import org.apache.spark.util.Utils * pages. The use of tabs is optional, however; a WebUI may choose to include pages directly. */ private[spark] abstract class WebUI( - securityManager: SecurityManager, + val securityManager: SecurityManager, port: Int, conf: SparkConf, basePath: String = "", @@ -45,9 +46,10 @@ private[spark] abstract class WebUI( protected val tabs = ArrayBuffer[WebUITab]() protected val handlers = ArrayBuffer[ServletContextHandler]() + protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None - protected val localHostName = Utils.localHostName() - protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) + protected val localHostName = Utils.localHostNameForURI() + protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath @@ -61,13 +63,26 @@ private[spark] abstract class WebUI( tabs += tab } + def detachTab(tab: WebUITab) { + tab.pages.foreach(detachPage) + tabs -= tab + } + + def detachPage(page: WebUIPage) { + pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) + } + /** Attach a page to this UI. */ def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix - attachHandler(createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath)) - attachHandler(createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath)) + val renderHandler = createServletHandler(pagePath, + (request: HttpServletRequest) => page.render(request), securityManager, basePath) + val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", + (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + attachHandler(renderHandler) + attachHandler(renderJsonHandler) + pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) + .append(renderHandler) } /** Attach a handler to this UI. */ @@ -82,7 +97,7 @@ private[spark] abstract class WebUI( } /** Detach a handler from this UI. */ - protected def detachHandler(handler: ServletContextHandler) { + def detachHandler(handler: ServletContextHandler) { handlers -= handler serverInfo.foreach { info => info.rootHandler.removeHandler(handler) @@ -92,6 +107,25 @@ private[spark] abstract class WebUI( } } + /** + * Add a handler for static content. + * + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. + */ + def addStaticHandler(resourceBase: String, path: String): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + } + + /** + * Remove a static content handler. + * + * @param path Path in UI to unmount. + */ + def removeStaticHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) + } + /** Initialize all components of the server. */ def initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c82730f524eb..b0a2cb4aa4d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -43,17 +43,35 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } id }.getOrElse { - return Text(s"Missing executorId parameter") + throw new IllegalArgumentException(s"Missing executorId parameter") } val time = System.currentTimeMillis() val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.map { thread => + val dumpRows = threadDump.sortWith { + case (threadTrace1, threadTrace2) => { + val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + } else { + v1 > v2 + } + } + }.map { thread => + val threadName = thread.threadName + val className = "accordion-heading " + { + if (threadName.contains("Executor task launch")) { + "executor-thread" + } else { + "non-executor-thread" + } + }
    -
    + + +As you add new text files with data the cluster centers will update. Each training +point should be formatted as `[x1, x2, x3]`, and each test data point +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` +you will see predictions. With new data, the cluster centers will change! diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index ef18cec9371d..eedc23424ad5 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,6 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. {% highlight scala %} import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel import org.apache.spark.mllib.recommendation.Rating // Load and parse the data @@ -76,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -95,6 +96,10 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => err * err }.mean() println("Mean Squared Error = " + MSE) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from another source of information (e.g., it is inferred from @@ -102,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 -val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) +val lambda = 0.01 +val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha) {% endhighlight %}
    @@ -143,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -181,6 +187,10 @@ public class CollaborativeFiltering { } ).rdd()).mean(); System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -192,7 +202,7 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") @@ -200,15 +210,19 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from other source of information (i.e., it is inferred from other diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 101dc2f8695f..f0e8d5495675 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -31,7 +31,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseVector) and [`SparseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseVector). We recommend using the factory methods implemented in -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) to create local vectors. +[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -57,7 +57,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/java/org/apache/spark/mllib/linalg/DenseVector.html) and [`SparseVector`](api/java/org/apache/spark/mllib/linalg/SparseVector.html). We recommend using the factory methods implemented in -[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vector.html) to create local vectors. +[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vectors.html) to create local vectors. {% highlight java %} import org.apache.spark.mllib.linalg.Vector; @@ -78,13 +78,13 @@ MLlib recognizes the following types as dense vectors: and the following as sparse vectors: -* MLlib's [`SparseVector`](api/python/pyspark.mllib.linalg.SparseVector-class.html). +* MLlib's [`SparseVector`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseVector). * SciPy's [`csc_matrix`](http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix) with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.linalg.Vectors-class.html) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) to create sparse vectors. {% highlight python %} import numpy as np @@ -151,7 +151,7 @@ LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new
    A labeled point is represented by -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html). +[`LabeledPoint`](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint). {% highlight python %} from pyspark.mllib.linalg import SparseVector @@ -211,7 +211,7 @@ JavaRDD examples =
    -[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.util.MLUtils-class.html) reads training +[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) reads training examples stored in LIBSVM format. {% highlight python %} @@ -226,7 +226,8 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") A local matrix has integer-typed row and column indices and double-typed values, stored on a single machine. MLlib supports dense matrices, whose entry values are stored in a single double array in -column major. For example, the following matrix `\[ \begin{pmatrix} +column-major order, and sparse matrices, whose non-zero entry values are stored in the Compressed Sparse +Column (CSC) format in column-major order. For example, the following dense matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ 5.0 & 6.0 @@ -238,28 +239,33 @@ is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the m
    The base class of local matrices is -[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one -implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). +[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseMatrix). We recommend using the factory methods implemented -in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local -matrices. +in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +val sm: Matrix = Matrices.sparse(3, 2, Array(0, 1, 3), Array(0, 2, 1), Array(9, 6, 8)) {% endhighlight %}
    The base class of local matrices is -[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one -implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). +[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide two +implementations: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html), +and [`SparseMatrix`](api/java/org/apache/spark/mllib/linalg/SparseMatrix.html). We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight java %} import org.apache.spark.mllib.linalg.Matrix; @@ -267,6 +273,30 @@ import org.apache.spark.mllib.linalg.Matrices; // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +Matrix sm = Matrices.sparse(3, 2, new int[] {0, 1, 3}, new int[] {0, 2, 1}, new double[] {9, 6, 8}); +{% endhighlight %} +
    + +
    + +The base class of local matrices is +[`Matrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseMatrix). +We recommend using the factory methods implemented +in [`Matrices`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. + +{% highlight python %} +import org.apache.spark.mllib.linalg.{Matrix, Matrices} + +// Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) +dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) {% endhighlight %}
    @@ -342,12 +372,37 @@ long m = mat.numRows(); long n = mat.numCols(); {% endhighlight %}
    + +
    + +A [`RowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) can be +created from an `RDD` of vectors. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import RowMatrix + +# Create an RDD of vectors. +rows = sc.parallelize([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) + +# Create a RowMatrix from an RDD of vectors. +mat = RowMatrix(rows) + +# Get its size. +m = mat.numRows() # 4 +n = mat.numCols() # 3 + +# Get the rows as an RDD of vectors again. +rowsRDD = mat.rows +{% endhighlight %} +
    +
    ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local +vector.
    @@ -401,7 +456,51 @@ long n = mat.numCols(); // Drop its row indices. RowMatrix rowMat = mat.toRowMatrix(); {% endhighlight %} -
    +
    + +
    + +An [`IndexedRowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRowMatrix) +can be created from an `RDD` of `IndexedRow`s, where +[`IndexedRow`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRow) is a +wrapper over `(long, vector)`. An `IndexedRowMatrix` can be converted to a `RowMatrix` by dropping +its row indices. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix + +# Create an RDD of indexed rows. +# - This can be done explicitly with the IndexedRow class: +indexedRows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + IndexedRow(1, [4, 5, 6]), + IndexedRow(2, [7, 8, 9]), + IndexedRow(3, [10, 11, 12])]) +# - or by using (long, vector) tuples: +indexedRows = sc.parallelize([(0, [1, 2, 3]), (1, [4, 5, 6]), + (2, [7, 8, 9]), (3, [10, 11, 12])]) + +# Create an IndexedRowMatrix from an RDD of IndexedRows. +mat = IndexedRowMatrix(indexedRows) + +# Get its size. +m = mat.numRows() # 4 +n = mat.numCols() # 3 + +# Get the rows as an RDD of IndexedRows. +rowsRDD = mat.rows + +# Convert to a RowMatrix by dropping the row indices. +rowMat = mat.toRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() +{% endhighlight %} +
    + +
    ### CoordinateMatrix @@ -465,4 +564,142 @@ long n = mat.numCols(); IndexedRowMatrix indexedRowMatrix = mat.toIndexedRowMatrix(); {% endhighlight %} + +
    + +A [`CoordinateMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.CoordinateMatrix) +can be created from an `RDD` of `MatrixEntry` entries, where +[`MatrixEntry`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.MatrixEntry) is a +wrapper over `(long, long, float)`. A `CoordinateMatrix` can be converted to a `RowMatrix` by +calling `toRowMatrix`, or to an `IndexedRowMatrix` with sparse rows by calling `toIndexedRowMatrix`. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry + +# Create an RDD of coordinate entries. +# - This can be done explicitly with the MatrixEntry class: +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +# - or using (long, long, float) tuples: +entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) + +# Create an CoordinateMatrix from an RDD of MatrixEntries. +mat = CoordinateMatrix(entries) + +# Get its size. +m = mat.numRows() # 3 +n = mat.numCols() # 2 + +# Get the entries as an RDD of MatrixEntries. +entriesRDD = mat.entries + +# Convert to a RowMatrix. +rowMat = mat.toRowMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() +{% endhighlight %} +
    + + + +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
    +
    + +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix) +can be created from an `RDD` of sub-matrix blocks, where a sub-matrix block is a +`((blockRowIndex, blockColIndex), sub-matrix)` tuple. + +{% highlight python %} +from pyspark.mllib.linalg import Matrices +from pyspark.mllib.linalg.distributed import BlockMatrix + +# Create an RDD of sub-matrix blocks. +blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + +# Create a BlockMatrix from an RDD of sub-matrix blocks. +mat = BlockMatrix(blocks, 3, 2) + +# Get its size. +m = mat.numRows() # 6 +n = mat.numCols() # 2 + +# Get the blocks as an RDD of sub-matrix blocks. +blocksRDD = mat.blocks + +# Convert to a LocalMatrix. +localMat = mat.toLocalMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() +{% endhighlight %} +
    diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index fc8e732251a3..c1d0f8a6b1cd 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Tree - MLlib -displayTitle: MLlib - Decision Tree +title: Decision Trees - MLlib +displayTitle: MLlib - Decision Trees --- * Table of contents @@ -54,8 +54,8 @@ impurity measure for regression (variance). Variance Regression - $\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$y_i$ is label for an instance, - $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$. + $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N y_i$. @@ -194,6 +194,7 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    {% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -221,6 +222,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification tree model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -279,13 +284,18 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification tree model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -305,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes print('Test Error = ' + str(testErr)) print('Learned classification tree model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -324,6 +338,7 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
    {% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -350,6 +365,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression tree model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -414,13 +433,18 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -440,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression tree model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 870fed6cc502..05f51168d837 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -137,7 +137,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format. +MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
    @@ -157,6 +157,23 @@ val pc: Matrix = mat.computePrincipalComponents(10) // Principal components are // Project the rows to the linear space spanned by the top 10 principal components. val projected: RowMatrix = mat.multiply(pc) {% endhighlight %} + +The following code demonstrates how to compute principal components on source vectors +and use them to project the vectors into a low-dimensional space while keeping associated labels: + +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.feature.PCA + +val data: RDD[LabeledPoint] = ... + +// Compute the top 10 principal components. +val pca = new PCA(10).fit(data.map(_.features)) + +// Project vectors to the linear space spanned by the top 10 principal components, keeping the label +val projected = data.map(p => p.copy(features = pca.transform(p.features))) +{% endhighlight %} +
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5..1e00b2083ed7 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBosotedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -98,6 +98,7 @@ The test error is calculated to measure the algorithm accuracy.
    {% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -127,6 +128,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -188,12 +193,17 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification forest model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} -from pyspark.mllib.tree import RandomForest +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -216,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes print('Test Error = ' + str(testErr)) print('Learned classification forest model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -235,6 +249,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    {% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -264,6 +279,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression forest model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -328,12 +347,17 @@ Double testMSE = }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression forest model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} -from pyspark.mllib.tree import RandomForest +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -356,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -427,10 +455,19 @@ We omit some decision tree parameters since those are covered in the [decision t * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. +#### Validation while training -### Examples +Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while +training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the +first one being the training dataset and the second being the validation dataset. -GBTs currently have APIs in Scala and Java. Examples in both languages are shown below. +The training is stopped when the improvement in the validation error is not more than a certain tolerance +(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error +decreases initially and later increases. There might be cases in which the validation error does not change monotonically, +and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration` +(which gives the error or loss per iteration) to tune the number of iterations. + +### Examples #### Classification @@ -446,6 +483,7 @@ The test error is calculated to measure the algorithm accuracy. {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -458,7 +496,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // The defaultParams for Classification use LogLoss by default. val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClassesForClassification = 2 +boostingStrategy.treeStrategy.numClasses = 2 boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() @@ -473,6 +511,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification GBT model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %} @@ -534,6 +576,41 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} + + +
    + +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) +print('Test Error = ' + str(testErr)) +print('Learned classification GBT model:') +print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -554,6 +631,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -580,6 +658,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression GBT model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %} @@ -647,6 +729,41 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} + + +
    + +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) +print('Test Mean Squared Error = ' + str(testMSE)) +print('Learned regression GBT model:') +print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 000000000000..7066d5c97418 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
    Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
    F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
    Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
    Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
    Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
    + + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print("Area under PR = %s" % metrics.areaUnderPR) + +# Area under ROC curve +print("Area under ROC = %s" % metrics.areaUnderROC) + +{% endhighlight %} + +
    +
    + + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
    Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
    Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
    Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
    Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
    Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
    F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
    Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print("Summary Stats") +print("Precision = %s" % precision) +print("Recall = %s" % recall) +print("F1 Score = %s" % f1Score) + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + +# Weighted stats +print("Weighted recall = %s" % metrics.weightedRecall) +print("Weighted precision = %s" % metrics.weightedPrecision) +print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) +print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) +print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) +{% endhighlight %} + +
    +
    + +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
    Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
    Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
    Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
    Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
    F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
    Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
    Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
    F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
    Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
    Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
    Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
    + +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
    + +
    + +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print("Recall = %s" % metrics.recall()) +print("Precision = %s" % metrics.precision()) +print("F1 measure = %s" % metrics.f1Measure()) +print("Accuracy = %s" % metrics.accuracy) + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + +# Micro stats +print("Micro precision = %s" % metrics.microPrecision) +print("Micro recall = %s" % metrics.microRecall) +print("Micro F1 measure = %s" % metrics.microF1Measure) + +# Hamming loss +print("Hamming loss = %s" % metrics.hammingLoss) + +# Subset accuracy +print("Subset accuracy = %s" % metrics.subsetAccuracy) + +{% endhighlight %} + +
    +
    + +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinitionNotes
    + Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
    Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
    Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
    + +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
    + +
    + +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print("RMSE = %s" % metrics.rootMeanSquaredError) + +# R-squared +print("R-squared = %s" % metrics.r2) + +{% endhighlight %} + +
    +
    + +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
    Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
    Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
    Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
    Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
    + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print("MSE = %s" % metrics.meanSquaredError) +print("RMSE = %s" % metrics.rootMeanSquaredError) + +# R-squared +print("R-squared = %s" % metrics.r2) + +# Mean absolute error +print("MAE = %s" % metrics.meanAbsoluteError) + +# Explained variance +print("Explained variance = %s" % metrics.explainedVariance) + +{% endhighlight %} + +
    +
    \ No newline at end of file diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index d4a61a7fbf3d..de86aba2ae62 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -217,7 +221,7 @@ model = word2vec.fit(inp) synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) {% endhighlight %}
    @@ -375,3 +379,266 @@ data2 = labels.zip(normalizer2.transform(features)) {% endhighlight %} + +## Feature selection +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. + +### ChiSqSelector +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. + +#### Model Fitting + +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the +following parameters in the constructor: + +* `numTopFeatures` number of top features that the selector will select (filter). + +We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in +`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then +return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. + +This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +an `RDD[Vector]` to produce a reduced `RDD[Vector]`. + +Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). + +#### Example + +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature. + +
    +
    +{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector + +// Load some data in libsvm format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category +val discretizedData = data.map { lp => + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) +} +// Create ChiSqSelector that will select top 50 of 692 features +val selector = new ChiSqSelector(50) +// Create ChiSqSelector model (selecting features) +val transformer = selector.fit(discretizedData) +// Filter the top 50 features from each feature vector +val filteredData = discretizedData.map { lp => + LabeledPoint(lp.label, transformer.transform(lp.features)) +} +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ChiSqSelector; +import org.apache.spark.mllib.feature.ChiSqSelectorModel; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); +JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), + "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); + +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category +JavaRDD discretizedData = points.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + final double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); + } + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + } + }); + +// Create ChiSqSelector that will select top 50 of 692 features +ChiSqSelector selector = new ChiSqSelector(50); +// Create ChiSqSelector model (selecting features) +final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); +// Filter the top 50 features from each feature vector +JavaRDD filteredData = discretizedData.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + return new LabeledPoint(lp.label(), transformer.transform(lp.features())); + } + } +); + +sc.stop(); +{% endhighlight %} +
    +
    + +## ElementwiseProduct + +ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. + +`\[ \begin{pmatrix} +v_1 \\ +\vdots \\ +v_N +\end{pmatrix} \circ \begin{pmatrix} + w_1 \\ + \vdots \\ + w_N + \end{pmatrix} += \begin{pmatrix} + v_1 w_1 \\ + \vdots \\ + v_N w_N + \end{pmatrix} +\]` + +[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: + +* `w`: the transforming vector. + +`ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. + +### Example + +This example below demonstrates how to transform vectors using a transforming vector value. + +
    +
    +{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct(transformingVector) + +// Batch transform and per-row transform give the same results: +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import ElementwiseProduct + +# Load and parse the data +sc = SparkContext() +data = sc.textFile("data/mllib/kmeans_data.txt") +parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) + +# Create weight vector. +transformingVector = Vectors.dense([0.0, 1.0, 2.0]) +transformer = ElementwiseProduct(transformingVector) + +# Batch transform +transformedData = transformer.transform(parsedData) +# Single-row transform +transformedData2 = transformer.transform(parsedData.first()) + +{% endhighlight %} +
    +
    + + +## PCA + +A feature transformer that projects vectors to a low-dimensional space using PCA. +Details you can read at [dimensionality reduction](mllib-dimensionality-reduction.html). + +### Example + +The following code demonstrates how to compute principal components on a `Vector` +and use them to project the vectors into a low-dimensional space while keeping associated labels +for calculation a [Linear Regression]((mllib-linear-methods.html)) + +
    +
    +{% highlight scala %} +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.feature.PCA + +val data = sc.textFile("data/mllib/ridge-data/lpsa.data").map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) +}.cache() + +val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0).cache() +val test = splits(1) + +val pca = new PCA(training.first().features.size/2).fit(data.map(_.features)) +val training_pca = training.map(p => p.copy(features = pca.transform(p.features))) +val test_pca = test.map(p => p.copy(features = pca.transform(p.features))) + +val numIterations = 100 +val model = LinearRegressionWithSGD.train(training, numIterations) +val model_pca = LinearRegressionWithSGD.train(training_pca, numIterations) + +val valuesAndPreds = test.map { point => + val score = model.predict(point.features) + (score, point.label) +} + +val valuesAndPreds_pca = test_pca.map { point => + val score = model_pca.predict(point.features) + (score, point.label) +} + +val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() +val MSE_pca = valuesAndPreds_pca.map{case(v, p) => math.pow((v - p), 2)}.mean() + +println("Mean Squared Error = " + MSE) +println("PCA Mean Squared Error = " + MSE_pca) +{% endhighlight %} +
    +
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md new file mode 100644 index 000000000000..4d4f5cfdc564 --- /dev/null +++ b/docs/mllib-frequent-pattern-mining.md @@ -0,0 +1,323 @@ +--- +layout: global +title: Frequent Pattern Mining - MLlib +displayTitle: MLlib - Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. +MLlib provides a parallel implementation of FP-growth, +a popular algorithm to mining frequent itemsets. + +## FP-growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In MLlib, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffices of transactions, +and hence more scalable than a single-machine implementation. +We refer users to the papers for more details. + +MLlib's FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `numPartitions`: the number of partitions used to distribute the work. + +**Examples** + +
    +
    + +[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the +FP-growth algorithm. +It take a `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. + + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.FPGrowth + +val data = sc.textFile("data/mllib/sample_fpgrowth.txt") + +val transactions: RDD[Array[String]] = data.map(s => s.trim.split(' ')) + +val fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10) +val model = fpg.run(transactions) + +model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) +} + +val minConfidence = 0.8 +model.generateAssociationRules(minConfidence).collect().foreach { rule => + println( + rule.antecedent.mkString("[", ",", "]") + + " => " + rule.consequent .mkString("[", ",", "]") + + ", " + rule.confidence) +} +{% endhighlight %} + +
    + +
    + +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take an `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. + +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +SparkConf conf = new SparkConf().setAppName("FP-growth Example"); +JavaSparkContext sc = new JavaSparkContext(conf); + +JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); + +JavaRDD> transactions = data.map( + new Function>() { + public List call(String line) { + String[] parts = line.split(" "); + return Arrays.asList(parts); + } + } +); + +FPGrowth fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10); +FPGrowthModel model = fpg.run(transactions); + +for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + itemset.javaItems() + "], " + itemset.freq()); +} + +double minConfidence = 0.8; +for (AssociationRules.Rule rule + : model.generateAssociationRules(minConfidence).toJavaRDD().collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); +} +{% endhighlight %} + +
    + +
    + +[`FPGrowth`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `List` of items of a generic type. +Calling `FPGrowth.train` with transactions returns an +[`FPGrowthModel`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowthModel) +that stores the frequent itemsets with their frequencies. + +{% highlight python %} +from pyspark.mllib.fpm import FPGrowth + +data = sc.textFile("data/mllib/sample_fpgrowth.txt") + +transactions = data.map(lambda line: line.strip().split(' ')) + +model = FPGrowth.train(transactions, minSupport=0.2, numPartitions=10) + +result = model.freqItemsets().collect() +for fi in result: + print(fi) +{% endhighlight %} + +
    + +
    + +## Association Rules + +
    +
    +[AssociationRules](api/scala/index.html#org.apache.spark.mllib.fpm.AssociationRules) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.AssociationRules +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset + +val freqItemsets = sc.parallelize(Seq( + new FreqItemset(Array("a"), 15L), + new FreqItemset(Array("b"), 35L), + new FreqItemset(Array("a", "b"), 12L) +)); + +val ar = new AssociationRules() + .setMinConfidence(0.8) +val results = ar.run(freqItemsets) + +results.collect().foreach { rule => + println("[" + rule.antecedent.mkString(",") + + "=>" + + rule.consequent.mkString(",") + "]," + rule.confidence) +} +{% endhighlight %} + +
    + +
    +[AssociationRules](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + +JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( + new FreqItemset(new String[] {"a"}, 15L), + new FreqItemset(new String[] {"b"}, 35L), + new FreqItemset(new String[] {"a", "b"}, 12L) +)); + +AssociationRules arules = new AssociationRules() + .setMinConfidence(0.8); +JavaRDD> results = arules.run(freqItemsets); + +for (AssociationRules.Rule rule: results.collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); +} +{% endhighlight %} + +
    +
    + +## PrefixSpan + +PrefixSpan is a sequential pattern mining algorithm described in +[Pei et al., Mining Sequential Patterns by Pattern-Growth: The +PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +the reader to the referenced paper for formalizing the sequential +pattern mining problem. + +MLlib's PrefixSpan implementation takes the following parameters: + +* `minSupport`: the minimum support required to be considered a frequent + sequential pattern. +* `maxPatternLength`: the maximum length of a frequent sequential + pattern. Any frequent pattern exceeding this length will not be + included in the results. +* `maxLocalProjDBSize`: the maximum number of items allowed in a + prefix-projected database before local iterative processing of the + projected databse begins. This parameter should be tuned with respect + to the size of your executors. + +**Examples** + +The following example illustrates PrefixSpan running on the sequences +(using same notation as Pei et al): + +~~~ + <(12)3> + <1(32)(12)> + <(12)5> + <6> +~~~ + +
    +
    + +[`PrefixSpan`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpan) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpanModel) +that stores the frequent sequences with their frequencies. + +{% highlight scala %} +import org.apache.spark.mllib.fpm.PrefixSpan + +val sequences = sc.parallelize(Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6)) + ), 2).cache() +val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) +val model = prefixSpan.run(sequences) +model.freqSequences.collect().foreach { freqSequence => +println( + freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + ", " + freqSequence.freq) +} +{% endhighlight %} + +
    + +
    + +[`PrefixSpan`](api/java/org/apache/spark/mllib/fpm/PrefixSpan.html) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/java/org/apache/spark/mllib/fpm/PrefixSpanModel.html) +that stores the frequent sequences with their frequencies. + +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.mllib.fpm.PrefixSpan; +import org.apache.spark.mllib.fpm.PrefixSpanModel; + +JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) +), 2); +PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); +PrefixSpanModel model = prefixSpan.run(sequences); +for (PrefixSpan.FreqSequence freqSeq: model.freqSequences().toJavaRDD().collect()) { + System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); +} +{% endhighlight %} + +
    +
    + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 39c64d06926b..6330c977552d 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -1,46 +1,71 @@ --- layout: global -title: Machine Learning Library (MLlib) Programming Guide +title: MLlib +displayTitle: Machine Learning Library (MLlib) Guide +description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: +filtering, dimensionality reduction, as well as underlying optimization primitives. +Guides for individual algorithms are listed below. + +The API is divided into 2 parts: + +* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. +* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. + +We list major functionality from both below, with links to detailed guides. + +# MLlib types, algorithms and utilities + +This lists functionality included in `spark.mllib`, the main MLlib API. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) - * summary statistics - * correlations - * stratified sampling - * hypothesis testing - * random data generation + * [summary statistics](mllib-statistics.html#summary-statistics) + * [correlations](mllib-statistics.html#correlations) + * [stratified sampling](mllib-statistics.html#stratified-sampling) + * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) - * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [ensembles of trees (Random Forests and Gradient-Boosted Trees)](mllib-ensembles.html) + * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) - * alternating least squares (ALS) + * [alternating least squares (ALS)](mllib-collaborative-filtering.html#collaborative-filtering) * [Clustering](mllib-clustering.html) - * k-means + * [k-means](mllib-clustering.html#k-means) + * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) + * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) + * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) - * singular value decomposition (SVD) - * principal component analysis (PCA) + * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) + * [principal component analysis (PCA)](mllib-dimensionality-reduction.html#principal-component-analysis-pca) * [Feature extraction and transformation](mllib-feature-extraction.html) +* [Frequent pattern mining](mllib-frequent-pattern-mining.html) + * [FP-growth](mllib-frequent-pattern-mining.html#fp-growth) + * [association rules](mllib-frequent-pattern-mining.html#association-rules) + * [PrefixSpan](mllib-frequent-pattern-mining.html#prefix-span) +* [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) - * stochastic gradient descent - * limited-memory BFGS (L-BFGS) + * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) + * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) +* [PMML model export](mllib-pmml-model-export.html) MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. # spark.ml: high-level APIs for ML pipelines -Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of +Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. -It is currently an alpha component, and we would like to hear back from the community about -how it fits real-world use cases and how it could be improved. + +*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. Note that we will keep supporting and adding features to `spark.mllib` along with the development of `spark.ml`. @@ -48,153 +73,54 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. - -# Dependencies - -MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), -which depends on [netlib-java](https://github.com/fommil/netlib-java), -and [jblas](https://github.com/mikiobraun/jblas). -`netlib-java` and `jblas` depend on native Fortran routines. -You need to install the -[gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries) -if it is not already present on your nodes. -MLlib will throw a linking error if it cannot detect these libraries automatically. -Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's -dependency set under default settings. -If no native library is available at runtime, you will see a warning message. -To use native libraries from `netlib-java`, please build Spark with `-Pnetlib-lgpl` or -include `com.github.fommil.netlib:all:1.1.2` as a dependency of your project. -If you want to use optimized BLAS/LAPACK libraries such as -[OpenBLAS](http://www.openblas.net/), please link its shared libraries to -`/usr/lib/libblas.so.3` and `/usr/lib/liblapack.so.3`, respectively. -BLAS/LAPACK libraries on worker nodes should be built without multithreading. - -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. - ---- - -# Migration Guide - -## From 1.1 to 1.2 - -The only API changes in MLlib v1.2 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.2: - -1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number -of classes. In MLlib v1.1, this argument was called `numClasses` in Python and -`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. -This `numClasses` parameter is specified either via -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. +Guides for `spark.ml` include: -2. *(Breaking change)* The API for -[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. -This should generally not affect user code, unless the user manually constructs decision trees -(instead of using the `trainClassifier` or `trainRegressor` methods). -The tree `Node` now includes more information, including the probability of the predicted label -(for classification). +* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts +* Guides on using algorithms within the Pipelines API: + * [Feature transformers](ml-features.html), including a few not in the lower-level `spark.mllib` API + * [Decision trees](ml-decision-tree.html) + * [Ensembles](ml-ensembles.html) + * [Linear methods](ml-linear-methods.html) -3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. - -Examples in the Spark distribution and examples in the -[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. - -## From 1.0 to 1.1 - -The only API changes in MLlib v1.1 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.1: - -1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match -the implementations of trees in -[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) -and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). -In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. -In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. -This depth is specified by the `maxDepth` parameter in -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` -methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -rather than using the old parameter class `Strategy`. These new training methods explicitly -separate classification and regression, and they replace specialized parameter types with -simple `String` types. - -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the -[Decision Trees Guide](mllib-decision-tree.html#examples). - -## From 0.9 to 1.0 - -In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few -breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. Details are described below. - -
    -
    - -We used to represent a feature vector by `Array[Double]`, which is replaced by -[`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) in v1.0. Algorithms that used -to accept `RDD[Array[Double]]` now take -`RDD[Vector]`. [`LabeledPoint`](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) -is now a wrapper of `(Double, Vector)` instead of `(Double, Array[Double])`. Converting -`Array[Double]` to `Vector` is straightforward: - -{% highlight scala %} -import org.apache.spark.mllib.linalg.{Vector, Vectors} - -val array: Array[Double] = ... // a double array -val vector: Vector = Vectors.dense(array) // a dense vector -{% endhighlight %} - -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to create sparse vectors. +# Dependencies -*Note*: Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. +MLlib uses the linear algebra package +[Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised +numerical processing. If natives are not available at runtime, you +will see a warning message and a pure JVM implementation will be used +instead. -
    +To learn more about the benefits and background of system optimised +natives, you may wish to watch Sam Halliday's ScalaX talk on +[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). -
    +Due to licensing issues with runtime proprietary binaries, we do not +include `netlib-java`'s native proxies by default. To configure +`netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with +`-Pnetlib-lgpl`) as a dependency of your project and read the +[netlib-java](https://github.com/fommil/netlib-java) documentation for +your platform's additional installation instructions. -We used to represent a feature vector by `double[]`, which is replaced by -[`Vector`](api/java/index.html?org/apache/spark/mllib/linalg/Vector.html) in v1.0. Algorithms that used -to accept `RDD` now take -`RDD`. [`LabeledPoint`](api/java/index.html?org/apache/spark/mllib/regression/LabeledPoint.html) -is now a wrapper of `(double, Vector)` instead of `(double, double[])`. Converting `double[]` to -`Vector` is straightforward: +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) +version 1.4 or newer. -{% highlight java %} -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; +--- -double[] array = ... // a double array -Vector vector = Vectors.dense(array); // a dense vector -{% endhighlight %} +# Migration Guide -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to -create sparse vectors. +For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -
    - -
    +## From 1.3 to 1.4 -We used to represent a labeled feature vector in a NumPy array, where the first entry corresponds to -the label and the rest are features. This representation is replaced by class -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html), which takes both -dense and sparse feature vectors. +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: -{% highlight python %} -from pyspark.mllib.linalg import SparseVector -from pyspark.mllib.regression import LabeledPoint +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. -# Create a labeled point with a positive label and a dense feature vector. -pos = LabeledPoint(1.0, [1.0, 0.0, 3.0]) +## Previous Spark Versions -# Create a labeled point with a negative label and a sparse feature vector. -neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) -{% endhighlight %} -
    -
    +Earlier migration guides are archived [on this page](mllib-migration-guides.html). diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md new file mode 100644 index 000000000000..6aa881f74918 --- /dev/null +++ b/docs/mllib-isotonic-regression.md @@ -0,0 +1,198 @@ +--- +layout: global +title: Isotonic regression - MLlib +displayTitle: MLlib - Regression +--- + +## Isotonic regression +[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression) +belongs to the family of regression algorithms. Formally isotonic regression is a problem where +given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses +and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted +finding a function that minimises + +`\begin{equation} + f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 +\end{equation}` + +with respect to complete order subject to +`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights. +The resulting function is called isotonic regression and it is unique. +It can be viewed as least squares problem under order restriction. +Essentially isotonic regression is a +[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) +best fitting the original data points. + +MLlib supports a +[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +which uses an approach to +[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +The training input is a RDD of tuples of three double values that represent +label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +optional parameter called $isotonic$ defaulting to true. +This argument specifies if the isotonic regression is +isotonic (monotonically increasing) or antitonic (monotonically decreasing). + +Training returns an IsotonicRegressionModel that can be used to predict +labels for both known and unknown features. The result of isotonic regression +is treated as piecewise linear function. The rules for prediction therefore are: + +* If the prediction input exactly matches a training feature + then associated prediction is returned. In case there are multiple predictions with the same + feature then one of them is returned. Which one is undefined + (same as java.util.Arrays.binarySearch). +* If the prediction input is lower or higher than all training features + then prediction with lowest or highest feature is returned respectively. + In case there are multiple predictions with the same feature + then the lowest or highest is returned respectively. +* If the prediction input falls between two training features then prediction is treated + as piecewise linear function and interpolated value is calculated from the + predictions of the two closest features. In case there are multiple values + with the same feature then the same rules as in previous point are used. + +### Examples + +
    +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight scala %} +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} + +val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) +} + +// Split data into training (60%) and test (40%) sets. +val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0) +val test = splits(1) + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +val model = new IsotonicRegression().setIsotonic(true).run(training) + +// Create tuples of predicted and real labels. +val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) +} + +// Calculate mean squared error between predicted and real labels. +val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() +println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") +{% endhighlight %} +
    + +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +import scala.Tuple2; +import scala.Tuple3; + +JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } +); + +// Split data into training (60%) and test (40%) sets. +JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); +JavaRDD> training = splits[0]; +JavaRDD> test = splits[1]; + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + +// Create tuples of predicted and real labels. +JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } +); + +// Calculate mean squared error between predicted and real labels. +Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } +).rdd()).mean(); + +System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} +
    + +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight python %} +import math +from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel + +data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +# Create label, feature, weight tuples from input data with weight set to default value 1.0. +parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) + +# Split data into training (60%) and test (40%) sets. +training, test = parsedData.randomSplit([0.6, 0.4], 11) + +# Create isotonic regression model from training data. +# Isotonic parameter defaults to true so it is only shown for demonstration +model = IsotonicRegression.train(training) + +# Create tuples of predicted and real labels. +predictionAndLabel = test.map(lambda p: (model.predict(p[1]), p[0])) + +# Calculate mean squared error between predicted and real labels. +meanSquaredError = predictionAndLabel.map(lambda pl: math.pow((pl[0] - pl[1]), 2)).mean() +print("Mean Squared Error = " + str(meanSquaredError)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = IsotonicRegressionModel.load(sc, "myModelPath") +{% endhighlight %} +
    +
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 44b7f67c5773..e9b2d276cd38 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,18 +10,18 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} \newcommand{\av}{\mathbf{\alpha}} \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib: L1$\|\wv\|_1$$\mathrm{sign}(\wv)$ + + elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$ + @@ -107,25 +110,33 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. -## Binary classification - -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) -aims to divide items into two categories: positive and negative. MLlib -supports two linear methods for binary classification: linear Support Vector -Machines (SVMs) and logistic regression. For both methods, MLlib supports -L1 and L2 regularized variants. The training data set is represented by an RDD -of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the -mathematical formulation in this guide, a training label $y$ is denoted as -either $+1$ (positive) or $-1$ (negative), which is convenient for the -formulation. *However*, the negative label is represented by $0$ in MLlib -instead of $-1$, to be consistent with multiclass labeling. +## Classification + +[Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into +categories. +The most common classification type is +[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +categories, usually named positive and negative. +If there are more than two categories, it is called +[multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). +MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +and logistic regression. +Linear SVMs supports only binary classification, while logistic regression supports both binary and +multiclass classification problems. +For both methods, MLlib supports L1 and L2 regularized variants. +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +where labels are class indices starting from zero: $0, 1, 2, \ldots$. +Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either +$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. +*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -144,41 +155,7 @@ denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative otherwise. -### Logistic regression - -[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. -It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss -function in the formulation given by the logistic loss: -`\[ -L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). -\]` - -The logistic regression algorithm outputs a logistic regression model. Given a -new data point, denoted by $\x$, the model makes predictions by -applying the logistic function -`\[ -\mathrm{f}(z) = \frac{1}{1 + e^{-z}} -\]` -where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or -negative otherwise, though unlike linear SVMs, the raw output of the logistic regression -model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability -that $\x$ is positive). - -### Evaluation metrics - -MLlib supports common evaluation metrics for binary classification (not available in PySpark). -This -includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), -[receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), -precision-recall curve, and -[area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -AUC is commonly used to compare the performance of various models while -precision/recall/F-measure can help determine the appropriate threshold to use -for prediction purposes. - -### Examples +**Examples**
    @@ -189,11 +166,8 @@ object, and make predictions with the resulting model to compute the training error. {% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.SVMWithSGD +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. @@ -211,7 +185,7 @@ val model = SVMWithSGD.train(training, numIterations) // Clear the default threshold. model.clearThreshold() -// Compute raw scores on the test set. +// Compute raw scores on the test set. val scoreAndLabels = test.map { point => val score = model.predict(point.features) (score, point.label) @@ -222,6 +196,10 @@ val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() println("Area under ROC = " + auROC) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %} The `SVMWithSGD.train()` method by default performs L2 regularization with the @@ -243,8 +221,6 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. -
    @@ -255,15 +231,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} -import java.util.Random; - import scala.Tuple2; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.*; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; + import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; @@ -280,11 +254,11 @@ public class SVMClassifier { JavaRDD training = data.sample(false, 0.6, 11L); training.cache(); JavaRDD test = data.subtract(training); - + // Run training algorithm to build the model. int numIterations = 100; final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); - + // Clear the default threshold. model.clearThreshold(); @@ -297,13 +271,17 @@ public class SVMClassifier { } } ); - + // Get evaluation metrics. - BinaryClassificationMetrics metrics = + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); double auROC = metrics.areaUnderROC(); - + System.out.println("Area under ROC = " + auROC); + + // Save and load model + model.save(sc, "myModelPath"); + SVMModel sameModel = SVMModel.load(sc, "myModelPath"); } } {% endhighlight %} @@ -334,14 +312,198 @@ quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency.
    +
    +The following example shows how to load a sample dataset, build SVM model, +and make predictions with the resulting model to compute the training error. + +{% highlight python %} +from pyspark.mllib.classification import SVMWithSGD, SVMModel +from pyspark.mllib.regression import LabeledPoint + +# Load and parse the data +def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + +data = sc.textFile("data/mllib/sample_svm_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = SVMWithSGD.train(parsedData, iterations=100) + +# Evaluating the model on training data +labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) +trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) +print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") +{% endhighlight %} +
    + + +### Logistic regression + +[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a +binary response. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, +with the loss function in the formulation given by the logistic loss: +`\[ +L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). +\]` + +For binary classification problems, the algorithm outputs a binary logistic regression model. +Given a new data point, denoted by $\x$, the model makes predictions by +applying the logistic function +`\[ +\mathrm{f}(z) = \frac{1}{1 + e^{-z}} +\]` +where $z = \wv^T \x$. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). + +Binary logistic regression can be generalized into +[multinomial logistic regression](http://en.wikipedia.org/wiki/Multinomial_logistic_regression) to +train and predict multiclass classification problems. +For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the +other $K - 1$ outcomes can be separately regressed against the pivot outcome. +In MLlib, the first class $0$ is chosen as the "pivot" class. +See Section 4.4 of +[The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for +references. +Here is an +[detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297). + +For multiclass classification problems, the algorithm will output a multinomial logistic regression +model, which contains $K - 1$ binary logistic regression models regressed against the first class. +Given a new data points, $K - 1$ models will be run, and the class with largest probability will be +chosen as the predicted class. + +We implemented two algorithms to solve logistic regression: mini-batch gradient descent and L-BFGS. +We recommend L-BFGS over mini-batch gradient descent for faster convergence. + +**Examples** + +
    + +
    +The following code illustrates how to load a sample multiclass dataset, split it into train and +test, and use +[LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) +to fit a logistic regression model. +Then the model is evaluated against the test dataset and saved to disk. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + +// Split data into training (60%) and test (40%). +val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0).cache() +val test = splits(1) + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training) + +// Compute raw scores on the test set. +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Get evaluation metrics. +val metrics = new MulticlassMetrics(predictionAndLabels) +val precision = metrics.precision +println("Precision = " + precision) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = LogisticRegressionModel.load(sc, "myModelPath") +{% endhighlight %} + +
    + +
    +The following code illustrates how to load a sample multiclass dataset, split it into train and +test, and use +[LogisticRegressionWithLBFGS](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) +to fit a logistic regression model. +Then the model is evaluated against the test dataset and saved to disk. + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MultinomialLogisticRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + double precision = metrics.precision(); + System.out.println("Precision = " + precision); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} +{% endhighlight %} +
    +
    The following example shows how to load a sample dataset, build Logistic Regression model, and make predictions with the resulting model to compute the training error. +Note that the Python API does not yet support multiclass classification and model save/load but +will in the future. + {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -352,20 +514,26 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = LogisticRegressionWithLBFGS.train(parsedData) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    -## Linear least squares, Lasso, and ridge regression +# Regression + +### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -373,26 +541,27 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -### Examples +**Examples**
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -413,6 +582,10 @@ val valuesAndPreds = parsedData.map { point => } val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) @@ -443,7 +616,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -463,7 +636,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -483,6 +656,10 @@ public class LinearRegression { } ).rdd()).mean(); System.out.println("training Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -490,13 +667,14 @@ public class LinearRegression {
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from numpy import array +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel # Load and parse the data def parsePoint(line): @@ -513,6 +691,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -523,15 +705,15 @@ section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. -## Streaming linear regression +###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. -### Examples +**Examples** The following example demonstrates how to load training and testing data from two different input streams of text files, parse the streams as labeled points, fit a linear regression model @@ -541,7 +723,7 @@ online to the first stream, and make predictions on the second stream.
    -First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -553,7 +735,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -573,7 +755,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -583,22 +765,74 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
    +
    + +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! +
    +
    -## Implementation (developer) + +# Implementation (developer) Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the MLlib - Old Migration Guides +description: MLlib migration guides from before Spark SPARK_VERSION_SHORT +--- + +The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). + +## From 1.2 to 1.3 + +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. + +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. + +## From 1.1 to 1.2 + +The only API changes in MLlib v1.2 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.2: + +1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number +of classes. In MLlib v1.1, this argument was called `numClasses` in Python and +`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. +This `numClasses` parameter is specified either via +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Breaking change)* The API for +[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. +This should generally not affect user code, unless the user manually constructs decision trees +(instead of using the `trainClassifier` or `trainRegressor` methods). +The tree `Node` now includes more information, including the probability of the predicted label +(for classification). + +3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. + +Examples in the Spark distribution and examples in the +[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. + +## From 1.0 to 1.1 + +The only API changes in MLlib v1.1 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.1: + +1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match +the implementations of trees in +[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) +and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). +In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. +In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. +This depth is specified by the `maxDepth` parameter in +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` +methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +rather than using the old parameter class `Strategy`. These new training methods explicitly +separate classification and regression, and they replace specialized parameter types with +simple `String` types. + +Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +[Decision Trees Guide](mllib-decision-tree.html#examples). + +## From 0.9 to 1.0 + +In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few +breaking changes. If your data is sparse, please store it in a sparse format instead of dense to +take advantage of sparsity in both storage and computation. Details are described below. + diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index d5b044d94fdd..e73bd30f3a90 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -13,12 +13,14 @@ compute the conditional probability distribution of label given an observation and use it for prediction. MLlib supports [multinomial naive -Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), -which is typically used for [document -classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. -Feature values must be nonnegative to represent term frequencies. +feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). +Feature values must be nonnegative. The model type is selected with an optional parameter +"multinomial" or "bernoulli" with "multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -32,12 +34,12 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, and output a +smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. {% highlight scala %} -import org.apache.spark.mllib.classification.NaiveBayes +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -51,10 +53,14 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0) +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} @@ -68,6 +74,8 @@ optionally smoothing parameter `lambda` as input, and output a can be used for evaluation and prediction. {% highlight java %} +import scala.Tuple2; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; @@ -75,7 +83,6 @@ import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; -import scala.Tuple2; JavaRDD training = ... // training set JavaRDD test = ... // test set @@ -93,34 +100,50 @@ double accuracy = predictionAndLabel.filter(new Function, return pl._1().equals(pl._2()); } }).count() / (double) test.count(); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    -[NaiveBayes](api/python/pyspark.mllib.classification.NaiveBayes-class.html) implements multinomial +[NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial naive Bayes. It takes an RDD of -[LabeledPoint](api/python/pyspark.mllib.regression.LabeledPoint-class.html) and an optionally +[LabeledPoint](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint) and an optionally smoothing parameter `lambda` as input, and output a -[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be +[NaiveBayesModel](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. - +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel +from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.classification import NaiveBayes -# an RDD of LabeledPoint -data = sc.parallelize([ - LabeledPoint(0.0, [0.0, 0.0]) - ... # more labeled points -]) +def parseLine(line): + parts = line.split(',') + label = float(parts[0]) + features = Vectors.dense([float(x) for x in parts[1].split(' ')]) + return LabeledPoint(label, features) + +data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) + +# Split data aproximately into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 0) # Train a naive Bayes model. -model = NaiveBayes.train(data, 1.0) +model = NaiveBayes.train(training, 1.0) + +# Make prediction and test accuracy. +predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) +accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() -# Make prediction. -prediction = model.predict([0.0, 0.0]) +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 4d101afca2c9..6cabc1610a15 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -203,6 +203,10 @@ regularization, as well as L2 regularizer. recommended. * `maxNumIterations` is the maximal number of iterations that L-BFGS can be run. * `regParam` is the regularization parameter when using regularization. +* `convergenceTol` controls how much relative change is still allowed when L-BFGS +is considered to converge. This must be nonnegative. Lower values are less tolerant and +therefore generally cause more iterations to be run. This value looks at both average +improvement and the norm of gradient inside [Breeze LBFGS](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGS.scala). The `return` is a tuple containing two elements. The first element is a column matrix containing weights for every feature, and the second element is an array containing diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md new file mode 100644 index 000000000000..42ea2ca81f80 --- /dev/null +++ b/docs/mllib-pmml-model-export.md @@ -0,0 +1,86 @@ +--- +layout: global +title: PMML model export - MLlib +displayTitle: MLlib - PMML model export +--- + +* Table of contents +{:toc} + +## MLlib supported models + +MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). + +The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. + + + + + + + + + + + + + + + + + + + + + + + + + +
    MLlib modelPMML model
    KMeansModelClusteringModel
    LinearRegressionModelRegressionModel (functionName="regression")
    RidgeRegressionModelRegressionModel (functionName="regression")
    LassoModelRegressionModel (functionName="regression")
    SVMModelRegressionModel (functionName="classification" normalizationMethod="none")
    Binary LogisticRegressionModelRegressionModel (functionName="classification" normalizationMethod="logit")
    + +## Examples +
    + +
    +To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. + +Here a complete example of building a KMeansModel and print it out in PMML format: +{% highlight scala %} +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/kmeans_data.txt") +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using KMeans +val numClusters = 2 +val numIterations = 20 +val clusters = KMeans.train(parsedData, numClusters, numIterations) + +// Export to PMML +println("PMML Model:\n" + clusters.toPMML) +{% endhighlight %} + +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: + +{% highlight scala %} +// Export the model to a String in PMML format +clusters.toPMML + +// Export the model to a local file in PMML format +clusters.toPMML("/tmp/kmeans.xml") + +// Export the model to a directory on a distributed file system in PMML format +clusters.toPMML(sc,"/tmp/kmeans") + +// Export the model to the OutputStream in PMML format +clusters.toPMML(System.out) +{% endhighlight %} + +For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. + +
    + +
    diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ca8c29218f52..6acfc71d7b01 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -81,8 +81,8 @@ System.out.println(summary.numNonzeros()); // number of nonzeros in each column
    -[`colStats()`](api/python/pyspark.mllib.stat.Statistics-class.html#colStats) returns an instance of -[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.stat.MultivariateStatisticalSummary-class.html), +[`colStats()`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics.colStats) returns an instance of +[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary), which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. @@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors # Compute column summary statistics. summary = Statistics.colStats(mat) -print summary.mean() -print summary.variance() -print summary.numNonzeros() +print(summary.mean()) +print(summary.variance()) +print(summary.numNonzeros()) {% endhighlight %}
    @@ -169,7 +169,7 @@ Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson");
    -[`Statistics`](api/python/pyspark.mllib.stat.Statistics-class.html) provides methods to +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. @@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a # method is not specified, Pearson's method will be used by default. -print Statistics.corr(seriesX, seriesY, method="pearson") +print(Statistics.corr(seriesX, seriesY, method="pearson")) data = ... # an RDD of Vectors # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. # If a method is not specified, Pearson's method will be used by default. -print Statistics.corr(data, method="pearson") +print(Statistics.corr(data, method="pearson")) {% endhighlight %}
    @@ -258,7 +258,7 @@ JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); {% endhighlight %}
    -[`sampleByKey()`](api/python/pyspark.rdd.RDD-class.html#sampleByKey) allows users to +[`sampleByKey()`](api/python/pyspark.html#pyspark.RDD.sampleByKey) allows users to sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of keys. @@ -283,7 +283,7 @@ approxSample = data.sampleByKey(False, fractions); Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically significant, whether this result occurred by chance or not. MLlib currently supports Pearson's -chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine +chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. @@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events # compute the goodness of fit. If a second vector to test against is not supplied as a parameter, # the test runs against a uniform distribution. goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. +print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. mat = Matrices.dense(...) # a contingency matrix # conduct Pearson's independence test on the input contingency matrix independenceTestResult = Statistics.chiSqTest(mat) -print independenceTestResult # summary of the test including the p-value, degrees of freedom... +print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... obs = sc.parallelize(...) # LabeledPoint(feature, label) . @@ -415,13 +415,91 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) . featureTestResults = Statistics.chiSqTest(obs) for i, result in enumerate(featureTestResults): - print "Column $d:" % (i + 1) - print result + print("Column $d:" % (i + 1)) + print(result) {% endhighlight %}
    +Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +for equality of probability distributions. By providing the name of a theoretical distribution +(currently solely supported for the normal distribution) and its parameters, or a function to +calculate the cumulative distribution according to a given theoretical distribution, the user can +test the null hypothesis that their sample is drawn from that distribution. In the case that the +user tests against the normal distribution (`distName="norm"`), but does not provide distribution +parameters, the test initializes to the standard normal distribution and logs an appropriate +message. + +
    +
    +[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight scala %} +import org.apache.spark.mllib.stat.Statistics + +val data: RDD[Double] = ... // an RDD of sample data + +// run a KS test for the sample versus a standard normal distribution +val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) +println(testResult) // summary of the test including the p-value, test statistic, + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis + +// perform a KS test using a cumulative distribution function of our making +val myCDF: Double => Double = ... +val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) +{% endhighlight %} +
    + +
    +[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; + +JavaSparkContext jsc = ... +JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, ...)); +KolmogorovSmirnovTestResult testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0.0, 1.0); +// summary of the test including the p-value, test statistic, +// and null hypothesis +// if our p-value indicates significance, we can reject the null hypothesis +System.out.println(testResult); +{% endhighlight %} +
    + +
    +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight python %} +from pyspark.mllib.stat import Statistics + +parallelData = sc.parallelize([1.0, 2.0, ... ]) + +# run a KS test for the sample versus a standard normal distribution +testResult = Statistics.kolmogorovSmirnovTest(parallelData, "norm", 0, 1) +print(testResult) # summary of the test including the p-value, test statistic, + # and null hypothesis + # if our p-value indicates significance, we can reject the null hypothesis +# Note that the Scala functionality of calling Statistics.kolmogorovSmirnovTest with +# a lambda to calculate the CDF is not made available in the Python API +{% endhighlight %} +
    +
    + + ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. @@ -476,7 +554,7 @@ JavaDoubleRDD v = u.map(
    -[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +[`RandomRDDs`](api/python/pyspark.mllib.html#pyspark.mllib.random.RandomRDDs) provides factory methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. @@ -493,5 +571,82 @@ u = RandomRDDs.uniformRDD(sc, 1000000L, 10) v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %}
    + + +## Kernel density estimation + +[Kernel density estimation](https://en.wikipedia.org/wiki/Kernel_density_estimation) is a technique +useful for visualizing empirical probability distributions without requiring assumptions about the +particular distribution that the observed samples are drawn from. It computes an estimate of the +probability density function of a random variables, evaluated at a given set of points. It achieves +this estimate by expressing the PDF of the empirical distribution at a particular point as the the +mean of PDFs of normal distributions centered around each of the samples. + +
    + +
    +[`KernelDensity`](api/scala/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight scala %} +import org.apache.spark.mllib.stat.KernelDensity +import org.apache.spark.rdd.RDD + +val data: RDD[Double] = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +val kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0) + +// Find density estimates for the given values +val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) +{% endhighlight %} +
    + +
    +[`KernelDensity`](api/java/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight java %} +import org.apache.spark.mllib.stat.KernelDensity; +import org.apache.spark.rdd.RDD; + +RDD data = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +KernelDensity kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0); + +// Find density estimates for the given values +double[] densities = kd.estimate(new double[] {-1.0, 2.0, 5.0}); +{% endhighlight %} +
    + +
    +[`KernelDensity`](api/python/pyspark.mllib.html#pyspark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight python %} +from pyspark.mllib.stat import KernelDensity + +data = ... # an RDD of sample data + +# Construct the density estimator with the sample data and a standard deviation for the Gaussian +# kernels +kd = KernelDensity() +kd.setSample(data) +kd.setBandwidth(3.0) + +# Find density estimates for the given values +densities = kd.estimate([-1.0, 2.0, 5.0]) +{% endhighlight %} +
    diff --git a/docs/monitoring.md b/docs/monitoring.md index f32cdef240d3..cedceb295802 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -1,6 +1,7 @@ --- layout: global title: Monitoring and Instrumentation +description: Monitoring, metrics, and instrumentation guide for Spark SPARK_VERSION_SHORT --- There are several ways to monitor Spark applications: web UIs, metrics, and external instrumentation. @@ -47,7 +48,7 @@ follows: Environment VariableMeaning SPARK_DAEMON_MEMORY - Memory to allocate to the history server (default: 512m). + Memory to allocate to the history server (default: 1g). SPARK_DAEMON_JAVA_OPTS @@ -85,10 +86,10 @@ follows: - spark.history.fs.updateInterval - 10 + spark.history.fs.update.interval + 10s - The period, in seconds, at which information displayed by this history server is updated. + The period at which information displayed by this history server is updated. Each update checks for any changes made to the event logs in persisted storage. @@ -144,11 +145,117 @@ follows: If disabled, no access control checks are made. + + spark.history.fs.cleaner.enabled + false + + Specifies whether the History Server should periodically clean up event logs from storage. + + + + spark.history.fs.cleaner.interval + 1d + + How often the job history cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.cleaner.maxAge. + + + + spark.history.fs.cleaner.maxAge + 7d + + Job history files older than this will be deleted when the history cleaner runs. + + Note that in all of these UIs, the tables are sortable by clicking their headers, making it easy to identify slow tasks, data skew, etc. +Note that the history server only displays completed Spark jobs. One way to signal the completion of a Spark job is to stop the Spark Context explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` to handle the Spark Context setup and tear down, and still show the job history on the UI. + +## REST API + +In addition to viewing the metrics in the UI, they are also available as JSON. This gives developers +an easy way to create new visualizations and monitoring tools for Spark. The JSON is available for +both running applications, and in the history server. The endpoints are mounted at `/api/v1`. Eg., +for the history server, they would typically be accessible at `http://:18080/api/v1`, and +for a running application, at `http://localhost:4040/api/v1`. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    EndpointMeaning
    /applicationsA list of all applications
    /applications/[app-id]/jobsA list of all jobs for a given application
    /applications/[app-id]/jobs/[job-id]Details for the given job
    /applications/[app-id]/stagesA list of all stages for a given application
    /applications/[app-id]/stages/[stage-id]A list of all attempts for the given stage
    /applications/[app-id]/stages/[stage-id]/[stage-attempt-id]Details for the given stage attempt
    /applications/[app-id]/stages/[stage-id]/[stage-attempt-id]/taskSummarySummary metrics of all tasks in the given stage attempt
    /applications/[app-id]/stages/[stage-id]/[stage-attempt-id]/taskListA list of all tasks for the given stage attempt
    /applications/[app-id]/executorsA list of all executors for the given application
    /applications/[app-id]/storage/rddA list of stored RDDs for the given application
    /applications/[app-id]/storage/rdd/[rdd-id]Details for the storage status of a given RDD
    /applications/[app-id]/logsDownload the event logs for all attempts of the given application as a zip file
    /applications/[app-id]/[attempt-id]/logsDownload the event logs for the specified attempt of the given application as a zip file
    + +When running on Yarn, each application has multiple attempts, so `[app-id]` is actually +`[app-id]/[attempt-id]` in all cases. + +These endpoints have been strongly versioned to make it easier to develop applications on top. + In particular, Spark guarantees: + +* Endpoints will never be removed from one version +* Individual fields will never be removed for any given endpoint +* New endpoints may be added +* New fields may be added to existing endpoints +* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. +* Api versions may be dropped, but only after at least one minor release of co-existing with a new api version + +Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is +still required, though there is only one application available. Eg. to see the list of jobs for the +running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to +keep the paths consistent in both modes. + # Metrics Spark has a configurable metrics system based on the @@ -175,6 +282,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `JmxSink`: Registers metrics for viewing in a JMX console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. * `GraphiteSink`: Sends metrics to a Graphite node. +* `Slf4jSink`: Sends metrics to slf4j as log entries. Spark also supports a Ganglia sink which is not included in the default build due to licensing restrictions: diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 6486614e7135..4cf83bb39263 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1,6 +1,7 @@ --- layout: global title: Spark Programming Guide +description: Spark SPARK_VERSION_SHORT programming guide in Java, Scala and Python --- * This will become a table of contents (this text will be scraped). @@ -40,19 +41,20 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency artifactId = hadoop-client version = -Finally, you need to import some Spark classes and implicit conversions into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following lines: {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.SparkConf {% endhighlight %} +(Before Spark 1.3.0, you need to explicitly `import org.apache.spark.SparkContext._` to enable essential implicit conversions.) +
    -Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports +Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports [lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html) for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. @@ -83,8 +85,8 @@ import org.apache.spark.SparkConf
    -Spark {{site.SPARK_VERSION}} works with Python 2.6 or higher (but not Python 3). It uses the standard CPython interpreter, -so C libraries like NumPy can be used. +Spark {{site.SPARK_VERSION}} works with Python 2.6+ or Python 3.4+. It can use the standard CPython interpreter, +so C libraries like NumPy can be used. It also works with PyPy 2.3+. To run Spark applications in Python, use the `bin/spark-submit` script located in the Spark directory. This script will load Spark's Java/Scala libraries and allow you to submit applications to a cluster. @@ -96,12 +98,20 @@ to your version of HDFS. Some common HDFS version tags are listed on the [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. -Finally, you need to import some Spark classes into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following line: -{% highlight scala %} +{% highlight python %} from pyspark import SparkContext, SparkConf {% endhighlight %} +PySpark requires the same minor version of Python in both driver and workers. It uses the default python version in PATH, +you can specify which version of Python you want to use by `PYSPARK_PYTHON`, for example: + +{% highlight bash %} +$ PYSPARK_PYTHON=python3.4 bin/pyspark +$ PYSPARK_PYTHON=/opt/pypy-2.5/bin/pypy bin/spark-submit examples/src/main/python/pi.py +{% endhighlight %} +
    @@ -141,8 +151,8 @@ JavaSparkContext sc = new JavaSparkContext(conf);
    -The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.context.SparkContext-class.html) object, which tells Spark -how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.conf.SparkConf-class.html) object +The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object, which tells Spark +how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.html#pyspark.SparkConf) object that contains information about your application. {% highlight python %} @@ -172,8 +182,11 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. -For example, to run `bin/spark-shell` on exactly four cores, use: +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly +four cores, use: {% highlight bash %} $ ./bin/spark-shell --master local[4] @@ -185,6 +198,12 @@ Or, to also add `code.jar` to its classpath, use: $ ./bin/spark-shell --master local[4] --jars code.jar {% endhighlight %} +To include a dependency using maven coordinates: + +{% highlight bash %} +$ ./bin/spark-shell --master local[4] --packages "org.example:example:0.1" +{% endhighlight %} + For a complete list of options, run `spark-shell --help`. Behind the scenes, `spark-shell` invokes the more general [`spark-submit` script](submitting-applications.html). @@ -195,7 +214,11 @@ For a complete list of options, run `spark-shell --help`. Behind the scenes, In the PySpark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add Python .zip, .egg or .py files -to the runtime path by passing a comma-separated list to `--py-files`. +to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: {% highlight bash %} @@ -223,9 +246,13 @@ You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +your notebook before you start to try Spark from the IPython notebook. +
    @@ -321,7 +348,7 @@ Apart from text files, Spark's Scala API also supports several other data format * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). In addition, Spark allows you to specify native types for a few common Writables; for example, `sequenceFile[Int, String]` will automatically read IntWritables and Texts. -* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `RDD.saveAsObjectFile` and `SparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -353,7 +380,7 @@ Apart from text files, Spark's Java API also supports several other data formats * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). -* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `JavaRDD.saveAsObjectFile` and `JavaSparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -459,7 +486,6 @@ the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters. - ## RDD Operations @@ -523,7 +549,7 @@ returning only its answer to the driver program. If we also wanted to use `lineLengths` again later, we could add: {% highlight java %} -lineLengths.persist(); +lineLengths.persist(StorageLevel.MEMORY_ONLY()); {% endhighlight %} before the `reduce`, which would cause `lineLengths` to be saved in memory after the first time it is computed. @@ -711,7 +737,7 @@ class MyClass(object): def __init__(self): self.field = "Hello" def doStuff(self, rdd): - return rdd.map(lambda s: self.field + x) + return rdd.map(lambda s: self.field + s) {% endhighlight %} To avoid this issue, the simplest way is to copy `field` into a local variable instead @@ -720,13 +746,76 @@ of accessing it externally: {% highlight python %} def doStuff(self, rdd): field = self.field - return rdd.map(lambda s: field + x) + return rdd.map(lambda s: field + s) +{% endhighlight %} + + + + + +### Understanding closures +One of the harder things about Spark is understanding the scope and life cycle of variables and methods when executing code across a cluster. RDD operations that modify variables outside of their scope can be a frequent source of confusion. In the example below we'll look at code that uses `foreach()` to increment a counter, but similar issues can occur for other operations as well. + +#### Example + +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): + +
    + +
    +{% highlight scala %} +var counter = 0 +var rdd = sc.parallelize(data) + +// Wrong: Don't do this!! +rdd.foreach(x => counter += x) + +println("Counter value: " + counter) +{% endhighlight %} +
    + +
    +{% highlight java %} +int counter = 0; +JavaRDD rdd = sc.parallelize(data); + +// Wrong: Don't do this!! +rdd.foreach(x -> counter += x); + +println("Counter value: " + counter); {% endhighlight %} +
    + +
    +{% highlight python %} +counter = 0 +rdd = sc.parallelize(data) + +# Wrong: Don't do this!! +rdd.foreach(lambda x: counter += x) +print("Counter value: " + counter) + +{% endhighlight %}
    +#### Local vs. cluster modes + +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. + +However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. + +What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only sees the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. + +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#AccumLink). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. + +In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. + +#### Printing elements of an RDD +Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. + ### Working with Key-Value Pairs
    @@ -740,11 +829,9 @@ by a key. In Scala, these operations are automatically available on RDDs containing [Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) objects -(the built-in tuples in the language, created by simply writing `(a, b)`), as long as you -import `org.apache.spark.SparkContext._` in your program to enable Spark's implicit -conversions. The key-value pair operations are available in the +(the built-in tuples in the language, created by simply writing `(a, b)`). The key-value pair operations are available in the [PairRDDFunctions](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions) class, -which automatically wraps around an RDD of tuples if you import the conversions. +which automatically wraps around an RDD of tuples. For example, the following code uses the `reduceByKey` operation on key-value pairs to count how many times each line of text occurs in a file: @@ -835,7 +922,8 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -856,7 +944,7 @@ for details. Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). - mapPartitions(func) + mapPartitions(func) Similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T. @@ -883,7 +971,7 @@ for details. Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better @@ -894,25 +982,25 @@ for details. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. @@ -925,17 +1013,17 @@ for details. process's stdin and lines output to its stdout are returned as an RDD of strings. - coalesce(numPartitions) + coalesce(numPartitions) Decrease the number of partitions in the RDD to numPartitions. Useful for running operations more efficiently after filtering down a large dataset. repartition(numPartitions) Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. - This always shuffles all data over the network. + This always shuffles all data over the network. - repartitionAndSortWithinPartitions(partitioner) + repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -948,7 +1036,9 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -974,7 +1064,7 @@ for details. take(n) - Return an array with the first n elements of the dataset. Note that this is currently not executed in parallel. Instead, the driver program computes all the elements. + Return an array with the first n elements of the dataset. takeSample(withReplacement, num, [seed]) @@ -990,7 +1080,7 @@ for details. saveAsSequenceFile(path)
    (Java and Scala) - Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that either implement Hadoop's Writable interface. In Scala, it is also + Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that implement Hadoop's Writable interface. In Scala, it is also available on types that are implicitly convertible to Writable (Spark includes conversions for basic types like Int, Double, String, etc). @@ -999,15 +1089,79 @@ for details. SparkContext.objectFile(). - countByKey() + countByKey() Only available on RDDs of type (K, V). Returns a hashmap of (K, Int) pairs with the count of each key. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an accumulator variable (see below) or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. +
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. +### Shuffle operations + +Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's +mechanism for re-distributing data so that it's grouped differently across partitions. This typically +involves copying data across executors and machines, making the shuffle a complex and +costly operation. + +#### Background + +To understand what happens during the shuffle we can consider the example of the +[`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all +values for a single key are combined into a tuple - the key and the result of executing a reduce +function against all values associated with that key. The challenge is that not all values for a +single key necessarily reside on the same partition, or even the same machine, but they must be +co-located to compute the result. + +In Spark, data is generally not distributed across partitions to be in the necessary place for a +specific operation. During computations, a single task will operate on a single partition - thus, to +organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - +this is called the **shuffle**. + +Although the set of elements in each partition of newly shuffled data will be deterministic, and so +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: + +* `mapPartitions` to sort each partition using, for example, `.sorted` +* `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning +* `sortBy` to make a globally ordered RDD + +Operations which can cause a shuffle include **repartition** operations like +[`repartition`](#RepartitionLink) and [`coalesce`](#CoalesceLink), **'ByKey** operations +(except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and +**join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). + +#### Performance Impact +The **Shuffle** is an expensive operation since it involves disk I/O, data serialization, and +network I/O. To organize data for the shuffle, Spark generates sets of tasks - *map* tasks to +organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from +MapReduce and does not directly relate to Spark's `map` and `reduce` operations. + +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks +read the relevant sorted blocks. + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables +to disk, incurring the additional overhead of disk I/O and increased garbage collection. + +Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +consume a large amount of disk space. The temporary storage directory is specified by the +`spark.local.dir` configuration parameter when configuring the Spark context. + +Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). + ## RDD Persistence One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory @@ -1027,7 +1181,7 @@ replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-proj These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), -[Python](api/python/pyspark.storagelevel.StorageLevel-class.html)) +[Python](api/python/pyspark.html#pyspark.StorageLevel)) to `persist()`. The `cache()` method is a shorthand for using the default storage level, which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The full set of storage levels is: @@ -1070,9 +1224,11 @@ storage levels is: Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. + from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon + out-of-the-box. Please refer to this page + for the suggested version pairings. @@ -1129,6 +1285,12 @@ than shipping a copy of it with tasks. They can be used, for example, to give ev large input dataset in an efficient manner. Spark also attempts to distribute broadcast variables using efficient broadcast algorithms to reduce communication cost. +Spark actions are executed through a set of stages, separated by distributed "shuffle" operations. +Spark automatically broadcasts the common data needed by tasks within each stage. The data +broadcasted this way is cached in serialized form and deserialized before running each task. This +means that explicitly creating broadcast variables is only useful when tasks across multiple stages +need the same data or when caching the data in deserialized form is important. + Broadcast variables are created from a variable `v` by calling `SparkContext.broadcast(v)`. The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the `value` method. The code below shows this: @@ -1177,7 +1339,7 @@ run on the cluster so that `v` is not shipped to the nodes more than once. In ad `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later). -## Accumulators +## Accumulators Accumulators are variables that are only "added" to through an associative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in @@ -1290,7 +1452,7 @@ scala> accum.value {% endhighlight %} While this code used the built-in support for accumulators of type Int, programmers can also -create their own types by subclassing [AccumulatorParam](api/python/pyspark.accumulators.AccumulatorParam-class.html). +create their own types by subclassing [AccumulatorParam](api/python/pyspark.html#pyspark.AccumulatorParam). The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class representing mathematical vectors, we could write: @@ -1322,25 +1484,28 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being
    {% highlight scala %} -val acc = sc.accumulator(0) -data.map(x => acc += x; f(x)) -// Here, acc is still 0 because no actions have cause the `map` to be computed. +val accum = sc.accumulator(0) +data.map { x => accum += x; f(x) } +// Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %}
    {% highlight java %} Accumulator accum = sc.accumulator(0); -data.map(x -> accum.add(x); f(x);); -// Here, accum is still 0 because no actions have cause the `map` to be computed. +data.map(x -> { accum.add(x); return f(x); }); +// Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %}
    {% highlight python %} accum = sc.accumulator(0) -data.map(lambda x => acc.add(x); f(x)) -# Here, acc is still 0 because no actions have cause the `map` to be computed. +def g(x): + accum.add(x) + return f(x) +data.map(g) +# Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %}
    @@ -1352,6 +1517,11 @@ The [application submission guide](submitting-applications.html) describes how t In short, once you package your application into a JAR (for Java/Scala) or a set of `.py` or `.zip` files (for Python), the `bin/spark-submit` script lets you submit it to any supported cluster manager. +# Launching Spark jobs from Java / Scala + +The [org.apache.spark.launcher](api/java/index.html?org/apache/spark/launcher/package-summary.html) +package provides classes for launching Spark jobs as child processes using a simple Java API. + # Unit Testing Spark is friendly to unit testing with any popular unit test framework. @@ -1409,7 +1579,8 @@ You can see some [example Spark programs](http://spark.apache.org/examples.html) In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run Java and Scala examples by passing the class name to Spark's `bin/run-example` script; for instance: ./bin/run-example SparkPi @@ -1418,6 +1589,10 @@ For Python examples, use `spark-submit` instead: ./bin/spark-submit examples/src/main/python/pi.py +For R examples, use `spark-submit` instead: + + ./bin/spark-submit examples/src/main/r/dataframe.R + For help on optimizing your programs, the [configuration](configuration.html) and [tuning](tuning.html) guides provide information on best practices. They are especially important for making sure that your data is stored in memory in an efficient format. @@ -1425,4 +1600,4 @@ For help on deploying, the [cluster mode overview](cluster-overview.html) descri in distributed operation and supported cluster managers. Finally, full API documentation is available in -[Scala](api/scala/#org.apache.spark.package), [Java](api/java/) and [Python](api/python/). +[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). diff --git a/docs/quick-start.md b/docs/quick-start.md index bf643bb70e15..ce2cc9d2169c 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -1,6 +1,7 @@ --- layout: global title: Quick Start +description: Quick start tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -183,10 +184,10 @@ scala> linesWithSpark.cache() res7: spark.RDD[String] = spark.FilteredRDD@17e51082 scala> linesWithSpark.count() -res8: Long = 15 +res8: Long = 19 scala> linesWithSpark.count() -res9: Long = 15 +res9: Long = 19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -201,10 +202,10 @@ a cluster, as described in the [programming guide](programming-guide.html#initia >>> linesWithSpark.cache() >>> linesWithSpark.count() -15 +19 >>> linesWithSpark.count() -15 +19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -405,7 +406,7 @@ logData = sc.textFile(logFile).cache() numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() -print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) {% endhighlight %} @@ -422,14 +423,14 @@ dependencies to `spark-submit` through its `--py-files` argument by packaging th We can run this application using the `bin/spark-submit` script: -{% highlight python %} +{% highlight bash %} # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --master local[4] \ SimpleApp.py ... Lines with a: 46, Lines with b: 23 -{% endhighlight python %} +{% endhighlight %}
    @@ -443,7 +444,8 @@ Congratulations on running your first Spark application! * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run them as follows: {% highlight bash %} @@ -452,4 +454,7 @@ You can run them as follows: # For Python examples, use spark-submit directly: ./bin/spark-submit examples/src/main/python/pi.py + +# For R examples, use spark-submit directly: +./bin/spark-submit examples/src/main/r/dataframe.R {% endhighlight %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 78358499fd01..cfd219ab02e2 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -78,6 +78,9 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. +Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -107,10 +110,14 @@ the `make-distribution.sh` script included in a Spark source tarball/checkout. The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. -The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: +## Client Mode + +In client mode, a Spark Mesos framework is launched directly on the client machine and waits for the driver output. + +The driver needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: - * `export MESOS_NATIVE_LIBRARY=`. This path is typically + * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically `/lib/libmesos.so` where the prefix is `/usr/local` by default. See Mesos installation instructions above. On Mac OS X, the library is called `libmesos.dylib` instead of `libmesos.so`. @@ -129,8 +136,7 @@ val sc = new SparkContext(conf) {% endhighlight %} (You can also use [`spark-submit`](submitting-applications.html) and configure `spark.executor.uri` -in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file. Note -that `spark-submit` currently only supports deploying the Spark driver in `client` mode for Mesos.) +in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file.) When running a shell, the `spark.executor.uri` parameter is inherited from `SPARK_EXECUTOR_URI`, so it does not need to be redundantly passed in as a system property. @@ -139,6 +145,17 @@ it does not need to be redundantly passed in as a system property. ./bin/spark-shell --master mesos://host:5050 {% endhighlight %} +## Cluster mode + +Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client +can find the results of the driver from the Mesos Web UI. + +To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master url (e.g: mesos://host:5050). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url +to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +Spark cluster Web UI. # Mesos Run Modes @@ -167,8 +184,23 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). -# Known issues -- When using the "fine-grained" mode, make sure that your executors always leave 32 MB free on the slaves. Otherwise it can happen that your Spark job does not proceed anymore. Currently, Apache Mesos only offers resources if there are at least 32 MB memory allocatable. But as Spark allocates memory only for the executor and cpu only for tasks, it can happen on high slave memory usage that no new tasks will be started anymore. More details can be found in [MESOS-1688](https://issues.apache.org/jira/browse/MESOS-1688). Alternatively use the "coarse-gained" mode, which is not affected by this issue. +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + +# Mesos Docker Support + +Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` +in your [SparkConf](configuration.html#spark-properties). + +The Docker image used must have an appropriate version of Spark already part of the image, or you can +have Mesos download Spark via the usual methods. + +Requires Mesos version 0.20.1 or later. # Running Alongside Hadoop @@ -184,6 +216,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration @@ -197,7 +243,11 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - Set the run mode for Spark on Mesos. For more information about the run mode, refer to #Mesos Run Mode section above. + If set to "true", runs over Mesos clusters in + "coarse-grained" sharing mode, + where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per + Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use + for the whole duration of the Spark job. @@ -209,21 +259,109 @@ See the [configuration page](configuration.html) for information on Spark config Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + + spark.mesos.mesosExecutor.cores + 1.0 + + (Fine-grained mode only) Number of cores to give each Mesos executor. This does not + include the cores used to run the Spark tasks. In other words, even if no Spark task + is being run, each Mesos executor will occupy the number of cores configured here. + The value can be a floating point number. + + + + spark.mesos.executor.docker.image + (none) + + Set the name of the docker image that the Spark executors will run in. The selected + image must have Spark installed, as well as a compatible version of the Mesos library. + The installed path of Spark in the image can be specified with spark.mesos.executor.home; + the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY. + + + + spark.mesos.executor.docker.volumes + (none) + + Set the list of volumes which will be mounted into the Docker image, which was set using + spark.mesos.executor.docker.image. The format of this property is a comma-separated list of + mappings following the form passed to docker run -v. That is they take the form: + +
    [host_path:]container_path[:ro|:rw]
    + + + + spark.mesos.executor.docker.portmaps + (none) + + Set the list of incoming ports exposed by the Docker image, which was set using + spark.mesos.executor.docker.image. The format of this property is a comma-separated list of + mappings which take the form: + +
    host_port:container_port[:tcp|:udp]
    + + spark.mesos.executor.home - SPARK_HOME + driver side SPARK_HOME - The location where the mesos executor will look for Spark binaries to execute, and uses the SPARK_HOME setting on default. - This variable is only used when no spark.executor.uri is provided, and assumes Spark is installed on the specified location - on each slave. + Set the directory in which Spark is installed on the executors in Mesos. By default, the + executors will simply use the driver's Spark home directory, which may not be visible to + them. Note that this is only relevant if a Spark binary package is not specified through + spark.executor.uri. spark.mesos.executor.memoryOverhead - 384 + executor memory * 0.10, with minimum of 384 + + The amount of additional memory, specified in MB, to be allocated per executor. By default, + the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the final overhead will be this value. + + + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + + + spark.mesos.principal + Framework principal to authenticate to Mesos + + Set the principal with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.secret + Framework secret to authenticate to Mesos + + Set the secret with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.role + Role for the Spark framework + + Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations + and resource weight sharing. + + + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. - The amount of memory that Mesos executor will request for the task to account for the overhead of running the executor itself. - The final total amount of memory allocated is the maximum value between executor memory plus memoryOverhead, and overhead fraction (1.07) plus the executor memory. + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
      +
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • +
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • +
    • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
    • +
    • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
    • +
    • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
    • +
    diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68ab127bcf08..5159ef9e3394 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,6 +7,51 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +To launch a Spark application in `yarn-cluster` mode: + + `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. @@ -17,6 +62,38 @@ To build Spark yourself, refer to [Building Spark](building-spark.html). Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -48,10 +125,10 @@ Most of the configs are the same for Spark on YARN as for other deployment modes - + @@ -71,9 +148,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes - + + + + + + @@ -87,7 +177,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -104,9 +195,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Comma-separated list of files to be placed in the working directory of each executor. + + + + + - + @@ -125,6 +223,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + + + + + @@ -161,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -178,7 +283,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes + + + + + @@ -189,84 +301,91 @@ Most of the configs are the same for Spark on YARN as for other deployment modes It should be no larger than the global number of max attempts in the YARN configuration. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    spark.yarn.am.waitTime100000100s - In yarn-cluster mode, time in milliseconds for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
    spark.yarn.scheduler.heartbeat.interval-ms50003000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. + The value is capped at half the value of YARN's configuration for the expiry interval + (yarn.am.liveness-monitor.expiry-interval-ms). +
    spark.yarn.scheduler.initial-allocation.interval200ms + The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager + when there are pending container allocation requests. It should be no larger than + spark.yarn.scheduler.heartbeat.interval-ms. The allocation interval will doubled on + successive eager heartbeats if pending containers still exist, until + spark.yarn.scheduler.heartbeat.interval-ms is reached.
    spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`.
    spark.executor.instances2 + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. +
    spark.yarn.executor.memoryOverheadexecutorMemory * 0.07, with minimum of 384 executorMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%).
    spark.yarn.am.port(random) + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. +
    spark.yarn.queue default Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
    (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead. +
    spark.yarn.am.extraLibraryPath(none) + Set a special library path to use when launching the application master in client mode.
    spark.yarn.submit.waitAppCompletiontrue + In YARN cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive reporting the application's status. + Otherwise, the client process will exit after submission. +
    spark.yarn.executor.nodeLabelExpression(none) + A YARN node label expression that restricts the set of nodes executors will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. +
    spark.yarn.tags(none) + Comma-separated list of strings to pass through as YARN application tags appearing + in YARN ApplicationReports, which can be used for filtering when querying YARN apps. +
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure HDFS. +
    spark.yarn.config.gatewayPath(none) + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

    + The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

    + For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. +

    spark.yarn.config.replacementPath(none) + See spark.yarn.config.gatewayPath. +
    spark.yarn.security.tokens.${service}.enabledtrue + Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. + By default, delegation tokens for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run. +

    + Currently supported services are: hive, hbase +

    -# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. +- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/security.md b/docs/security.md index 6e0a54fbc4ad..d4ffa60e59a3 100644 --- a/docs/security.md +++ b/docs/security.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Security +displayTitle: Spark Security +title: Security --- Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: @@ -31,6 +32,8 @@ SSL must be configured on each node and configured for each component involved i ### YARN mode The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. +For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. + ### Standalone mode The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 5c6084fb4625..2fe9ec3542b2 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./bin/spark-class org.apache.spark.deploy.worker.Worker spark://IP:PORT + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -77,10 +77,11 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. -Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: +Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`: - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `sbin/start-slave.sh` - Starts a slave instance on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of slaves as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. - `sbin/stop-slaves.sh` - Stops all slave instances on the machines specified in the `conf/slaves` file. @@ -151,7 +152,7 @@ You can optionally configure the cluster further by setting environment variable SPARK_DAEMON_MEMORY - Memory to allocate to the Spark master and worker daemons themselves (default: 512m). + Memory to allocate to the Spark master and worker daemons themselves (default: 1g). SPARK_DAEMON_JAVA_OPTS @@ -222,8 +223,7 @@ SPARK_WORKER_OPTS supports the following system properties: false Enable periodic cleanup of worker / application directories. Note that this only affects standalone - mode, as YARN works differently. Applications directories are cleaned up regardless of whether - the application is still running. + mode, as YARN works differently. Only the directories of stopped applications are cleaned up. diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 000000000000..7139d16b4a06 --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,267 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, +which can be created from the SparkContext. If you are working from the SparkR shell, the +`SQLContext` and `SparkContext` should already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` +you can specify the packages with the `packages` argument. + +
    +{% highlight r %} +sc <- sparkR.init(sparkPackages="com.databricks:spark-csv_2.11:1.0.3") +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- sql(hiveContext, "FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    + +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a linear model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) 2.2513930 +##Sepal_Width 0.8035609 +##Species_versicolor 1.4587432 +##Species_virginica 1.9468169 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index be8c5c2c1522..33e7893d7bd0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark SQL Programming Guide +displayTitle: Spark SQL and DataFrame Guide +title: Spark SQL and DataFrames --- * This will become a table of contents (this text will be scraped). @@ -8,168 +9,437 @@ title: Spark SQL Programming Guide # Overview +Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. + +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. + +# DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. + +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). + +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. + + +## Starting Point: SQLContext +
    -Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using -Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +The entry point into all functionality in Spark SQL is the +[`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`. +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ +{% endhighlight %}
    -
    -Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using -Spark. At the core of this component is a new type of RDD, -[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with -a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table -in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +
    + +The entry point into all functionality in Spark SQL is the +[`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} +
    -Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using -Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of -[Row](api/python/pyspark.sql.Row-class.html) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +The entry point into all relational functionality in Spark is the +[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) +{% endhighlight %} + +
    + +
    + +The entry point into all relational functionality in Spark is the +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} -All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
    -**Spark SQL is currently an alpha component. While we will minimize API changes, some APIs may change in future releases.** +In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a +superset of the functionality provided by the basic `SQLContext`. Additional features include +the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +existing Hive setup, and all of the data sources available to a `SQLContext` are still available. +`HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +to feature parity with a `HiveContext`. + +The specific variant of SQL that is used to parse queries can also be selected using the +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +this is recommended for most use cases. + + +## Creating DataFrames -*************************************************************************************************** +With a `SQLContext`, applications can create `DataFrame`s from an existing `RDD`, from a Hive table, or from data sources. -# Getting Started +As an example, the following creates a `DataFrame` based on the content of a JSON file:
    +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +val df = sqlContext.read.json("examples/src/main/resources/people.json") + +// Displays the content of the DataFrame to stdout +df.show() +{% endhighlight %} + +
    + +
    +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); + +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); + +// Displays the content of the DataFrame to stdout +df.show(); +{% endhighlight %} + +
    + +
    +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + +df = sqlContext.read.json("examples/src/main/resources/people.json") + +# Displays the content of the DataFrame to stdout +df.show() +{% endhighlight %} + +
    + +
    +{% highlight r %} +sqlContext <- SQLContext(sc) + +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Displays the content of the DataFrame to stdout +showDF(df) +{% endhighlight %} + +
    + +
    -The entry point into all relational functionality in Spark is the -[SQLContext](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic SQLContext, all you need is a SparkContext. +## DataFrame Operations + +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +Here we include some basic examples of structured data processing using DataFrames: + +
    +
    {% highlight scala %} val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD -{% endhighlight %} +// Create the DataFrame +val df = sqlContext.read.json("examples/src/main/resources/people.json") -In addition to the basic SQLContext, you can also create a HiveContext, which provides a -superset of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +// Show the content of the DataFrame +df.show() +// age name +// null Michael +// 30 Andy +// 19 Justin + +// Print the schema in a tree format +df.printSchema() +// root +// |-- age: long (nullable = true) +// |-- name: string (nullable = true) + +// Select only the "name" column +df.select("name").show() +// name +// Michael +// Andy +// Justin + +// Select everybody, but increment the age by 1 +df.select(df("name"), df("age") + 1).show() +// name (age + 1) +// Michael null +// Andy 31 +// Justin 20 + +// Select people older than 21 +df.filter(df("age") > 21).show() +// age name +// 30 Andy + +// Count people by age +df.groupBy("age").count().show() +// age count +// null 1 +// 19 1 +// 30 1 +{% endhighlight %}
    +{% highlight java %} +JavaSparkContext sc // An existing SparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) -The entry point into all relational functionality in Spark is the -[JavaSQLContext](api/scala/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one -of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. +// Create the DataFrame +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); -{% highlight java %} -JavaSparkContext sc = ...; // An existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); -{% endhighlight %} +// Show the content of the DataFrame +df.show(); +// age name +// null Michael +// 30 Andy +// 19 Justin -In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict -super set of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +// Print the schema in a tree format +df.printSchema(); +// root +// |-- age: long (nullable = true) +// |-- name: string (nullable = true) + +// Select only the "name" column +df.select("name").show(); +// name +// Michael +// Andy +// Justin + +// Select everybody, but increment the age by 1 +df.select(df.col("name"), df.col("age").plus(1)).show(); +// name (age + 1) +// Michael null +// Andy 31 +// Justin 20 + +// Select people older than 21 +df.filter(df.col("age").gt(21)).show(); +// age name +// 30 Andy + +// Count people by age +df.groupBy("age").count().show(); +// age count +// null 1 +// 19 1 +// 30 1 +{% endhighlight %}
    - -The entry point into all relational functionality in Spark is the -[SQLContext](api/python/pyspark.sql.SQLContext-class.html) class, or one -of its decedents. To create a basic SQLContext, all you need is a SparkContext. +In Python it's possible to access a DataFrame's columns either by attribute +(`df.age`) or by indexing (`df['age']`). While the former is convenient for +interactive data exploration, users are highly encouraged to use the +latter form, which is future proof and won't break with column names that +are also attributes on the DataFrame class. {% highlight python %} from pyspark.sql import SQLContext sqlContext = SQLContext(sc) + +# Create the DataFrame +df = sqlContext.read.json("examples/src/main/resources/people.json") + +# Show the content of the DataFrame +df.show() +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +df.printSchema() +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +df.select("name").show() +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +df.select(df['name'], df['age'] + 1).show() +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +df.filter(df['age'] > 21).show() +## age name +## 30 Andy + +# Count people by age +df.groupBy("age").count().show() +## age count +## null 1 +## 19 1 +## 30 1 + {% endhighlight %} -In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict -super set of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +
    + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) + +# Create the DataFrame +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Show the content of the DataFrame +showDF(df) +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +printSchema(df) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +showDF(select(df, "name")) +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +showDF(select(df, df$name, df$age + 1)) +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +showDF(where(df, df$age > 21)) +## age name +## 30 Andy + +# Count people by age +showDF(count(groupBy(df, "age"))) +## age count +## null 1 +## 19 1 +## 30 1 + +{% endhighlight %}
    -The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, - this is recommended for most use cases. -# Data Sources +## Running SQL Queries Programmatically + +The `sql` function on a `SQLContext` enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +
    +{% highlight scala %} +val sqlContext = ... // An existing SQLContext +val df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
    + +
    +{% highlight java %} +SQLContext sqlContext = ... // An existing SQLContext +DataFrame df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) +df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
    + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +df <- sql(sqlContext, "SELECT * FROM table") +{% endhighlight %} +
    + +
    -Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section -describes the various methods for loading data into a SchemaRDD. -## RDDs +## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. -The second method for creating SchemaRDDs is through a programmatic interface that allows you to +The second method for creating DataFrames is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows -you to construct SchemaRDDs when the columns and their types are not known until runtime. +you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection
    -The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes -to a SchemaRDD. The case class +The Scala interface for Spark SQL supports automatically converting an RDD containing case classes +to a DataFrame. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex -types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be +types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -177,15 +447,22 @@ import sqlContext.createSchemaRDD case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. -val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) +val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF() people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. -val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +val teenagers = sqlContext.sql("SELECT name, age FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. -// The columns of a row in the result can be accessed by ordinal. +// The results of SQL queries are DataFrames and support all the normal RDD operations. +// The columns of a row in the result can be accessed by field index: teenagers.map(t => "Name: " + t(0)).collect().foreach(println) + +// or by field name: +teenagers.map(t => "Name: " + t.getAs[String]("name")).collect().foreach(println) + +// row.getValuesMap[T] retrieves multiple columns at once into a Map[String, T] +teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(println) +// Map("name" -> "Justin", "age" -> 19) {% endhighlight %}
    @@ -193,7 +470,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -224,12 +501,12 @@ public static class Person implements Serializable { {% endhighlight %} -A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object +A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -246,15 +523,15 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); +DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List teenagerNames = teenagers.map(new Function() { +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -266,7 +543,7 @@ List teenagerNames = teenagers.map(new Function() {
    -Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we @@ -283,17 +560,17 @@ lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) -# Infer the schema, and register the SchemaRDD as a table. -schemaPeople = sqlContext.inferSchema(people) +# Infer the schema, and register the DataFrame as a table. +schemaPeople = sqlContext.createDataFrame(people) schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %}
    @@ -309,12 +586,12 @@ for teenName in teenNames.collect(): When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided by `SQLContext`. For example: @@ -328,8 +605,11 @@ val people = sc.textFile("examples/src/main/resources/people.txt") // The schema is encoded in a string val schemaString = "name age" -// Import Spark SQL data types and Row. -import org.apache.spark.sql._ +// Import Row. +import org.apache.spark.sql.Row; + +// Import Spark SQL data types +import org.apache.spark.sql.types.{StructType,StructField,StringType}; // Generate the schema based on the string of schema val schema = @@ -340,16 +620,16 @@ val schema = val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) // Apply the schema to the RDD. -val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) +val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema) -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people") +// Register the DataFrames as a table. +peopleDataFrame.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val results = sqlContext.sql("SELECT name FROM people") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. -// The columns of a row in the result can be accessed by ordinal. +// The results of SQL queries are DataFrames and support all the normal RDD operations. +// The columns of a row in the result can be accessed by field index or by field name. results.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -361,26 +641,29 @@ results.map(t => "Name: " + t(0)).collect().foreach(println) When JavaBean classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided -by `JavaSQLContext`. +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided +by `SQLContext`. For example: {% highlight java %} -// Import factory methods provided by DataType. -import org.apache.spark.sql.api.java.DataType +import org.apache.spark.api.java.function.Function; +// Import factory methods provided by DataTypes. +import org.apache.spark.sql.types.DataTypes; // Import StructType and StructField -import org.apache.spark.sql.api.java.StructType -import org.apache.spark.sql.api.java.StructField +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.StructField; // Import Row. -import org.apache.spark.sql.api.java.Row +import org.apache.spark.sql.Row; +// Import RowFactory. +import org.apache.spark.sql.RowFactory; // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); @@ -391,31 +674,31 @@ String schemaString = "name age"; // Generate the schema based on the string of schema List fields = new ArrayList(); for (String fieldName: schemaString.split(" ")) { - fields.add(DataType.createStructField(fieldName, DataType.StringType, true)); + fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true)); } -StructType schema = DataType.createStructType(fields); +StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows. JavaRDD rowRDD = people.map( new Function() { public Row call(String record) throws Exception { String[] fields = record.split(","); - return Row.create(fields[0], fields[1].trim()); + return RowFactory.create(fields[0], fields[1].trim()); } }); // Apply the schema to the RDD. -JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); +DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema); -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people"); +// Register the DataFrame as a table. +peopleDataFrame.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); +DataFrame results = sqlContext.sql("SELECT name FROM people"); -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List names = results.map(new Function() { +List names = results.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -430,17 +713,18 @@ List names = results.map(new Function() { When a dictionary of kwargs cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of tuples or lists in the RDD created in the step 1. -3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. +3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`. For example: {% highlight python %} # Import SQLContext and data types -from pyspark.sql import * +from pyspark.sql import SQLContext +from pyspark.sql.types import * # sc is an existing SparkContext. sqlContext = SQLContext(sc) @@ -457,24 +741,189 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt schema = StructType(fields) # Apply the schema to the RDD. -schemaPeople = sqlContext.applySchema(people, schema) +schemaPeople = sqlContext.createDataFrame(people, schema) -# Register the SchemaRDD as a table. +# Register the DataFrame as a table. schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. names = results.map(lambda p: "Name: " + p.name) for name in names.collect(): - print name + print(name) +{% endhighlight %} + +
    + +
    + + +# Data Sources + +Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. +A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a DataFrame as a table allows you to run SQL queries over its data. This section +describes the general methods for loading and saving data using the Spark Data Sources and then +goes into specific options that are available for the built-in data sources. + +## Generic Load/Save Functions + +In the simplest form, the default data source (`parquet` unless otherwise configured by +`spark.sql.sources.default`) will be used for all operations. + +
    +
    + +{% highlight scala %} +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") +{% endhighlight %} + +
    + +
    + +{% highlight java %} + +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); + +{% endhighlight %} + +
    + +
    + +{% highlight python %} + +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") + +{% endhighlight %} + +
    + +
    + +{% highlight r %} +df <- loadDF(sqlContext, "people.parquet") +saveDF(select(df, "name", "age"), "namesAndAges.parquet") +{% endhighlight %} + +
    +
    + +### Manually Specifying Options + +You can also manually specify the data source that will be used along with any extra options +that you would like to pass to the data source. Data sources are specified by their fully qualified +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +using this syntax. + +
    +
    + +{% highlight scala %} +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("parquet").save("namesAndAges.parquet") +{% endhighlight %} + +
    + +
    + +{% highlight java %} + +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); + {% endhighlight %} +
    + +
    + +{% highlight python %} + +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") + +{% endhighlight %}
    +
    + +{% highlight r %} + +df <- loadDF(sqlContext, "people.json", "json") +saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") + +{% endhighlight %}
    +
    + +### Save Modes + +Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. +Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +new data. + + + + + + + + + + + + + + + + + + + + + + + +
    Scala/JavaPythonMeaning
    SaveMode.ErrorIfExists (default)"error" (default) + When saving a DataFrame to a data source, if data already exists, + an exception is expected to be thrown. +
    SaveMode.Append"append" + When saving a DataFrame to a data source, if data/table already exists, + contents of the DataFrame are expected to be appended to existing data. +
    SaveMode.Overwrite"overwrite" + Overwrite mode means that when saving a DataFrame to a data source, + if data/table already exists, existing data is expected to be overwritten by the contents of + the DataFrame. +
    SaveMode.Ignore"ignore" + Ignore mode means that when saving a DataFrame to a data source, if data already exists, + the save operation is expected to not save the contents of the DataFrame and to not + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. +
    + +### Saving to Persistent Tables + +When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +will still exist even after your Spark program has restarted, as long as you maintain your connection +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +method on a `SQLContext` with the name of the table. + +By default `saveAsTable` will create a "managed table", meaning that the location of the data will +be controlled by the metastore. Managed tables will also have their data deleted automatically +when a table is dropped. ## Parquet Files @@ -492,17 +941,17 @@ Using the data from the above example: {% highlight scala %} // sqlContext from the previous example is used in this example. -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a SchemaRDD. -val parquetFile = sqlContext.parquetFile("people.parquet") +// The result of loading a Parquet file is also a DataFrame. +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -517,19 +966,19 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% highlight java %} // sqlContext from the previous example is used in this example. -JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. +DataFrame schemaPeople = ... // The DataFrame from the previous example. -// JavaSchemaRDDs can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +// DataFrames can be saved as Parquet files, maintaining the schema information. +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); +// The result of loading a parquet file is also a DataFrame. +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); -//Parquet files can also be registered as tables and then used in SQL statements. +// Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); -List teenagerNames = teenagers.map(new Function() { +DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -543,48 +992,334 @@ List teenagerNames = teenagers.map(new Function() { {% highlight python %} # sqlContext from the previous example is used in this example. -schemaPeople # The SchemaRDD from the previous example. +schemaPeople # The DataFrame from the previous example. -# SchemaRDDs can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +# DataFrames can be saved as Parquet files, maintaining the schema information. +schemaPeople.write.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a SchemaRDD. -parquetFile = sqlContext.parquetFile("people.parquet") +# The result of loading a parquet file is also a DataFrame. +parquetFile = sqlContext.read.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %}
    +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +schemaPeople # The DataFrame from the previous example. + +# DataFrames can be saved as Parquet files, maintaining the schema information. +saveAsParquetFile(schemaPeople, "people.parquet") + +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# The result of loading a parquet file is also a DataFrame. +parquetFile <- parquetFile(sqlContext, "people.parquet") + +# Parquet files can also be registered as tables and then used in SQL statements. +registerTempTable(parquetFile, "parquetFile"); +teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) +for (teenName in collect(teenNames)) { + cat(teenName, "\n") +} +{% endhighlight %} +
    -### Configuration +
    -Configuration of Parquet can be done using the `setConf` method on SQLContext or by running -`SET key=value` commands using SQL. +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} - - - - - - - @@ -844,20 +843,22 @@ Some of the common ones are as follows.
    Property NameDefaultMeaning
    spark.sql.parquet.binaryAsStringfalse + + +
    + +{% highlight sql %} + +CREATE TEMPORARY TABLE parquetTable +USING org.apache.spark.sql.parquet +OPTIONS ( + path "examples/src/main/resources/people.parquet" +) + +SELECT * FROM parquetTable + +{% endhighlight %} + +
    + + + +### Partition Discovery + +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +table, data are usually stored in different directories, with partitioning column values encoded in +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For example, we can store all our previously used +population data into a partitioned table using the following directory structure, with two extra +columns, `gender` and `country` as partitioning columns: + +{% highlight text %} + +path +└── to + └── table + ├── gender=male + │   ├── ... + │   │ + │   ├── country=US + │   │   └── data.parquet + │   ├── country=CN + │   │   └── data.parquet + │   └── ... + └── gender=female +    ├── ... +    │ +    ├── country=US +    │   └── data.parquet +    ├── country=CN +    │   └── data.parquet +    └── ... + +{% endhighlight %} + +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: + +{% highlight text %} + +root +|-- name: string (nullable = true) +|-- age: long (nullable = true) +|-- gender: string (nullable = true) +|-- country: string (nullable = true) + +{% endhighlight %} + +Notice that the data types of the partitioning columns are automatically inferred. Currently, +numeric data types and string type are supported. Sometimes users may not want to automatically +infer the data types of the partitioning columns. For these use cases, the automatic type inference +can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to +`true`. When type inference is disabled, string type will be used for the partitioning columns. + + +### Schema Merging + +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +source is now able to automatically detect this case and merge schemas of all these files. + +
    + +
    + +{% highlight scala %} +// sqlContext from the previous example is used in this example. +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ + +// Create a simple DataFrame, stored into a partition directory +val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +df1.write.parquet("data/test_table/key=1") + +// Create another DataFrame in a new partition directory, +// adding a new column and dropping an existing column +val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +df2.write.parquet("data/test_table/key=2") + +// Read the partitioned table +val df3 = sqlContext.read.parquet("data/test_table") +df3.printSchema() + +// The final schema consists of all 3 columns in the Parquet files together +// with the partitioning column appeared in the partition directory paths. +// root +// |-- single: int (nullable = true) +// |-- double: int (nullable = true) +// |-- triple: int (nullable = true) +// |-- key : int (nullable = true) +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ + .map(lambda i: Row(single=i, double=i * 2))) +df1.save("data/test_table/key=1", "parquet") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) + .map(lambda i: Row(single=i, triple=i * 3))) +df2.save("data/test_table/key=2", "parquet") + +# Read the partitioned table +df3 = sqlContext.load("data/test_table", "parquet") +df3.printSchema() + +# The final schema consists of all 3 columns in the Parquet files together +# with the partitioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + +
    + +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") + +# Read the partitioned table +df3 <- loadDF(sqlContext, "data/test_table", "parquet") +printSchema(df3) + +# The final schema consists of all 3 columns in the Parquet files together +# with the partitioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + +
    + +
    + +### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
    + +
    + +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
    + +
    + +### Configuration + +Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running +`SET key=value` commands using SQL. + + + + + + + + + + + + @@ -597,13 +1332,8 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or - - + + @@ -613,19 +1343,45 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or support. + + + + +
    Property NameDefaultMeaning
    spark.sql.parquet.binaryAsStringfalse Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do - not differentiate between binary data and strings when writing out the Parquet schema. This + not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems.
    spark.sql.parquet.int96AsTimestamptrue + Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also + store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. +
    spark.sql.parquet.cacheMetadata true - Turns on caching of Parquet schema metadata. Can speed up querying of static data. + Turns on caching of Parquet schema metadata. Can speed up querying of static data.
    spark.sql.parquet.filterPushdownfalse - Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Paruet 1.6.0rc3 (PARQUET-136). - However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn - this feature on. - trueEnables Parquet filter push-down optimization when set to true.
    spark.sql.hive.convertMetastoreParquet
    spark.sql.parquet.output.committer.classorg.apache.parquet.hadoop.
    ParquetOutputCommitter
    +

    + The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

    +

    + Note: +

      +
    • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
    • +
    • + This option overrides spark.sql.sources.
      outputCommitterClass
      . +
    • +
    +

    +

    + Spark SQL comes with a builtin + org.apache.spark.sql.
    parquet.DirectParquetOutputCommitter
    , which can be more + efficient then the default Parquet output committer when writing data to S3. +

    +
    ## JSON Datasets
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. -This conversion can be done using one of two methods in a SQLContext: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -636,8 +1392,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a SchemaRDD from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -645,41 +1400,37 @@ people.printSchema() // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this SchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. -This conversion can be done using one of two methods in a JavaSQLContext : +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a JavaSchemaRDD from the file(s) pointed to by path -JavaSchemaRDD people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -687,29 +1438,26 @@ people.printSchema(); // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this JavaSchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); -// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. -This conversion can be done using one of two methods in a SQLContext: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using `SQLContext.read.json` on a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -720,9 +1468,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a SchemaRDD from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -730,13 +1476,13 @@ people.printSchema() # |-- age: integer (nullable = true) # |-- name: string (nullable = true) -# Register this SchemaRDD as a table. +# Register this DataFrame as a table. people.registerTempTable("people") -# SQL statements can be run by using the sql methods provided by sqlContext. +# SQL statements can be run by using the sql methods provided by `sqlContext`. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# Alternatively, a DataFrame can be created for a JSON dataset represented by # an RDD[String] storing one JSON object per string. anotherPeopleRDD = sc.parallelize([ '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) @@ -744,6 +1490,55 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %}
    +
    +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. + +Note that the file that is offered as _a json file_ is not a typical JSON file. Each +line must contain a separate, self-contained valid JSON object. As a consequence, +a regular multi-line JSON file will most often fail. + +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRSQL.init(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path <- "examples/src/main/resources/people.json" +# Create a DataFrame from the file(s) pointed to by path +people <- jsonFile(sqlContext, path) + +# The inferred schema can be visualized using the printSchema() method. +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql methods provided by `sqlContext`. +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +{% endhighlight %} +
    + +
    + +{% highlight sql %} + +CREATE TEMPORARY TABLE jsonTable +USING org.apache.spark.sql.json +OPTIONS ( + path "examples/src/main/resources/people.json" +) + +SELECT * FROM jsonTable + +{% endhighlight %} + +
    +
    ## Hive Tables @@ -755,75 +1550,287 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the +YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the +`spark-submit` command. + + +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the +hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current +directory. + +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.hive.HiveContext(sc) + +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +// Queries are expressed in HiveQL +sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) +{% endhighlight %} + +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to +the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be +expressed in HiveQL. + +{% highlight java %} +// sc is an existing JavaSparkContext. +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc); + +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); + +// Queries are expressed in HiveQL. +Row[] results = sqlContext.sql("FROM src SELECT key, value").collect(); + +{% endhighlight %} + +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. +{% highlight python %} +# sc is an existing SparkContext. +from pyspark.sql import HiveContext +sqlContext = HiveContext(sc) + +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results = sqlContext.sql("FROM src SELECT key, value").collect() + +{% endhighlight %} + +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRHive.init(sc) + +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- collect(sql(sqlContext, "FROM src SELECT key, value")) + +{% endhighlight %} + +
    +
    + +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. + +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. +
    spark.sql.hive.metastore.jarsbuiltin + Location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    +
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc
    +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +
    spark.sql.hive.metastore.barrierPrefixes(empty) +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +
    + + +## JDBC To Other Databases + +Spark SQL also includes a data source that can read data from other databases using JDBC. This +functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). +This is because the results are returned +as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. +The JDBC data source is also easier to use from Java or Python as it does not require the user to +provide a ClassTag. +(Note that this is different than the Spark SQL JDBC server, which allows other applications to +run queries using Spark SQL). + +To get started you will need to include the JDBC driver for you particular database on the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the +following command: + +{% highlight bash %} +SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell +{% endhighlight %} + +Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using +the Data Sources API. The following options are supported: + + + + + + + + + + + + + + + + + + + + +
    Property NameMeaning
    url + The JDBC URL to connect to. +
    dbtable + The JDBC table that should be read. Note that anything that is valid in a FROM clause of + a SQL query can be used. For example, instead of a full table you could also use a + subquery in parentheses. +
    driver + The class name of the JDBC driver needed to connect to this URL. This class will be loaded + on the master and workers before running an JDBC commands to allow the driver to + register itself with the JDBC subsystem. +
    partitionColumn, lowerBound, upperBound, numPartitions + These options must all be specified if any of them is specified. They describe how to + partition the table when reading in parallel from multiple workers. + partitionColumn must be a numeric column from the table in question. Notice + that lowerBound and upperBound are just used to decide the + partition stride, not for filtering the rows in table. So all rows in the table will be + partitioned and returned. +
    -When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a HiveContext. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current -directory. - {% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.hive.HiveContext(sc) - -sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") - -// Queries are expressed in HiveQL -sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
    -When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be -expressed in HiveQL. - {% highlight java %} -// sc is an existing JavaSparkContext. -JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); - -sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); -// Queries are expressed in HiveQL. -Row[] results = sqlContext.sql("FROM src SELECT key, value").collect(); +Map options = new HashMap(); +options.put("url", "jdbc:postgresql:dbserver"); +options.put("dbtable", "schema.tablename"); +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} +
    -When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be -expressed in HiveQL. - {% highlight python %} -# sc is an existing SparkContext. -from pyspark.sql import HiveContext -sqlContext = HiveContext(sc) -sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() -# Queries can be expressed in HiveQL. -results = sqlContext.sql("FROM src SELECT key, value").collect() +{% endhighlight %} + +
    + +
    + +{% highlight r %} + +df <- loadDF(sqlContext, source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") + +{% endhighlight %} + +
    + +
    + +{% highlight sql %} + +CREATE TEMPORARY TABLE jdbcTable +USING org.apache.spark.sql.jdbc +OPTIONS ( + url "jdbc:postgresql:dbserver", + dbtable "schema.tablename" +) {% endhighlight %}
    +## Troubleshooting + + * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. + * Some databases, such as H2, convert all names to upper case. You'll need to use upper case to refer to those names in Spark SQL. + + # Performance Tuning For some workloads it is possible to improve performance by either caching data in memory, or by @@ -831,11 +1838,11 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. -Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +Configuration of in-memory caching can be done using the `setConf` method on `SQLContext` or by running `SET key=value` commands using SQL. @@ -873,16 +1880,15 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. - - + + @@ -892,17 +1898,25 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. + + + + +
    spark.sql.codegenfalsespark.sql.tungsten.enabledtrue - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. - However, for simple queries this can actually slow down query execution. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation.
    spark.sql.planner.externalSorttrue + When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. +
    -# Other SQL Interfaces +# Distributed SQL Engine -Spark SQL also supports interfaces for running SQL queries directly without the need to write any -code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12. +in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -911,7 +1925,7 @@ To start the JDBC/ODBC server, run the following in the Spark directory: This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. By default, the server listens on localhost:10000. You may override this -bahaviour via either environment variables, i.e.: +behaviour via either environment variables, i.e.: {% highlight bash %} export HIVE_SERVER2_THRIFT_PORT= @@ -947,10 +1961,10 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script that comes with Hive. -Thrift JDBC server also supports sending thrift RPC messages over HTTP transport. -Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: +Thrift JDBC server also supports sending thrift RPC messages over HTTP transport. +Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: - hive.server2.transport.mode - Set this to value: http + hive.server2.transport.mode - Set this to value: http hive.server2.thrift.http.port - HTTP port number fo listen on; default is 10001 hive.server2.http.endpoint - HTTP endpoint; default is cliservice @@ -972,7 +1986,165 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. -# Compatibility with Other Systems +# Migration Guide + +## Upgrading from Spark SQL 1.3 to 1.4 + +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + +Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`. + +
    +
    +{% highlight scala %} + +// In 1.3.x, in order for the grouping column "department" to show up, +// it must be included explicitly as part of the agg function call. +df.groupBy("department").agg($"department", max("age"), sum("expense")) + +// In 1.4+, grouping column "department" is included automatically. +df.groupBy("department").agg(max("age"), sum("expense")) + +// Revert to 1.3 behavior (not retaining grouping column) by: +sqlContext.setConf("spark.sql.retainGroupColumns", "false") + +{% endhighlight %} +
    + +
    +{% highlight java %} + +// In 1.3.x, in order for the grouping column "department" to show up, +// it must be included explicitly as part of the agg function call. +df.groupBy("department").agg(col("department"), max("age"), sum("expense")); + +// In 1.4+, grouping column "department" is included automatically. +df.groupBy("department").agg(max("age"), sum("expense")); + +// Revert to 1.3 behavior (not retaining grouping column) by: +sqlContext.setConf("spark.sql.retainGroupColumns", "false"); + +{% endhighlight %} +
    + +
    +{% highlight python %} + +import pyspark.sql.functions as func + +# In 1.3.x, in order for the grouping column "department" to show up, +# it must be included explicitly as part of the agg function call. +df.groupBy("department").agg("department"), func.max("age"), func.sum("expense")) + +# In 1.4+, grouping column "department" is included automatically. +df.groupBy("department").agg(func.max("age"), func.sum("expense")) + +# Revert to 1.3.x behavior (not retaining grouping column) by: +sqlContext.setConf("spark.sql.retainGroupColumns", "false") + +{% endhighlight %} +
    + +
    + + +## Upgrading from Spark SQL 1.0-1.2 to 1.3 + +In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +as unstable (i.e., DeveloperAPI or Experimental). + +#### Rename of SchemaRDD to DataFrame + +The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +directly, but instead provide most of the functionality that RDDs provide though their own +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. + +In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +some use cases. It is still recommended that users update their code to use `DataFrame` instead. +Java and Python users will need to update their code. + +#### Unification of the Java and Scala APIs + +Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +use types that are usable from both languages (i.e. `Array` instead of language specific collections). +In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading +is used instead. + +Additionally the Java specific types API has been removed. Users of both Scala and Java should +use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. + + +#### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) + +Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. +Users should now write `import sqlContext.implicits._`. + +Additionally, the implicit conversions now only augment RDDs that are composed of `Product`s (i.e., +case classes or tuples) with a method `toDF`, instead of applying automatically. + +When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`import org.apache.spark.sql.functions._`. + +#### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) + +Spark 1.3 removes the type aliases that were present in the base sql package for `DataType`. Users +should instead import the classes in `org.apache.spark.sql.types` + +#### UDF Registration Moved to `sqlContext.udf` (Java & Scala) + +Functions that are used to register UDFs, either for use in the DataFrame DSL or SQL, have been +moved into the udf object in `SQLContext`. + +
    +
    +{% highlight scala %} + +sqlContext.udf.register("strLen", (s: String) => s.length()) + +{% endhighlight %} +
    + +
    +{% highlight java %} + +sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType); + +{% endhighlight %} +
    + +
    + +Python UDF registration is unchanged. + +#### Python DataTypes No Longer Singletons + +When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of +referencing a singleton. ## Migration Guide for Shark User @@ -1092,15 +2264,11 @@ in Hive deployments. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. + **Esoteric Hive Features** -* Tables with partitions using different input formats: In Spark SQL, all table partitions need to - have the same input format. -* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions - (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNION` type and `DATE` type +* `UNION` type * Unique join -* Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at the moment and only supports populating the sizeInBytes field of the hive metastore. @@ -1116,9 +2284,6 @@ less important due to Spark SQL's in-memory computational model. Others are slot releases of Spark SQL. * Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically convert a join to map join: For joining a large table with multiple small tables, - Hive automatically converts the join into a map join. We are adding this auto conversion in the - next release. * Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". * Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still @@ -1129,33 +2294,10 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. -# Writing Language-Integrated Relational Queries - -**Language-Integrated queries are experimental and currently only supported in Scala.** - -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: - -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. - -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +# Data Types - - -# Spark SQL DataType Reference +Spark SQL and DataFrames support the following data types: * Numeric types - `ByteType`: Represents 1-byte signed integer numbers. @@ -1198,10 +2340,10 @@ evaluated by the SQL execution engine. A full list of the functions supported c
    -All data types of Spark SQL are located in the package `org.apache.spark.sql`. +All data types of Spark SQL are located in the package `org.apache.spark.sql.types`. You can access them by doing {% highlight scala %} -import org.apache.spark.sql._ +import org.apache.spark.sql.types._ {% endhighlight %} @@ -1253,7 +2395,7 @@ import org.apache.spark.sql._ - + @@ -1447,7 +2589,7 @@ please use factory methods provided in - +
    DecimalType scala.math.BigDecimal java.math.BigDecimal DecimalType
    StructType org.apache.spark.sql.api.java.Row org.apache.spark.sql.Row DataTypes.createStructType(fields)
    Note: fields is a List or an array of StructFields. @@ -1468,10 +2610,10 @@ please use factory methods provided in
    -All data types of Spark SQL are located in the package of `pyspark.sql`. +All data types of Spark SQL are located in the package of `pyspark.sql.types`. You can access them by doing {% highlight python %} -from pyspark.sql import * +from pyspark.sql.types import * {% endhighlight %} @@ -1618,5 +2760,151 @@ from pyspark.sql import * +
    + +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Data typeValue type in RAPI to access or create a data type
    ByteType + integer
    + Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
    + "byte" +
    ShortType + integer
    + Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
    + "short" +
    IntegerType integer + "integer" +
    LongType + integer
    + Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
    + "long" +
    FloatType + numeric
    + Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
    + "float" +
    DoubleType numeric + "double" +
    DecimalType Not supported + Not supported +
    StringType character + "string" +
    BinaryType raw + "binary" +
    BooleanType logical + "bool" +
    TimestampType POSIXct + "timestamp" +
    DateType Date + "date" +
    ArrayType vector or list + list(type="array", elementType=elementType, containsNull=[containsNull])
    + Note: The default value of containsNull is True. +
    MapType environment + list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
    + Note: The default value of valueContainsNull is True. +
    StructType named list + list(type="struct", fields=fields)
    + Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
    StructField The value type in R of the data type of this field + (For example, integer for a StructField with the data type IntegerType) + list(name=name, type=dataType, nullable=nullable) +
    + +
    + diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 6a2048121f8b..a75587a92adc 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). +the ones for which it has built-in support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. Note that custom receivers can be implemented @@ -21,15 +21,15 @@ A custom receiver must extend this abstract class by implementing two methods - `onStop()`: Things to do to stop receiving data. Both `onStart()` and `onStop()` must not block indefinitely. Typically, `onStart()` would start the threads -that responsible for receiving the data and `onStop()` would ensure that the receiving by those threads +that are responsible for receiving the data, and `onStop()` would ensure that these threads receiving the data are stopped. The receiving threads can also use `isStopped()`, a `Receiver` method, to check whether they should stop receiving data. Once the data is received, that data can be stored inside Spark by calling `store(data)`, which is a method provided by the Receiver class. -There are number of flavours of `store()` which allow you store the received data -record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavour of -`store()` used to implemented a receiver affects its reliability and fault-tolerance semantics. +There are a number of flavors of `store()` which allow one to store the received data +record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavor of +`store()` used to implement a receiver affects its reliability and fault-tolerance semantics. This is discussed [later](#receiver-reliability) in more detail. Any exception in the receiving threads should be caught and handled properly to avoid silent @@ -60,7 +60,7 @@ class CustomReceiver(host: String, port: Int) def onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -123,7 +123,7 @@ public class JavaCustomReceiver extends Receiver { public void onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -167,7 +167,7 @@ public class JavaCustomReceiver extends Receiver { The custom receiver can be used in a Spark Streaming application by using `streamingContext.receiverStream()`. This will create -input DStream using data received by the instance of custom receiver, as shown below +an input DStream using data received by the instance of custom receiver, as shown below:
    @@ -206,22 +206,20 @@ there are two kinds of receivers based on their reliability and fault-tolerance and stored in Spark reliably (that is, replicated successfully). Usually, implementing this receiver involves careful consideration of the semantics of source acknowledgements. -1. *Unreliable Receiver* - These are receivers for unreliable sources that do not support - acknowledging. Even for reliable sources, one may implement an unreliable receiver that - do not go into the complexity of acknowledging correctly. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgement to a source. This can be used for sources that do not support acknowledgement, or even for reliable sources when one does not want or need to go into the complexity of acknowledgement. To implement a *reliable receiver*, you have to use `store(multiple-records)` to store data. -This flavour of `store` is a blocking call which returns only after all the given records have +This flavor of `store` is a blocking call which returns only after all the given records have been stored inside Spark. If the receiver's configured storage level uses replication (enabled by default), then this call returns after replication has completed. Thus it ensures that the data is reliably stored, and the receiver can now acknowledge the -source appropriately. This ensures that no data is caused when the receiver fails in the middle +source appropriately. This ensures that no data is lost when the receiver fails in the middle of replicating data -- the buffered data will not be acknowledged and hence will be later resent by the source. An *unreliable receiver* does not have to implement any of this logic. It can simply receive records from the source and insert them one-at-a-time using `store(single-record)`. While it does -not get the reliability guarantees of `store(multiple-records)`, it has the following advantages. +not get the reliability guarantees of `store(multiple-records)`, it has the following advantages: - The system takes care of chunking that data into appropriate sized blocks (look for block interval in the [Spark Streaming Programming Guide](streaming-programming-guide.html)). diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index ac01dd3d8019..de0461010dae 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +Python API Flume is not yet available in the Python API. + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -56,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java).
    +
    + from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
    Note that the hostname should be the same as the one used by the resource manager in the @@ -64,7 +75,7 @@ configuring Flume agents. 3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). -## Approach 2 (Experimental): Pull-based Approach using a Custom Sink +## Approach 2: Pull-based Approach using a Custom Sink Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. - Flume pushes data into the sink, and the data stays buffered. @@ -97,6 +108,12 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + + groupId = org.apache.commons + artifactId = commons-lang3 + version = 3.3.2 + 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. agent.sinks = spark @@ -127,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
    + from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
    See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 77c0abbbacbd..7571e22575ef 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,58 +2,193 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +## Approach 1: Receiver-based Approach +This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. + +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} -2. **Programming:** In the streaming application code, import `KafkaUtils` and create input DStream as follows. + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows.
    import org.apache.spark.streaming.kafka._ - val kafkaStream = KafkaUtils.createStream( - streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]) + val kafkaStream = KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; - JavaPairReceiverInputDStream kafkaStream = KafkaUtils.createStream( - streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); + JavaPairReceiverInputDStream kafkaStream = + KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); + + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). - See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). +
    +
    + from pyspark.streaming.kafka import KafkaUtils + + kafkaStream = KafkaUtils.createStream(streamingContext, \ + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/kafka_wordcount.py).
    - *Points to remember:* + **Points to remember:** - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. -3. **Deploying:** Package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - -Note that the Kafka receiver used by default is an -[*unreliable* receiver](streaming-programming-guide.html#receiver-reliability) section in the -programming guide). In Spark 1.2, we have added an experimental *reliable* Kafka receiver that -provides stronger -[fault-tolerance guarantees](streaming-programming-guide.html#fault-tolerance-semantics) of zero -data loss on failures. This receiver is automatically used when the write ahead log -(also introduced in Spark 1.2) is enabled -(see [Deployment](#deploying-applications.html) section in the programming guide). This -may reduce the receiving throughput of individual Kafka receivers compared to the unreliable -receivers, but this can be corrected by running -[more receivers in parallel](streaming-programming-guide.html#level-of-parallelism-in-data-receiving) -to increase aggregate throughput. Additionally, it is recommended that the replication of the -received data within Spark be disabled when the write ahead log is enabled as the log is already stored -in a replicated storage system. This can be done by setting the storage level for the input -stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). + +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + +## Approach 2: Direct Approach (No Receivers) +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. + +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). + +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. + +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. + +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). + +Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
    +
    + import org.apache.spark.streaming.kafka._ + + val directKafkaStream = KafkaUtils.createDirectStream[ + [key class], [value class], [key decoder class], [value decoder class] ]( + streamingContext, [map of Kafka parameters], [set of topics to consume]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). +
    +
    + import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream directKafkaStream = + KafkaUtils.createDirectStream(streamingContext, + [key class], [value class], [key decoder class], [value decoder class], + [map of Kafka parameters], [set of topics to consume]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + +
    +
    + from pyspark.streaming.kafka import KafkaUtils + directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py). +
    +
    + + In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. + By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. + + You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. + +
    +
    + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... + } +
    +
    + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference<>(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; + } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; + } + } + ); +
    +
    + Not supported yet +
    +
    + + You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. + + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. + +3. **Deploying:** This is same as the first approach, for Scala, Java and Python. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 379eb513d521..a7bcaec6fcd8 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -32,7 +32,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream val kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]) + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. @@ -44,29 +45,45 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]); + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
    + from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + + kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) + + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. +
    - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream - - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from - - The application name used in the streaming context becomes the Kinesis application name + - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis + sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization. - - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table. + - If the table exists but has incorrect checkpoint information (for a different stream, or + old expired sequenced numbers), then there may be temporary errors. + - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region). + - `[region name]`: Valid Kinesis region names can be found [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html). + - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + In other versions of the API, you can also specify the AWS access key and secret key directly. 3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). @@ -122,12 +139,20 @@ To run the example,
    - bin/run-example streaming.KinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    - bin/run-example streaming.JavaKinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + +
    +
    + + bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name]
    @@ -136,7 +161,7 @@ To run the example, - To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer. - bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10 + bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10 This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e37a2bb37b9a..118ced298f4b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Streaming Programming Guide +displayTitle: Spark Streaming Programming Guide +title: Spark Streaming +description: Spark Streaming programming guide and tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -9,7 +11,7 @@ title: Spark Streaming Programming Guide # Overview Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ, Kinesis or TCP sockets can be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's @@ -50,7 +52,7 @@ different languages. **Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream transformations and almost all the output operations available in Scala and Java interfaces. -However, it has only support for basic sources like text files and text data over sockets. +However, it only has support for basic sources like text files and text data over sockets. APIs for additional sources, like Kafka and Flume, will be available in the future. Further information about available features in the Python API are mentioned throughout this document; look out for the tag @@ -67,15 +69,15 @@ do is as follows.
    -First, we import the names of the Spark Streaming classes, and some implicit -conversions from StreamingContext into our environment, to add useful methods to +First, we import the names of the Spark Streaming classes and some implicit +conversions from StreamingContext into our environment in order to add useful methods to other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the -main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. +main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second. {% highlight scala %} import org.apache.spark._ import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. // The master requires 2 cores to prevent from a starvation scenario. @@ -94,7 +96,7 @@ val lines = ssc.socketTextStream("localhost", 9999) This `lines` DStream represents the stream of data that will be received from the data server. Each record in this DStream is a line of text. Next, we want to split the lines by -space into words. +space characters into words. {% highlight scala %} // Split each line into words @@ -107,7 +109,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Next, we want to count these words. {% highlight scala %} -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Count each word in each batch val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) @@ -187,15 +189,15 @@ Next, we want to count these words. {% highlight java %} // Count each word in each batch -JavaPairDStream pairs = words.map( +JavaPairDStream pairs = words.mapToPair( new PairFunction() { - @Override public Tuple2 call(String s) throws Exception { + @Override public Tuple2 call(String s) { return new Tuple2(s, 1); } }); JavaPairDStream wordCounts = pairs.reduceByKey( new Function2() { - @Override public Integer call(Integer i1, Integer i2) throws Exception { + @Override public Integer call(Integer i1, Integer i2) { return i1 + i2; } }); @@ -430,7 +432,7 @@ some of the common ones are as follows.
    For an up-to-date list, please refer to the -[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) +[Maven repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) for the full list of supported sources and artifacts. *** @@ -461,7 +463,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `StreamingContext` object can also be created from an existing `SparkContext` object. @@ -496,7 +498,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`. @@ -529,7 +531,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details.
    @@ -547,7 +549,7 @@ After a context is defined, you have to do the following. - Once a context has been started, no new streaming computations can be set up or added to it. - Once a context has been stopped, it cannot be restarted. - Only one StreamingContext can be active in a JVM at the same time. -- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false. +- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set the optional parameter of `stop()` called `stopSparkContext` to false. - A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created. *** @@ -581,7 +583,7 @@ the `flatMap` operation is applied on each RDD in the `lines` DStream to generat These underlying RDD transformations are computed by the Spark engine. The DStream operations -hide most of these details and provide the developer with higher-level API for convenience. +hide most of these details and provide the developer with a higher-level API for convenience. These operations are discussed in detail in later sections. *** @@ -598,7 +600,7 @@ data from a source and stores it in Spark's memory for processing. Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. - Example: file systems, socket connections, and Akka actors. + Examples: file systems, socket connections, and Akka actors. - *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -608,11 +610,11 @@ We are going to discuss some of the sources present in each category later in th Note that, if you want to receive multiple streams of data in parallel in your streaming application, you can create multiple input DStreams (discussed further in the [Performance Tuning](#level-of-parallelism-in-data-receiving) section). This will -create multiple receivers which will simultaneously receive multiple data streams. But note that -Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the -Spark Streaming application. Hence, it is important to remember that Spark Streaming application +create multiple receivers which will simultaneously receive multiple data streams. But note that a +Spark worker/executor is a long-running task, hence it occupies one of the cores allocated to the +Spark Streaming application. Therefore, it is important to remember that a Spark Streaming application needs to be allocated enough cores (or threads, if running locally) to process the received data, -as well as, to run the receiver(s). +as well as to run the receiver(s). ##### Points to remember {:.no_toc} @@ -621,13 +623,13 @@ as well as, to run the receiver(s). Either of these means that only one thread will be used for running tasks locally. If you are using a input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will be used to run the receiver, leaving no thread for processing the received data. Hence, when - running locally, always use "local[*n*]" as the master URL where *n* > number of receivers to run - (see [Spark Properties](configuration.html#spark-properties.html) for information on how to set + running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run + (see [Spark Properties](configuration.html#spark-properties) for information on how to set the master). - Extending the logic to running on a cluster, the number of cores allocated to the Spark Streaming - application must be more than the number of receivers. Otherwise the system will receive data, but - not be able to process them. + application must be more than the number of receivers. Otherwise the system will receive data, but + not be able to process it. ### Basic Sources {:.no_toc} @@ -637,7 +639,7 @@ which creates a DStream from text data received over a TCP socket connection. Besides sockets, the StreamingContext API provides methods for creating DStreams from files and Akka actors as input sources. -- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as +- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as:
    @@ -660,8 +662,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. - Python API As of Spark 1.2, - `fileStream` is not available in the Python API, only `textFileStream` is available. + Python API `fileStream` is not available in the Python API, only `textFileStream` is available. - **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver @@ -680,14 +681,15 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea ### Advanced Sources {:.no_toc} -Python API As of Spark 1.2, -these sources are not available in the Python API. + +Python API As of Spark {{site.SPARK_VERSION_SHORT}}, +out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts -of dependencies, the functionality to create DStreams from these sources have been moved to separate -libraries, that can be [linked](#linking) to explicitly when necessary. For example, if you want to -create a DStream using data from Twitter's stream of tweets, you have to do the following. +of dependencies, the functionality to create DStreams from these sources has been moved to separate +libraries that can be [linked](#linking) to explicitly when necessary. For example, if you want to +create a DStream using data from Twitter's stream of tweets, you have to do the following: 1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the SBT/Maven project dependencies. @@ -702,7 +704,7 @@ create a DStream using data from Twitter's stream of tweets, you have to do the {% highlight scala %} import org.apache.spark.streaming.twitter._ -TwitterUtils.createStream(ssc) +TwitterUtils.createStream(ssc, None) {% endhighlight %}
    @@ -717,10 +719,16 @@ TwitterUtils.createStream(jssc); Note that these advanced sources are not available in the Spark shell, hence applications based on these advanced sources cannot be tested in the shell. If you really want to use them in the Spark shell you will have to download the corresponding Maven artifact's JAR along with its dependencies -and it in the classpath. +and add it to the classpath. Some of these advanced sources are as follows. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. + +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. + +- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. + - **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by @@ -730,19 +738,12 @@ Some of these advanced sources are as follows. ([TwitterPopularTags]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala) and [TwitterAlgebirdCMS]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala)). -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can received data from Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. - -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can receive data from Kafka 0.8.0. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - -- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. - ### Custom Sources {:.no_toc} -Python API As of Spark 1.2, -these sources are not available in the Python API. +Python API This is not yet supported in Python. -Input DStreams can also be created out of custom data sources. All you have to do is implement an +Input DStreams can also be created out of custom data sources. All you have to do is implement a user-defined **receiver** (see next section to understand what that is) that can receive data from the custom sources and push it into Spark. See the [Custom Receiver Guide](streaming-custom-receivers.html) for details. @@ -752,14 +753,12 @@ Guide](streaming-custom-receivers.html) for details. There can be two kinds of data sources based on their *reliability*. Sources (like Kafka and Flume) allow the transferred data to be acknowledged. If the system receiving -data from these *reliable* sources acknowledge the received data correctly, it can be ensured -that no data gets lost due to any kind of failure. This leads to two kinds of receivers. +data from these *reliable* sources acknowledges the received data correctly, it can be ensured +that no data will be lost due to any kind of failure. This leads to two kinds of receivers: -1. *Reliable Receiver* - A *reliable receiver* correctly acknowledges a reliable - source that the data has been received and stored in Spark with replication. -1. *Unreliable Receiver* - These are receivers for sources that do not support acknowledging. Even - for reliable sources, one may implement an unreliable receiver that do not go into the complexity - of acknowledging correctly. +1. *Reliable Receiver* - A *reliable receiver* correctly sends acknowledgment to a reliable + source when the data has been received and stored in Spark with replication. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgment to a source. This can be used for sources that do not support acknowledgment, or even for reliable sources when one does not want or need to go into the complexity of acknowledgment. The details of how to write a reliable receiver are discussed in the [Custom Receiver Guide](streaming-custom-receivers.html). @@ -827,7 +826,7 @@ Some of the common ones are as follows.
    cogroup(otherStream, [numTasks]) When called on DStream of (K, V) and (K, W) pairs, return a new DStream of + When called on a DStream of (K, V) and (K, W) pairs, return a new DStream of (K, Seq[V], Seq[W]) tuples.
    -The last two transformations are worth highlighting again. +A few of these transformations are worth discussing in more detail. #### UpdateStateByKey Operation {:.no_toc} The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating it with new information. To use this, you will have to do two steps. -1. Define the state - The state can be of arbitrary data type. +1. Define the state - The state can be an arbitrary data type. 1. Define the state update function - Specify with a function how to update the state using the -previous state and the new values from input stream. +previous state and the new values from an input stream. + +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We -define the update function as +define the update function as:
    @@ -876,6 +877,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Scala code, take a look at the example +[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache +/spark/examples/streaming/StatefulNetworkWordCount.scala). +
    @@ -897,6 +904,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Java code, take a look at the example +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +/JavaStatefulNetworkWordCount.java). +
    @@ -914,14 +927,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi runningCounts = pairs.updateStateByKey(updateFunction) {% endhighlight %} -
    -
    - The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example +Python code, take a look at the example [stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +
    + + Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is discussed in detail in the [checkpointing](#checkpointing) section. @@ -934,7 +947,7 @@ operation that is not exposed in the DStream API. For example, the functionality of joining every batch in a data stream with another dataset is not directly exposed in the DStream API. However, you can easily use `transform` to do this. This enables very powerful possibilities. For example, -if you want to do real-time data cleaning by joining the input data stream with precomputed +one can do real-time data cleaning by joining the input data stream with precomputed spam information (maybe generated with Spark as well) and then filtering based on it.
    @@ -978,13 +991,14 @@ cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(.
    -In fact, you can also use [machine learning](mllib-guide.html) and -[graph computation](graphx-programming-guide.html) algorithms in the `transform` method. +Note that the supplied function gets called in every batch interval. This allows you to do +time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables, +etc. can be changed between batches. #### Window Operations {:.no_toc} -Finally, Spark Streaming also provides *windowed computations*, which allow you to apply -transformations over a sliding window of data. This following figure illustrates this sliding +Spark Streaming also provides *windowed computations*, which allow you to apply +transformations over a sliding window of data. The following figure illustrates this sliding window.

    @@ -996,11 +1010,11 @@ window. As shown in the figure, every time the window *slides* over a source DStream, the source RDDs that fall within the window are combined and operated upon to produce the -RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time +RDDs of the windowed DStream. In this specific case, the operation is applied over the last 3 time units of data, and slides by 2 time units. This shows that any window operation needs to specify two parameters. - * window length - The duration of the window (3 in the figure) + * window length - The duration of the window (3 in the figure). * sliding interval - The interval at which the window operation is performed (2 in the figure). @@ -1008,7 +1022,7 @@ These two parameters must be multiples of the batch interval of the source DStre figure). Let's illustrate the window operations with an example. Say, you want to extend the -[earlier example](#a-quick-example) by generating word counts over last 30 seconds of data, +[earlier example](#a-quick-example) by generating word counts over the last 30 seconds of data, every 10 seconds. To do this, we have to apply the `reduceByKey` operation on the `pairs` DStream of `(word, 1)` pairs over the last 30 seconds of data. This is done using the operation `reduceByKeyAndWindow`. @@ -1027,7 +1041,7 @@ val windowedWordCounts = pairs.reduceByKeyAndWindow((a:Int,b:Int) => (a + b), Se {% highlight java %} // Reduce function adding two integers, defined separately for clarity Function2 reduceFunc = new Function2() { - @Override public Integer call(Integer i1, Integer i2) throws Exception { + @Override public Integer call(Integer i1, Integer i2) { return i1 + i2; } }; @@ -1083,13 +1097,13 @@ said two parameters - windowLength and slideInterval. reduceByKeyAndWindow(func, invFunc, windowLength, slideInterval, [numTasks]) - A more efficient version of the above reduceByKeyAndWindow() where the reduce + A more efficient version of the above reduceByKeyAndWindow() where the reduce value of each window is calculated incrementally using the reduce values of the previous window. - This is done by reducing the new data that enter the sliding window, and "inverse reducing" the - old data that leave the window. An example would be that of "adding" and "subtracting" counts - of keys as the window slides. However, it is applicable to only "invertible reduce functions", + This is done by reducing the new data that enters the sliding window, and "inverse reducing" the + old data that leaves the window. An example would be that of "adding" and "subtracting" counts + of keys as the window slides. However, it is applicable only to "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as - parameter invFunc. Like in reduceByKeyAndWindow, the number of reduce tasks + parameter invFunc). Like in reduceByKeyAndWindow, the number of reduce tasks is configurable through an optional argument. Note that [checkpointing](#checkpointing) must be enabled for using this operation. @@ -1106,6 +1120,100 @@ said two parameters - windowLength and slideInterval. +#### Join Operations +{:.no_toc} +Finally, its worth highlighting how easily you can perform different kinds of joins in Spark Streaming. + + +##### Stream-stream joins +{:.no_toc} +Streams can be very easily joined with other streams. + +

    +
    +{% highlight scala %} +val stream1: DStream[String, String] = ... +val stream2: DStream[String, String] = ... +val joinedStream = stream1.join(stream2) +{% endhighlight %} +
    +
    +{% highlight java %} +JavaPairDStream stream1 = ... +JavaPairDStream stream2 = ... +JavaPairDStream> joinedStream = stream1.join(stream2); +{% endhighlight %} +
    +
    +{% highlight python %} +stream1 = ... +stream2 = ... +joinedStream = stream1.join(stream2) +{% endhighlight %} +
    +
    +Here, in each batch interval, the RDD generated by `stream1` will be joined with the RDD generated by `stream2`. You can also do `leftOuterJoin`, `rightOuterJoin`, `fullOuterJoin`. Furthermore, it is often very useful to do joins over windows of the streams. That is pretty easy as well. + +
    +
    +{% highlight scala %} +val windowedStream1 = stream1.window(Seconds(20)) +val windowedStream2 = stream2.window(Minutes(1)) +val joinedStream = windowedStream1.join(windowedStream2) +{% endhighlight %} +
    +
    +{% highlight java %} +JavaPairDStream windowedStream1 = stream1.window(Durations.seconds(20)); +JavaPairDStream windowedStream2 = stream2.window(Durations.minutes(1)); +JavaPairDStream> joinedStream = windowedStream1.join(windowedStream2); +{% endhighlight %} +
    +
    +{% highlight python %} +windowedStream1 = stream1.window(20) +windowedStream2 = stream2.window(60) +joinedStream = windowedStream1.join(windowedStream2) +{% endhighlight %} +
    +
    + +##### Stream-dataset joins +{:.no_toc} +This has already been shown earlier while explain `DStream.transform` operation. Here is yet another example of joining a windowed stream with a dataset. + +
    +
    +{% highlight scala %} +val dataset: RDD[String, String] = ... +val windowedStream = stream.window(Seconds(20))... +val joinedStream = windowedStream.transform { rdd => rdd.join(dataset) } +{% endhighlight %} +
    +
    +{% highlight java %} +JavaPairRDD dataset = ... +JavaPairDStream windowedStream = stream.window(Durations.seconds(20)); +JavaPairDStream joinedStream = windowedStream.transform( + new Function>, JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) { + return rdd.join(dataset); + } + } +); +{% endhighlight %} +
    +
    +{% highlight python %} +dataset = ... # some RDD +windowedStream = stream.window(20) +joinedStream = windowedStream.transform(lambda rdd: rdd.join(dataset)) +{% endhighlight %} +
    +
    + +In fact, you can also dynamically change the dataset you want to join against. The function provided to `transform` is evaluated every batch interval and therefore will use the current dataset that `dataset` reference points to. The complete list of DStream transformations is available in the API documentation. For the Scala API, see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) @@ -1117,7 +1225,7 @@ For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.stre *** ## Output Operations on DStreams -Output operations allow DStream's data to be pushed out external systems like a database or a file systems. +Output operations allow DStream's data to be pushed out to external systems like a database or a file systems. Since the output operations actually allow the transformed data to be consumed by external systems, they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). Currently, the following output operations are defined: @@ -1126,7 +1234,7 @@ Currently, the following output operations are defined: Output OperationMeaning print() - Prints first ten elements of every batch of data in a DStream on the driver node running + Prints the first ten elements of every batch of data in a DStream on the driver node running the streaming application. This is useful for development and debugging.
    Python API This is called @@ -1135,12 +1243,12 @@ Currently, the following output operations are defined: saveAsTextFiles(prefix, [suffix]) - Save this DStream's contents as a text files. The file name at each batch interval is + Save this DStream's contents as text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". saveAsObjectFiles(prefix, [suffix]) - Save this DStream's contents as a SequenceFile of serialized Java objects. The file + Save this DStream's contents as SequenceFiles of serialized Java objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    @@ -1150,7 +1258,7 @@ Currently, the following output operations are defined: saveAsHadoopFiles(prefix, [suffix]) - Save this DStream's contents as a Hadoop file. The file name at each batch interval is + Save this DStream's contents as Hadoop files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    Python API This is not available in @@ -1160,7 +1268,7 @@ Currently, the following output operations are defined: foreachRDD(func) The most generic output operator that applies a function, func, to each RDD generated from - the stream. This function should push the data in each RDD to a external system, like saving the RDD to + the stream. This function should push the data in each RDD to an external system, such as saving the RDD to files, or writing it over the network to a database. Note that the function func is executed in the driver process running the streaming application, and will usually have RDD actions in it that will force the computation of the streaming RDDs. @@ -1170,14 +1278,14 @@ Currently, the following output operations are defined: ### Design Patterns for using foreachRDD {:.no_toc} -`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems. +`dstream.foreachRDD` is a powerful primitive that allows data to be sent out to external systems. However, it is important to understand how to use this primitive correctly and efficiently. Some of the common mistakes to avoid are as follows. Often writing data to external system requires creating a connection object (e.g. TCP connection to a remote server) and using it to send data to a remote system. For this purpose, a developer may inadvertently try creating a connection object at -the Spark driver, but try to use it in a Spark worker to save records in the RDDs. +the Spark driver, and then try to use it in a Spark worker to save records in the RDDs. For example (in Scala),
    @@ -1239,7 +1347,7 @@ dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use -`rdd.foreachPartition` - create a single connection object and send all the records in a RDD +`rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
    @@ -1313,9 +1421,150 @@ Note that the connections in the pool should be lazily created on demand and tim *** +## DataFrame and SQL Operations +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. + +
    +
    +{% highlight scala %} + +/** DataFrame operations inside your streaming program */ + +val words: DStream[String] = ... + +words.foreachRDD { rdd => + + // Get the singleton instance of SQLContext + val sqlContext = SQLContext.getOrCreate(rdd.sparkContext) + import sqlContext.implicits._ + + // Convert RDD[String] to DataFrame + val wordsDataFrame = rdd.toDF("word") + + // Register as table + wordsDataFrame.registerTempTable("words") + + // Do word count on DataFrame using SQL and print it + val wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala). +
    +
    +{% highlight java %} + +/** Java Bean class for converting RDD to DataFrame */ +public class JavaRow implements java.io.Serializable { + private String word; + + public String getWord() { + return word; + } + + public void setWord(String word) { + this.word = word; + } +} + +... + +/** DataFrame operations inside your streaming program */ + +JavaDStream words = ... + +words.foreachRDD( + new Function2, Time, Void>() { + @Override + public Void call(JavaRDD rdd, Time time) { + + // Get the singleton instance of SQLContext + SQLContext sqlContext = SQLContext.getOrCreate(rdd.context()); + + // Convert RDD[String] to RDD[case class] to DataFrame + JavaRDD rowRDD = rdd.map(new Function() { + public JavaRow call(String word) { + JavaRow record = new JavaRow(); + record.setWord(word); + return record; + } + }); + DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRow.class); + + // Register as table + wordsDataFrame.registerTempTable("words"); + + // Do word count on table using SQL and print it + DataFrame wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word"); + wordCountsDataFrame.show(); + return null; + } + } +); +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java). +
    +
    +{% highlight python %} + +# Lazily instantiated global instance of SQLContext +def getSqlContextInstance(sparkContext): + if ('sqlContextSingletonInstance' not in globals()): + globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) + return globals()['sqlContextSingletonInstance'] + +... + +# DataFrame operations inside your streaming program + +words = ... # DStream of strings + +def process(time, rdd): + print("========= %s =========" % str(time)) + try: + # Get the singleton instance of SQLContext + sqlContext = getSqlContextInstance(rdd.context) + + # Convert RDD[String] to RDD[Row] to DataFrame + rowRdd = rdd.map(lambda w: Row(word=w)) + wordsDataFrame = sqlContext.createDataFrame(rowRdd) + + # Register as table + wordsDataFrame.registerTempTable("words") + + # Do word count on table using SQL and print it + wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() + except: + pass + +words.foreachRDD(process) +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/sql_network_wordcount.py). + +
    +
    + +You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember a sufficient amount of streaming data such that the query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). + +See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames. + +*** + +## MLlib Operations +You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. [Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. + +*** + ## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, -using `persist()` method on a DStream will automatically persist every RDD of that DStream in +using the `persist()` method on a DStream will automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. @@ -1327,28 +1576,27 @@ default persistence level is set to replicate the data to two nodes for fault-to Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More -information on different persistence levels can be found in -[Spark Programming Guide](programming-guide.html#rdd-persistence). +information on different persistence levels can be found in the [Spark Programming Guide](programming-guide.html#rdd-persistence). *** ## Checkpointing A streaming application must operate 24/7 and hence must be resilient to failures unrelated to the application logic (e.g., system failures, JVM crashes, etc.). For this to be possible, -Spark Streaming needs to *checkpoints* enough information to a fault- +Spark Streaming needs to *checkpoint* enough information to a fault- tolerant storage system such that it can recover from failures. There are two types of data that are checkpointed. - *Metadata checkpointing* - Saving of the information defining the streaming computation to fault-tolerant storage like HDFS. This is used to recover from failure of the node running the driver of the streaming application (discussed in detail later). Metadata includes: - + *Configuration* - The configuration that were used to create the streaming application. + + *Configuration* - The configuration that was used to create the streaming application. + *DStream operations* - The set of DStream operations that define the streaming application. + *Incomplete batches* - Batches whose jobs are queued but have not completed yet. - *Data checkpointing* - Saving of the generated RDDs to reliable storage. This is necessary in some *stateful* transformations that combine data across multiple batches. In such - transformations, the generated RDDs depends on RDDs of previous batches, which causes the length - of the dependency chain to keep increasing with time. To avoid such unbounded increase in recovery + transformations, the generated RDDs depend on RDDs of previous batches, which causes the length + of the dependency chain to keep increasing with time. To avoid such unbounded increases in recovery time (proportional to dependency chain), intermediate RDDs of stateful transformations are periodically *checkpointed* to reliable storage (e.g. HDFS) to cut off the dependency chains. @@ -1362,10 +1610,10 @@ transformations are used. Checkpointing must be enabled for applications with any of the following requirements: - *Usage of stateful transformations* - If either `updateStateByKey` or `reduceByKeyAndWindow` (with - inverse function) is used in the application, then the checkpoint directory must be provided for - allowing periodic RDD checkpointing. + inverse function) is used in the application, then the checkpoint directory must be provided to + allow for periodic RDD checkpointing. - *Recovering from failures of the driver running the application* - Metadata checkpoints are used - for to recover with progress information. + to recover with progress information. Note that simple streaming applications without the aforementioned stateful transformations can be run without enabling checkpointing. The recovery from driver failures will also be partial in @@ -1380,7 +1628,7 @@ Checkpointing can be enabled by setting a directory in a fault-tolerant, reliable file system (e.g., HDFS, S3, etc.) to which the checkpoint information will be saved. This is done by using `streamingContext.checkpoint(checkpointDirectory)`. This will allow you to use the aforementioned stateful transformations. Additionally, -if you want make the application recover from driver failures, you should rewrite your +if you want to make the application recover from driver failures, you should rewrite your streaming application to have the following behavior. + When the program is being started for the first time, it will create a new StreamingContext, @@ -1454,7 +1702,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example +context and set up the DStreams. See the Java example [JavaRecoverableNetworkWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). This example appends the word counts of network data into a file. @@ -1501,18 +1749,17 @@ You can also explicitly create a `StreamingContext` from the checkpoint data and In addition to using `getOrCreate` one also needs to ensure that the driver process gets restarted automatically on failure. This can only be done by the deployment infrastructure that is used to run the application. This is further discussed in the -[Deployment](#deploying-applications.html) section. +[Deployment](#deploying-applications) section. Note that checkpointing of RDDs incurs the cost of saving to reliable storage. This may cause an increase in the processing time of those batches where RDDs get checkpointed. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too infrequently -causes the lineage and task sizes to grow which may have detrimental effects. For stateful +causes the lineage and task sizes to grow, which may have detrimental effects. For stateful transformations that require RDD checkpointing, the default interval is a multiple of the batch interval that is at least 10 seconds. It can be set by using -`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 times of -sliding interval of a DStream is good setting to try. +`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 sliding intervals of a DStream is a good setting to try. *** @@ -1566,9 +1813,8 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. - -- *[Experimental in Spark 1.2] Configuring write ahead logs* - In Spark 1.2, - we have introduced a new experimental feature of write ahead logs for achieving strong +- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, + we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the @@ -1586,17 +1832,17 @@ To run a Spark Streaming applications, you need to have the following. {:.no_toc} If a running Spark Streaming application needs to be upgraded with new -application code, then there are two possible mechanism. +application code, then there are two possible mechanisms. - The upgraded Spark Streaming application is started and run in parallel to the existing application. -Once the new one (receiving the same data as the old one) has been warmed up and ready +Once the new one (receiving the same data as the old one) has been warmed up and is ready for prime time, the old one be can be brought down. Note that this can be done for data sources that support sending the data to two destinations (i.e., the earlier and upgraded applications). - The existing application is shutdown gracefully (see [`StreamingContext.stop(...)`](api/scala/index.html#org.apache.spark.streaming.StreamingContext) or [`JavaStreamingContext.stop(...)`](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) -for graceful shutdown options) which ensure data that have been received is completely +for graceful shutdown options) which ensure data that has been received is completely processed before shutdown. Then the upgraded application can be started, which will start processing from the same point where the earlier application left off. Note that this can be done only with input sources that support source-side buffering @@ -1631,13 +1877,13 @@ The following two metrics in web UI are particularly important: to finish. If the batch processing time is consistently more than the batch interval and/or the queueing -delay keeps increasing, then it indicates the system is -not able to process the batches as fast they are being generated and falling behind. +delay keeps increasing, then it indicates that the system is +not able to process the batches as fast they are being generated and is falling behind. In that case, consider -[reducing](#reducing-the-processing-time-of-each-batch) the batch processing time. +[reducing](#reducing-the-batch-processing-times) the batch processing time. The progress of a Spark Streaming program can also be monitored using the -[StreamingListener](api/scala/index.html#org.apache.spark.scheduler.StreamingListener) interface, +[StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface, which allows you to get receiver status and processing times. Note that this is a developer API and it is likely to be improved upon (i.e., more information reported) in the future. @@ -1645,8 +1891,8 @@ and it is likely to be improved upon (i.e., more information reported) in the fu *************************************************************************************************** # Performance Tuning -Getting the best performance of a Spark Streaming application on a cluster requires a bit of -tuning. This section explains a number of the parameters and configurations that can tuned to +Getting the best performance out of a Spark Streaming application on a cluster requires a bit of +tuning. This section explains a number of the parameters and configurations that can be tuned to improve the performance of you application. At a high level, you need to consider two things: 1. Reducing the processing time of each batch of data by efficiently using cluster resources. @@ -1654,24 +1900,24 @@ improve the performance of you application. At a high level, you need to conside 2. Setting the right batch size such that the batches of data can be processed as fast as they are received (that is, data processing keeps up with the data ingestion). -## Reducing the Processing Time of each Batch +## Reducing the Batch Processing Times There are a number of optimizations that can be done in Spark to minimize the processing time of -each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section +each batch. These have been discussed in detail in the [Tuning Guide](tuning.html). This section highlights some of the most important ones. ### Level of Parallelism in Data Receiving {:.no_toc} -Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized +Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to be deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, thus increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    @@ -1693,16 +1939,24 @@ JavaPairDStream unifiedStream = streamingContext.union(kafkaStre unifiedStream.print(); {% endhighlight %}
    +
    +{% highlight python %} +numStreams = 5 +kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] +unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
    Another parameter that should be considered is the receiver's blocking interval, which is determined by the [configuration parameter](configuration.html#spark-streaming) `spark.streaming.blockInterval`. For most receivers, the received data is coalesced together into blocks of data before storing inside Spark's memory. The number of blocks in each batch -determines the number of tasks that will be used to process those +determines the number of tasks that will be used to process the received data in a map-like transformation. The number of tasks per receiver per batch will be approximately (batch interval / block interval). For example, block interval of 200 ms will -create 10 tasks per 2 second batches. Too low the number of tasks (that is, less than the number +create 10 tasks per 2 second batches. If the number of tasks is too low (that is, less than the number of cores per machine), then it will be inefficient as all available cores will not be used to process the data. To increase the number of tasks for a given batch interval, reduce the block interval. However, the recommended minimum value of block interval is about 50 ms, @@ -1710,7 +1964,7 @@ below which the task launching overheads may be a problem. An alternative to receiving data with multiple input streams / receivers is to explicitly repartition the input data stream (using `inputStream.repartition()`). -This distributes the received batches of data across specified number of machines in the cluster +This distributes the received batches of data across the specified number of machines in the cluster before further processing. ### Level of Parallelism in Data Processing @@ -1718,7 +1972,7 @@ before further processing. Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is controlled by -the`spark.default.parallelism` [configuration property](configuration.html#spark-properties). You +the `spark.default.parallelism` [configuration property](configuration.html#spark-properties). You can pass the level of parallelism as an argument (see [`PairDStreamFunctions`](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions) documentation), or set the `spark.default.parallelism` @@ -1726,21 +1980,20 @@ documentation), or set the `spark.default.parallelism` ### Data Serialization {:.no_toc} -The overhead of data serialization can be significant, especially when sub-second batch sizes are - to be achieved. There are two aspects to it. +The overheads of data serialization can be reduced by tuning the serialization formats. In the case of streaming, there are two types of data that are being serialized. + +* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is insufficient to hold all of the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. -* **Serialization of RDD data in Spark**: Please refer to the detailed discussion on data - serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default - RDDs are persisted as serialized byte arrays to minimize pauses related to GC. +* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operations persist data in memory as they would be processed multiple times. However, unlike the Spark Core default of [StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$), persisted RDDs generated by streaming computations are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) by default to minimize GC overheads. -* **Serialization of input data**: To ingest external data into Spark, data received as bytes - (say, from the network) needs to deserialized from bytes and re-serialized into Spark's - serialization format. Hence, the deserialization overhead of input data may be a bottleneck. +In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization) for more details. For Kryo, consider registering custom classes, and disabling object reference tracking (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). + +In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of a few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. ### Task Launching Overheads {:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead -of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second +of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: * **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task @@ -1755,11 +2008,11 @@ thus allowing sub-second batch size to be viable. *** -## Setting the Right Batch Size +## Setting the Right Batch Interval For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed as fast as they are being generated. Whether this is true for an application can be found by -[monitoring](#monitoring) the processing times in the streaming web UI, where the batch +[monitoring](#monitoring-applications) the processing times in the streaming web UI, where the batch processing time should be less than the batch interval. Depending on the nature of the streaming @@ -1772,55 +2025,55 @@ production can be sustained. A good approach to figure out the right batch size for your application is to test it with a conservative batch interval (say, 5-10 seconds) and a low data rate. To verify whether the system -is able to keep up with data rate, you can check the value of the end-to-end delay experienced +is able to keep up with the data rate, you can check the value of the end-to-end delay experienced by each processed batch (either look for "Total delay" in Spark driver log4j logs, or use the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface). If the delay is maintained to be comparable to the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the -data rate and/or reducing the batch size. Note that momentary increase in the delay due to -temporary data rate increases maybe fine as long as the delay reduces back to a low value +data rate and/or reducing the batch size. Note that a momentary increase in the delay due to +temporary data rate increases may be fine as long as the delay reduces back to a low value (i.e., less than batch size). *** ## Memory Tuning -Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail -in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, -we highlight a few customizations that are strongly recommended to minimize GC related pauses -in Spark Streaming applications and achieving more consistent batch processing times. - -* **Default persistence level of DStreams**: Unlike RDDs, the default persistence level of DStreams -serializes the data in memory (that is, -[StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) for -DStream compared to -[StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$) for RDDs). -Even though keeping the data serialized incurs higher serialization/deserialization overheads, -it significantly reduces GC pauses. - -* **Clearing persistent RDDs**: By default, all persistent RDDs generated by Spark Streaming will - be cleared from memory based on Spark's built-in policy (LRU). If `spark.cleaner.ttl` is set, - then persistent RDDs that are older than that value are periodically cleared. As mentioned - [earlier](#operation), this needs to be careful set based on operations used in the Spark - Streaming program. However, a smarter unpersisting of RDDs can be enabled by setting the - [configuration property](configuration.html#spark-properties) `spark.streaming.unpersist` to - `true`. This makes the system to figure out which RDDs are not necessary to be kept around and - unpersists them. This is likely to reduce - the RDD memory usage of Spark, potentially improving GC behavior as well. - -* **Concurrent garbage collector**: Using the concurrent mark-and-sweep GC further -minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the +Tuning the memory usage and GC behavior of Spark applications has been discussed in great detail +in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications. + +The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on the last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then the necessary memory will be low. + +In general, since the data received through receivers is stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. + +Another aspect of memory tuning is garbage collection. For a streaming application that requires low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. + +There are a few parameters that can help you tune the memory usage and GC overheads: + +* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. + +* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using a window operation of 10 minutes, then Spark Streaming will keep around the last 10 minutes of data, and actively throw away older data. +Data can be retained for a longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. + +* **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more -consistent batch processing times. +consistent batch processing times. Make sure you set the CMS GC on both the driver (using `--driver-java-options` in `spark-submit`) and the executors (using [Spark configuration](configuration.html#runtime-environment) `spark.executor.extraJavaOptions`). + +* **Other tips**: To further reduce GC overheads, here are some more tips to try. + - Use Tachyon for off-heap storage of persisted RDDs. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). + - Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap. + *************************************************************************************************** *************************************************************************************************** # Fault-tolerance Semantics In this section, we will discuss the behavior of Spark Streaming applications in the event -of node failures. To understand this, let us remember the basic fault-tolerance semantics of -Spark's RDDs. +of failures. + +## Background +{:.no_toc} +To understand the semantics provided by Spark Streaming, let us remember the basic fault-tolerance semantics of Spark's RDDs. 1. An RDD is an immutable, deterministically re-computable, distributed dataset. Each RDD remembers the lineage of deterministic operations that were used on a fault-tolerant input @@ -1830,18 +2083,18 @@ re-computed from the original fault-tolerant dataset using the lineage of operat 1. Assuming that all of the RDD transformations are deterministic, the data in the final transformed RDD will always be the same irrespective of failures in the Spark cluster. -Spark operates on data on fault-tolerant file systems like HDFS or S3. Hence, +Spark operates on data in fault-tolerant file systems like HDFS or S3. Hence, all of the RDDs generated from the fault-tolerant data are also fault-tolerant. However, this is not the case for Spark Streaming as the data in most cases is received over the network (except when `fileStream` is used). To achieve the same fault-tolerance properties for all of the generated RDDs, the received data is replicated among multiple Spark executors in worker nodes in the cluster (default replication factor is 2). This leads to two kinds of data in the -system that needs to recovered in the event of failures: +system that need to recovered in the event of failures: 1. *Data received and replicated* - This data survives failure of a single worker node as a copy - of it exists on one of the nodes. + of it exists on one of the other nodes. 1. *Data received but buffered for replication* - Since this is not replicated, - the only way to recover that data is to get it again from the source. + the only way to recover this data is to get it again from the source. Furthermore, there are two kinds of failures that we should be concerned about: @@ -1854,35 +2107,64 @@ Furthermore, there are two kinds of failures that we should be concerned about: With this basic knowledge, let us understand the fault-tolerance semantics of Spark Streaming. -## Semantics with files as input source +## Definitions +{:.no_toc} +The semantics of streaming systems are often captured in terms of how many times each record can be processed by the system. There are three types of guarantees that a system can provide under all possible operating conditions (despite failures, etc.) + +1. *At most once*: Each record will be either processed once or not processed at all. +2. *At least once*: Each record will be processed one or more times. This is stronger than *at-most once* as it ensure that no data will be lost. But there may be duplicates. +3. *Exactly once*: Each record will be processed exactly once - no data will be lost and no data will be processed multiple times. This is obviously the strongest guarantee of the three. + +## Basic Semantics +{:.no_toc} +In any stream processing system, broadly speaking, there are three steps in processing the data. + +1. *Receiving the data*: The data is received from sources using Receivers or otherwise. + +1. *Transforming the data*: The received data is transformed using DStream and RDD transformations. + +1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc. + +If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide an exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. + +1. *Receiving the data*: Different input sources provide different guarantees. This is discussed in detail in the next subsection. + +1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents. + +1. *Pushing out the data*: Output operations by default ensure _at-least once_ semantics because it depends on the type of output operation (idempotent, or not) and the semantics of the downstream system (supports transactions or not). But users can implement their own transaction mechanisms to achieve _exactly-once_ semantics. This is discussed in more details later in the section. + +## Semantics of Received Data +{:.no_toc} +Different input sources provide different guarantees, ranging from _at-least once_ to _exactly once_. Read for more details. + +### With Files {:.no_toc} -If all of the input data is already present in a fault-tolerant files system like -HDFS, Spark Streaming can always recover from any failure and process all the data. This gives -*exactly-once* semantics, that all the data will be processed exactly once no matter what fails. +If all of the input data is already present in a fault-tolerant file system like +HDFS, Spark Streaming can always recover from any failure and process all of the data. This gives +*exactly-once* semantics, meaning all of the data will be processed exactly once no matter what fails. -## Semantics with input sources based on receivers +### With Receiver-based Sources {:.no_toc} For input sources based on receivers, the fault-tolerance semantics depend on both the failure scenario and the type of receiver. As we discussed [earlier](#receiver-reliability), there are two types of receivers: 1. *Reliable Receiver* - These receivers acknowledge reliable sources only after ensuring that - the received data has been replicated. If such a receiver fails, - the buffered (unreplicated) data does not get acknowledged to the source. If the receiver is - restarted, the source will resend the data, and therefore no data will be lost due to the failure. -1. *Unreliable Receiver* - Such receivers can lose data when they fail due to worker - or driver failures. + the received data has been replicated. If such a receiver fails, the source will not receive + acknowledgment for the buffered (unreplicated) data. Therefore, if the receiver is + restarted, the source will resend the data, and no data will be lost due to the failure. +1. *Unreliable Receiver* - Such receivers do *not* send acknowledgment and therefore *can* lose + data when they fail due to worker or driver failures. Depending on what type of receivers are used we achieve the following semantics. If a worker node fails, then there is no data loss with reliable receivers. With unreliable receivers, data received but not replicated can get lost. If the driver node fails, -then besides these losses, all the past data that was received and replicated in memory will be +then besides these losses, all of the past data that was received and replicated in memory will be lost. This will affect the results of the stateful transformations. -To avoid this loss of past received data, Spark 1.2 introduces an experimental feature of _write -ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs -enabled](#deploying-applications) and reliable receivers, there is zero data loss and -exactly-once semantics. +To avoid this loss of past received data, Spark 1.2 introduced _write +ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -1894,23 +2176,30 @@ The following table summarizes the semantics under failures: - Spark 1.1 or earlier, or
    - Spark 1.2 without write ahead log + Spark 1.1 or earlier, OR
    + Spark 1.2 or later without write ahead logs Buffered data lost with unreliable receivers
    - Zero data loss with reliable receivers and files
    + Zero data loss with reliable receivers
    + At-least once semantics Buffered data lost with unreliable receivers
    Past data lost with all receivers
    - Zero data loss with files - + Undefined semantics + - Spark 1.2 with write ahead log - Zero data loss with reliable receivers and files - Zero data loss with reliable receivers and files + Spark 1.2 or later with write ahead logs + + Zero data loss with reliable receivers
    + At-least once semantics + + + Zero data loss with reliable receivers and files
    + At-least once semantics + @@ -1919,18 +2208,32 @@ The following table summarizes the semantics under failures: +### With Kafka Direct API +{:.no_toc} +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). + ## Semantics of output operations {:.no_toc} -Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation - always leads to the same result. As a result, all DStream transformations are guaranteed to have - _exactly-once_ semantics. That is, the final transformed result will be same even if there were - was a worker node failure. However, output operations (like `foreachRDD`) have _at-least once_ - semantics, that is, the transformed data may get written to an external entity more than once in - the event of a worker failure. While this is acceptable for saving to HDFS using the - `saveAs***Files` operations (as the file will simply get over-written by the same data), - additional transactions-like mechanisms may be necessary to achieve exactly-once semantics - for output operations. +Output operations (like `foreachRDD`) have _at-least once_ semantics, that is, +the transformed data may get written to an external entity more than once in +the event of a worker failure. While this is acceptable for saving to file systems using the +`saveAs***Files` operations (as the file will simply get overwritten with the same data), +additional effort may be necessary to achieve exactly-once semantics. There are two approaches. + +- *Idempotent updates*: Multiple attempts always write the same data. For example, `saveAs***Files` always writes the same data to the generated files. + +- *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following. + - Use the batch time (available in `foreachRDD`) and the partition index of the RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. + - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else, if this was already committed, skip the update. + + dstream.foreachRDD { (rdd, time) => + rdd.foreachPartition { partitionIterator => + val partitionId = TaskContext.get.partitionId() + val uniqueId = generateUniqueId(time.milliseconds, partitionId) + // use this uniqueId to transactionally commit the data in partitionIterator + } + } *************************************************************************************************** *************************************************************************************************** @@ -1987,7 +2290,11 @@ package and renamed for better clarity. *************************************************************************************************** # Where to Go from Here - +* Additional guides + - [Kafka Integration Guide](streaming-kafka-integration.html) + - [Flume Integration Guide](streaming-flume-integration.html) + - [Kinesis Integration Guide](streaming-kinesis-integration.html) + - [Custom Receiver Guide](streaming-custom-receivers.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and @@ -2001,7 +2308,7 @@ package and renamed for better clarity. - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and - [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html) + [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) @@ -2009,8 +2316,8 @@ package and renamed for better clarity. [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) - Python docs - * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) - * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) + * [KafkaUtils](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 14a87f843698..e58645274e52 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -59,7 +59,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between the drivers and the executors. Note that `cluster` mode is currently not supported for -Mesos clusters or Python applications. +Mesos clusters. Currently only YARN supports cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. @@ -133,10 +133,10 @@ The master URL passed to Spark can be in one of the following formats: Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... yarn-client Connect to a YARN cluster in -client mode. The cluster location will be found based on the HADOOP_CONF_DIR variable. +client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. yarn-cluster Connect to a YARN cluster in -cluster mode. The cluster location will be found based on HADOOP_CONF_DIR. +cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. @@ -174,6 +174,11 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. + For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries to executors. diff --git a/docs/tuning.md b/docs/tuning.md index efaac9d3d405..6936912a6be5 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -1,6 +1,8 @@ --- layout: global -title: Tuning Spark +displayTitle: Tuning Spark +title: Tuning +description: Tuning and performance optimization guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -58,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. @@ -92,11 +94,13 @@ We will then cover tuning Spark's cache size and the Java garbage collector. ## Determining Memory Consumption -The best way to size the amount of memory consumption your dataset will require is to create an RDD, put it into cache, and look at the SparkContext logs on your driver program. The logs will tell you how much memory each partition is consuming, which you can aggregate to get the total size of the RDD. You will see messages like this: +The best way to size the amount of memory consumption a dataset will require is to create an RDD, put it +into cache, and look at the "Storage" page in the web UI. The page will tell you how much memory the RDD +is occupying. - INFO BlockManagerMasterActor: Added rdd_0_1 in memory on mbk.local:50311 (size: 717.5 KB, free: 332.3 MB) - -This means that partition 1 of RDD 0 consumed 717.5 KB. +To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method +This is useful for experimenting with different data layouts to trim memory usage, as well as +determining the amount of space a broadcast variable will occupy on each executor heap. ## Tuning Data Structures @@ -236,7 +240,7 @@ worth optimizing. ## Data Locality Data locality can have a major impact on the performance of Spark jobs. If data and the code that -operates on it are together than computation tends to be fast. But if code and data are separated, +operates on it are together then computation tends to be fast. But if code and data are separated, one must move to the other. Typically it is faster to ship serialized code from place to place than a chunk of data because code size is much smaller than data. Spark builds its scheduling around this general principle of data locality. diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh index 740c267fd986..4f3e8da809f7 100644 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh @@ -25,10 +25,10 @@ export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" export MODULES="{{modules}}" export SPARK_VERSION="{{spark_version}}" -export SHARK_VERSION="{{shark_version}}" +export TACHYON_VERSION="{{tachyon_version}}" export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" export SWAP_MB="{{swap}}" export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" export SPARK_MASTER_OPTS="{{spark_master_opts}}" export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" -export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" \ No newline at end of file +export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index abab209a05ba..11fd7ee0ec8d 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,26 +19,41 @@ # limitations under the License. # -from __future__ import with_statement +from __future__ import division, print_function, with_statement +import codecs import hashlib +import itertools import logging import os +import os.path import pipes import random import shutil import string +from stat import S_IRUSR import subprocess import sys import tarfile import tempfile +import textwrap import time -import urllib2 import warnings from datetime import datetime from optparse import OptionParser from sys import stderr +if sys.version < "3": + from urllib2 import urlopen, Request, HTTPError +else: + from urllib.request import urlopen, Request + from urllib.error import HTTPError + raw_input = input + xrange = range + +SPARK_EC2_VERSION = "1.4.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) + VALID_SPARK_VERSIONS = set([ "0.7.3", "0.8.0", @@ -52,45 +67,87 @@ "1.1.0", "1.1.1", "1.2.0", + "1.2.1", + "1.3.0", + "1.3.1", + "1.4.0", ]) -DEFAULT_SPARK_VERSION = "1.2.0" +SPARK_TACHYON_MAP = { + "1.0.0": "0.4.1", + "1.0.1": "0.4.1", + "1.0.2": "0.4.1", + "1.1.0": "0.5.0", + "1.1.1": "0.5.0", + "1.2.0": "0.5.0", + "1.2.1": "0.5.0", + "1.3.0": "0.5.0", + "1.3.1": "0.5.0", + "1.4.0": "0.6.4", +} + +DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) -MESOS_SPARK_EC2_BRANCH = "branch-1.3" - -# A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) - - -def setup_boto(): - # Download Boto if it's not already present in the SPARK_EC2_DIR/lib folder: - version = "boto-2.34.0" - md5 = "5556223d2d0cc4d06dd4829e671dcecd" - url = "https://pypi.python.org/packages/source/b/boto/%s.tar.gz" % version - lib_dir = os.path.join(SPARK_EC2_DIR, "lib") - if not os.path.exists(lib_dir): - os.mkdir(lib_dir) - boto_lib_dir = os.path.join(lib_dir, version) - if not os.path.isdir(boto_lib_dir): - tgz_file_path = os.path.join(lib_dir, "%s.tar.gz" % version) - print "Downloading Boto from PyPi" - download_stream = urllib2.urlopen(url) - with open(tgz_file_path, "wb") as tgz_file: - tgz_file.write(download_stream.read()) - with open(tgz_file_path) as tar: - if hashlib.md5(tar.read()).hexdigest() != md5: - print >> stderr, "ERROR: Got wrong md5sum for Boto" - sys.exit(1) - tar = tarfile.open(tgz_file_path) - tar.extractall(path=lib_dir) - tar.close() - os.remove(tgz_file_path) - print "Finished downloading Boto" - sys.path.insert(0, boto_lib_dir) +# Default location to get the spark-ec2 scripts (and ami-list) from +DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" + + +def setup_external_libs(libs): + """ + Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. + """ + PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" + SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") + + if not os.path.exists(SPARK_EC2_LIB_DIR): + print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( + path=SPARK_EC2_LIB_DIR + )) + print("This should be a one-time operation.") + os.mkdir(SPARK_EC2_LIB_DIR) + + for lib in libs: + versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) + lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) + + if not os.path.isdir(lib_dir): + tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") + print(" - Downloading {lib}...".format(lib=lib["name"])) + download_stream = urlopen( + "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( + prefix=PYPI_URL_PREFIX, + first_letter=lib["name"][:1], + lib_name=lib["name"], + lib_version=lib["version"] + ) + ) + with open(tgz_file_path, "wb") as tgz_file: + tgz_file.write(download_stream.read()) + with open(tgz_file_path, "rb") as tar: + if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: + print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) + sys.exit(1) + tar = tarfile.open(tgz_file_path) + tar.extractall(path=SPARK_EC2_LIB_DIR) + tar.close() + os.remove(tgz_file_path) + print(" - Finished downloading {lib}.".format(lib=lib["name"])) + sys.path.insert(1, lib_dir) + + +# Only PyPI libraries are supported. +external_libs = [ + { + "name": "boto", + "version": "2.34.0", + "md5": "5556223d2d0cc4d06dd4829e671dcecd" + } +] + +setup_external_libs(external_libs) -setup_boto() import boto from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType from boto import ec2 @@ -103,12 +160,11 @@ class UsageError(Exception): # Configure and parse our command-line arguments def parse_args(): parser = OptionParser( - usage="spark-ec2 [options] " - + "\n\n can be: launch, destroy, login, stop, start, get-master, reboot-slaves", - add_help_option=False) - parser.add_option( - "-h", "--help", action="help", - help="Show this help message and exit") + prog="spark-ec2", + version="%prog {v}".format(v=SPARK_EC2_VERSION), + usage="%prog [options] \n\n" + + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") + parser.add_option( "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") @@ -130,13 +186,15 @@ def parse_args(): help="Master instance type (leave empty for same as instance-type)") parser.add_option( "-r", "--region", default="us-east-1", - help="EC2 region zone to launch instances in") + help="EC2 region used to launch instances in, or to find them in (default: %default)") parser.add_option( "-z", "--zone", default="", help="Availability zone to launch instances in, or 'all' to spread " + "slaves across multiple (an additional $0.01/Gb for bandwidth" + "between zones applies) (default: a single zone chosen at random)") - parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") + parser.add_option( + "-a", "--ami", + help="Amazon Machine Image ID to use") parser.add_option( "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") @@ -144,9 +202,27 @@ def parse_args(): "--spark-git-repo", default=DEFAULT_SPARK_GITHUB_REPO, help="Github repo from which to checkout supplied commit hash (default: %default)") + parser.add_option( + "--spark-ec2-git-repo", + default=DEFAULT_SPARK_EC2_GITHUB_REPO, + help="Github repo from which to checkout spark-ec2 (default: %default)") + parser.add_option( + "--spark-ec2-git-branch", + default=DEFAULT_SPARK_EC2_BRANCH, + help="Github repo branch of spark-ec2 to use (default: %default)") + parser.add_option( + "--deploy-root-dir", + default=None, + help="A directory to copy into / on the first master. " + + "Must be absolute. Note that a trailing slash is handled as per rsync: " + + "If you omit it, the last directory of the --deploy-root-dir path will be created " + + "in / before copying its contents. If you append the trailing slash, " + + "the directory is not created and its contents are copied directly into /. " + + "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") + help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + + "(Hadoop 2.4.0) (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -166,12 +242,13 @@ def parse_args(): help="Number of EBS volumes to attach to each node as /vol[x]. " + "The volumes will be deleted when the instances terminate. " + "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") - parser.add_option("--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") + parser.add_option( + "--placement-group", type="string", default=None, + help="Which placement group to try and launch " + + "instances into. Assumes placement group is already " + + "created.") parser.add_option( "--swap", metavar="SWAP", type="int", default=1024, help="Swap space to set up per node, in MB (default: %default)") @@ -197,27 +274,45 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + "(e.g -Dspark.worker.timeout=180)") parser.add_option( "--user-data", type="string", default="", - help="Path to a user-data file (most AMI's interpret this as an initialization script)") + help="Path to a user-data file (most AMIs interpret this as an initialization script)") parser.add_option( "--authorized-address", type="string", default="0.0.0.0/0", help="Address to authorize on created security groups (default: %default)") parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") parser.add_option( - "--subnet-id", default=None, help="VPC subnet to launch instances in") + "--subnet-id", default=None, + help="VPC subnet to launch instances in") parser.add_option( - "--vpc-id", default=None, help="VPC to launch instances in") + "--vpc-id", default=None, + help="VPC to launch instances in") + parser.add_option( + "--private-ips", action="store_true", default=False, + help="Use private IPs for instances rather than public if VPC/subnet " + + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="stop", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -230,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " + - "must be set") - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " + - "must be set") - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) @@ -248,7 +345,7 @@ def get_or_make_group(conn, name, vpc_id): if len(group) > 0: return group[0] else: - print "Creating security group " + name + print("Creating security group " + name) return conn.create_security_group(name, "Spark EC2 group", vpc_id) @@ -256,88 +353,110 @@ def get_validate_spark_version(version, repo): if "." in version: version = version.replace("v", "") if version not in VALID_SPARK_VERSIONS: - print >> stderr, "Don't know about Spark version: {v}".format(v=version) + print("Don't know about Spark version: {v}".format(v=version), file=stderr) sys.exit(1) return version else: github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = urllib2.Request(github_commit_url) + request = Request(github_commit_url) request.get_method = lambda: 'HEAD' try: - response = urllib2.urlopen(request) - except urllib2.HTTPError, e: - print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url) - print >> stderr, "Received HTTP response code of {code}.".format(code=e.code) + response = urlopen(request) + except HTTPError as e: + print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), + file=stderr) + print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) sys.exit(1) return version -# Check whether a given EC2 instance object is in a state we consider active, -# i.e. not terminating or terminated. We count both stopping and stopped as -# active since we can restart stopped clusters. -def is_active(instance): - return (instance.state in ['pending', 'running', 'stopping', 'stopped']) +# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Last Updated: 2015-06-19 +# For easy maintainability, please keep this manually-inputted dictionary sorted by key. +EC2_INSTANCE_TYPES = { + "c1.medium": "pvm", + "c1.xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", + "c3.2xlarge": "pvm", + "c3.4xlarge": "pvm", + "c3.8xlarge": "pvm", + "c4.large": "hvm", + "c4.xlarge": "hvm", + "c4.2xlarge": "hvm", + "c4.4xlarge": "hvm", + "c4.8xlarge": "hvm", + "cc1.4xlarge": "hvm", + "cc2.8xlarge": "hvm", + "cg1.4xlarge": "hvm", + "cr1.8xlarge": "hvm", + "d2.xlarge": "hvm", + "d2.2xlarge": "hvm", + "d2.4xlarge": "hvm", + "d2.8xlarge": "hvm", + "g2.2xlarge": "hvm", + "g2.8xlarge": "hvm", + "hi1.4xlarge": "pvm", + "hs1.8xlarge": "pvm", + "i2.xlarge": "hvm", + "i2.2xlarge": "hvm", + "i2.4xlarge": "hvm", + "i2.8xlarge": "hvm", + "m1.small": "pvm", + "m1.medium": "pvm", + "m1.large": "pvm", + "m1.xlarge": "pvm", + "m2.xlarge": "pvm", + "m2.2xlarge": "pvm", + "m2.4xlarge": "pvm", + "m3.medium": "hvm", + "m3.large": "hvm", + "m3.xlarge": "hvm", + "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "r3.2xlarge": "hvm", + "r3.4xlarge": "hvm", + "r3.8xlarge": "hvm", + "t1.micro": "pvm", + "t2.micro": "hvm", + "t2.small": "hvm", + "t2.medium": "hvm", + "t2.large": "hvm", +} + + +def get_tachyon_version(spark_version): + return SPARK_TACHYON_MAP.get(spark_version, "") # Attempt to resolve an appropriate AMI given the architecture and region of the request. -# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2014-06-20 -# For easy maintainability, please keep this manually-inputted dictionary sorted by key. def get_spark_ami(opts): - instance_types = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "i2.xlarge": "hvm", - "m1.large": "pvm", - "m1.medium": "pvm", - "m1.small": "pvm", - "m1.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m2.xlarge": "pvm", - "m3.2xlarge": "hvm", - "m3.large": "hvm", - "m3.medium": "hvm", - "m3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "t1.micro": "pvm", - "t2.medium": "hvm", - "t2.micro": "hvm", - "t2.small": "hvm", - } - if opts.instance_type in instance_types: - instance_type = instance_types[opts.instance_type] + if opts.instance_type in EC2_INSTANCE_TYPES: + instance_type = EC2_INSTANCE_TYPES[opts.instance_type] else: instance_type = "pvm" - print >> stderr,\ - "Don't recognize %s, assuming type is pvm" % opts.instance_type + print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) + + # URL prefix from which to fetch AMI information + ami_prefix = "{r}/{b}/ami-list".format( + r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), + b=opts.spark_ec2_git_branch) - ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type) + ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urllib2.urlopen(ami_path).read().strip() - print "Spark AMI: " + ami + ami = reader(urlopen(ami_path)).read().strip() except: - print >> stderr, "Could not resolve AMI at: " + ami_path + print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami @@ -347,10 +466,11 @@ def get_spark_ami(opts): # Fails if there already instances running in the cluster's groups. def launch_cluster(conn, opts, cluster_name): if opts.identity_file is None: - print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." + print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) sys.exit(1) + if opts.key_pair is None: - print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." + print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) sys.exit(1) user_data_content = None @@ -358,7 +478,7 @@ def launch_cluster(conn, opts, cluster_name): with open(opts.user_data) as user_data_file: user_data_content = user_data_file.read() - print "Setting up security groups..." + print("Setting up security groups...") master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) authorized_address = opts.authorized_address @@ -387,6 +507,17 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) + # HDFS NFS gateway requires 111,2049,4242 for tcp & udp + master_group.authorize('tcp', 111, 111, authorized_address) + master_group.authorize('udp', 111, 111, authorized_address) + master_group.authorize('tcp', 2049, 2049, authorized_address) + master_group.authorize('udp', 2049, 2049, authorized_address) + master_group.authorize('tcp', 4242, 4242, authorized_address) + master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -417,8 +548,8 @@ def launch_cluster(conn, opts, cluster_name): existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print("ERROR: There are already instances running in group %s or %s" % + (master_group.name, slave_group.name), file=stderr) sys.exit(1) # Figure out Spark AMI @@ -431,12 +562,12 @@ def launch_cluster(conn, opts, cluster_name): additional_group_ids = [sg.id for sg in conn.get_all_security_groups() if opts.additional_security_group in (sg.name, sg.id)] - print "Launching instances..." + print("Launching instances...") try: image = conn.get_all_images(image_ids=[opts.ami])[0] except: - print >> stderr, "Could not find AMI " + opts.ami + print("Could not find AMI " + opts.ami, file=stderr) sys.exit(1) # Create block device mapping so that we can add EBS volumes if asked to. @@ -462,8 +593,8 @@ def launch_cluster(conn, opts, cluster_name): # Launch slaves if opts.spot_price is not None: # Launch spot instances with the requested price - print ("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) + print("Requesting %d slaves as spot instances with price $%.3f" % + (opts.slaves, opts.spot_price)) zones = get_zones(conn, opts) num_zones = len(zones) i = 0 @@ -482,11 +613,12 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 - print "Waiting for spot instances to be granted..." + print("Waiting for spot instances to be granted...") try: while True: time.sleep(10) @@ -499,24 +631,24 @@ def launch_cluster(conn, opts, cluster_name): if i in id_to_req and id_to_req[i].state == "active": active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: - print "All %d slaves granted" % opts.slaves + print("All %d slaves granted" % opts.slaves) reservations = conn.get_all_reservations(active_instance_ids) slave_nodes = [] for r in reservations: slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print("%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves)) except: - print "Canceling spot instance requests" + print("Canceling spot instance requests") conn.cancel_spot_instance_requests(my_req_ids) # Log a warning if any of these requests actually launched instances: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) running = len(master_nodes) + len(slave_nodes) if running: - print >> stderr, ("WARNING: %d instances are still running" % running) + print(("WARNING: %d instances are still running" % running), file=stderr) sys.exit(0) else: # Launch non-spot instances @@ -527,24 +659,30 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances - print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, - zone, slave_res.id) + print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( + s=num_slaves_this_zone, + plural_s=('' if num_slaves_this_zone == 1 else 's'), + z=zone, + r=slave_res.id)) i += 1 # Launch or resume masters if existing_masters: - print "Starting master..." + print("Starting master...") for inst in existing_masters: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -555,71 +693,92 @@ def launch_cluster(conn, opts, cluster_name): master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + master_res = image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + print("Launched master in %s, regid = %s" % (zone, master_res.id)) + + # This wait time corresponds to SPARK-4983 + print("Waiting for AWS to propagate instance metadata...") + time.sleep(15) + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) - # Give the instances descriptive names for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) -# Get the EC2 instances in an existing cluster if available. -# Returns a tuple of lists of EC2 instance objects for the masters and slaves +def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): + """ + Get the EC2 instances in an existing cluster if available. + Returns a tuple of lists of EC2 instance objects for the masters and slaves. + """ + print("Searching for existing cluster {c} in region {r}...".format( + c=cluster_name, r=opts.region)) + def get_instances(group_names): + """ + Get all non-terminated instances that belong to any of the provided security groups. -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - print "Searching for existing cluster " + cluster_name + "..." - reservations = conn.get_all_reservations() - master_nodes = [] - slave_nodes = [] - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for inst in active: - group_names = [g.name for g in inst.groups] - if (cluster_name + "-master") in group_names: - master_nodes.append(inst) - elif (cluster_name + "-slaves") in group_names: - slave_nodes.append(inst) - if any((master_nodes, slave_nodes)): - print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) - if master_nodes != [] or not die_on_error: - return (master_nodes, slave_nodes) - else: - if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" - else: - print >> sys.stderr, "ERROR: Could not find any existing cluster" + EC2 reservation filters and instance states are documented here: + http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options + """ + reservations = conn.get_all_reservations( + filters={"instance.group-name": group_names}) + instances = itertools.chain.from_iterable(r.instances for r in reservations) + return [i for i in instances if i.state not in ["shutting-down", "terminated"]] + + master_instances = get_instances([cluster_name + "-master"]) + slave_instances = get_instances([cluster_name + "-slaves"]) + + if any((master_instances, slave_instances)): + print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( + m=len(master_instances), + plural_m=('' if len(master_instances) == 1 else 's'), + s=len(slave_instances), + plural_s=('' if len(slave_instances) == 1 else 's'))) + + if not master_instances and die_on_error: + print("ERROR: Could not find a master for cluster {c} in region {r}.".format( + c=cluster_name, r=opts.region), file=sys.stderr) sys.exit(1) + return (master_instances, slave_instances) + # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. - - def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name + master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: - print "Generating cluster's SSH key on master..." + print("Generating cluster's SSH key on master...") key_setup = """ [ -f ~/.ssh/id_rsa ] || (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && @@ -627,31 +786,39 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): """ ssh(master, opts, key_setup) dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print "Transferring cluster's SSH key to slaves..." + print("Transferring cluster's SSH key to slaves...") for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) + slave_address = get_dns_name(slave, opts.private_ips) + print(slave_address) + ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] + 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') + # Clear SPARK_WORKER_INSTANCES if running on YARN + if opts.hadoop_major_version == "yarn": + opts.worker_instances = "" + # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten + print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) ssh( host=master, opts=opts, command="rm -rf spark-ec2" + " && " - + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, + b=opts.spark_ec2_git_branch) ) - print "Deploying files to master..." + print("Deploying files to master...") deploy_files( conn=conn, root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", @@ -661,35 +828,54 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): modules=modules ) - print "Running setup on master..." + if opts.deploy_root_dir is not None: + print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) + deploy_user_files( + root_dir=opts.deploy_root_dir, + opts=opts, + master_nodes=master_nodes + ) + + print("Running setup on master...") setup_spark_cluster(master, opts) - print "Done!" + print("Done!") def setup_spark_cluster(master, opts): ssh(master, opts, "chmod u+x spark-ec2/setup.sh") ssh(master, opts, "spark-ec2/setup.sh") - print "Spark standalone cluster started at http://%s:8080" % master + print("Spark standalone cluster started at http://%s:8080" % master) if opts.ganglia: - print "Ganglia started at http://%s:5080/ganglia" % master + print("Ganglia started at http://%s:5080/ganglia" % master) -def is_ssh_available(host, opts): +def is_ssh_available(host, opts, print_ssh_output=True): """ Check if SSH is available on a host. """ - try: - with open(os.devnull, 'w') as devnull: - ret = subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=devnull, - stderr=devnull - ) - return ret == 0 - except subprocess.CalledProcessError as e: - return False + s = subprocess.Popen( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order + ) + cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout + + if s.returncode != 0 and print_ssh_output: + # extra leading newline is for spacing in wait_for_cluster_state() + print(textwrap.dedent("""\n + Warning: SSH connection error. (This could be temporary.) + Host: {h} + SSH return code: {r} + SSH output: {o} + """).format( + h=host, + r=s.returncode, + o=cmd_output.strip() + )) + + return s.returncode == 0 def is_cluster_ssh_available(cluster_instances, opts): @@ -697,7 +883,8 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + dns_name = get_dns_name(i, opts.private_ips) + if not is_ssh_available(host=dns_name, opts=opts): return False else: return True @@ -727,7 +914,11 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): for i in cluster_instances: i.update() - statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) + max_batch = 100 + statuses = [] + for j in xrange(0, len(cluster_instances), max_batch): + batch = [i.id for i in cluster_instances[j:j + max_batch]] + statuses.extend(conn.get_all_instance_status(instance_ids=batch)) if cluster_state == 'ssh-ready': if all(i.state == 'running' for i in cluster_instances) and \ @@ -747,59 +938,78 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): sys.stdout.write("\n") end_time = datetime.now() - print "Cluster is now in '{s}' state. Waited {t} seconds.".format( + print("Cluster is now in '{s}' state. Waited {t} seconds.".format( s=cluster_state, t=(end_time - start_time).seconds - ) + )) # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2014-06-20 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, "c1.xlarge": 4, + "c3.large": 2, + "c3.xlarge": 2, "c3.2xlarge": 2, "c3.4xlarge": 2, "c3.8xlarge": 2, - "c3.large": 2, - "c3.xlarge": 2, + "c4.large": 0, + "c4.xlarge": 0, + "c4.2xlarge": 0, + "c4.4xlarge": 0, + "c4.8xlarge": 0, "cc1.4xlarge": 2, "cc2.8xlarge": 4, "cg1.4xlarge": 2, "cr1.8xlarge": 2, + "d2.xlarge": 3, + "d2.2xlarge": 6, + "d2.4xlarge": 12, + "d2.8xlarge": 24, "g2.2xlarge": 1, + "g2.8xlarge": 2, "hi1.4xlarge": 2, "hs1.8xlarge": 24, + "i2.xlarge": 1, "i2.2xlarge": 2, "i2.4xlarge": 4, "i2.8xlarge": 8, - "i2.xlarge": 1, - "m1.large": 2, - "m1.medium": 1, "m1.small": 1, + "m1.medium": 1, + "m1.large": 2, "m1.xlarge": 4, + "m2.xlarge": 1, "m2.2xlarge": 1, "m2.4xlarge": 2, - "m2.xlarge": 1, - "m3.2xlarge": 2, - "m3.large": 1, "m3.medium": 1, + "m3.large": 1, "m3.xlarge": 2, + "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, + "r3.large": 1, + "r3.xlarge": 1, "r3.2xlarge": 1, "r3.4xlarge": 1, "r3.8xlarge": 2, - "r3.large": 1, - "r3.xlarge": 1, "t1.micro": 0, + "t2.micro": 0, + "t2.small": 0, + "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] else: - print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type) + print("WARNING: Don't know number of disks on instance type %s; assuming 1" + % instance_type, file=stderr) return 1 @@ -811,7 +1021,7 @@ def get_num_disks(instance_type): # # root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) num_disks = get_num_disks(opts.instance_type) hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" @@ -828,14 +1038,21 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): if "." in opts.spark_version: # Pre-built Spark deploy spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) + tachyon_v = get_tachyon_version(spark_v) else: # Spark-only custom deploy spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) + tachyon_v = "" + print("Deploying Spark via git hash; Tachyon won't be set up") + modules = filter(lambda x: x != "tachyon", modules) + master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] + slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] + worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), + "master_list": '\n'.join(master_addresses), "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), + "slave_list": '\n'.join(slave_addresses), "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, @@ -843,8 +1060,9 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "swap": str(opts.swap), "modules": '\n'.join(modules), "spark_version": spark_v, + "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, + "spark_worker_instances": worker_instances_str, "spark_master_opts": opts.master_opts } @@ -887,6 +1105,23 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): shutil.rmtree(tmp_dir) +# Deploy a given local directory to a cluster, WITHOUT parameter substitution. +# Note that unlike deploy_files, this works for binary files. +# Also, it is up to the user to add (or not) the trailing slash in root_dir. +# Files are only deployed to the first master instance in the cluster. +# +# root_dir should be an absolute path. +def deploy_user_files(root_dir, opts, master_nodes): + active_master = get_dns_name(master_nodes[0], opts.private_ips) + command = [ + 'rsync', '-rv', + '-e', stringify_command(ssh_command(opts)), + "%s" % root_dir, + "%s@%s:/" % (opts.user, active_master) + ] + subprocess.check_call(command) + + def stringify_command(parts): if isinstance(parts, str): return parts @@ -896,6 +1131,7 @@ def stringify_command(parts): def ssh_args(opts): parts = ['-o', 'StrictHostKeyChecking=no'] + parts += ['-o', 'UserKnownHostsFile=/dev/null'] if opts.identity_file is not None: parts += ['-i', opts.identity_file] return parts @@ -919,13 +1155,13 @@ def ssh(host, opts, command): # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + + "Failed to SSH to remote host {0}.\n" + "Please check that you have provided the correct --identity-file and " "--key-pair parameters and try again.".format(host)) else: raise e - print >> stderr, \ - "Error executing remote command, retrying after 30 seconds: {0}".format(e) + print("Error executing remote command, retrying after 30 seconds: {0}".format(e), + file=stderr) time.sleep(30) tries = tries + 1 @@ -964,8 +1200,8 @@ def ssh_write(host, opts, command, arguments): elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: - print >> stderr, \ - "Error {0} while executing remote command, retrying after 30 seconds".format(status) + print("Error {0} while executing remote command, retrying after 30 seconds". + format(status), file=stderr) time.sleep(30) tries = tries + 1 @@ -981,12 +1217,26 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone +# Gets the IP address, taking into account the --private-ips flag +def get_ip_address(instance, private_ips=False): + ip = instance.ip_address if not private_ips else \ + instance.private_ip_address + return ip + + +# Gets the DNS name, taking into account the --private-ips flag +def get_dns_name(instance, private_ips=False): + dns = instance.public_dns_name if not private_ips else \ + instance.private_ip_address + return dns + + def real_main(): (opts, action, cluster_name) = parse_args() @@ -1003,14 +1253,67 @@ def real_main(): DeprecationWarning ) + if opts.identity_file is not None: + if not os.path.exists(opts.identity_file): + print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), + file=stderr) + sys.exit(1) + + file_mode = os.stat(opts.identity_file).st_mode + if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': + print("ERROR: The identity file must be accessible only by you.", file=stderr) + print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), + file=stderr) + sys.exit(1) + + if opts.instance_type not in EC2_INSTANCE_TYPES: + print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( + t=opts.instance_type), file=stderr) + + if opts.master_instance_type != "": + if opts.master_instance_type not in EC2_INSTANCE_TYPES: + print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( + t=opts.master_instance_type), file=stderr) + # Since we try instance types even if we can't resolve them, we check if they resolve first + # and, if they do, see if they resolve to the same virtualization type. + if opts.instance_type in EC2_INSTANCE_TYPES and \ + opts.master_instance_type in EC2_INSTANCE_TYPES: + if EC2_INSTANCE_TYPES[opts.instance_type] != \ + EC2_INSTANCE_TYPES[opts.master_instance_type]: + print("Error: spark-ec2 currently does not support having a master and slaves " + "with different AMI virtualization types.", file=stderr) + print("master instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) + print("slave instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) + sys.exit(1) + if opts.ebs_vol_num > 8: - print >> stderr, "ebs-vol-num cannot be greater than 8" + print("ebs-vol-num cannot be greater than 8", file=stderr) + sys.exit(1) + + # Prevent breaking ami_prefix (/, .git and startswith checks) + # Prevent forks with non spark-ec2 names for now. + if opts.spark_ec2_git_repo.endswith("/") or \ + opts.spark_ec2_git_repo.endswith(".git") or \ + not opts.spark_ec2_git_repo.startswith("https://github.com") or \ + not opts.spark_ec2_git_repo.endswith("spark-ec2"): + print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " + "Furthermore, we currently only support forks named spark-ec2.", file=stderr) + sys.exit(1) + + if not (opts.deploy_root_dir is None or + (os.path.isabs(opts.deploy_root_dir) and + os.path.isdir(opts.deploy_root_dir) and + os.path.exists(opts.deploy_root_dir))): + print("--deploy-root-dir must be an absolute path to a directory that exists " + "on the local file system", file=stderr) sys.exit(1) try: conn = ec2.connect_to_region(opts.region) except Exception as e: - print >> stderr, (e) + print((e), file=stderr) sys.exit(1) # Select an AZ at random if it was not specified. @@ -1019,7 +1322,7 @@ def real_main(): if action == "launch": if opts.slaves <= 0: - print >> sys.stderr, "ERROR: You have to start at least 1 slave" + print("ERROR: You have to start at least 1 slave", file=sys.stderr) sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -1034,26 +1337,27 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": - print "Are you sure you want to destroy the cluster %s?" % cluster_name - print "The following instances will be terminated:" (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name - msg = "ALL DATA ON ALL NODES WILL BE LOST!!\nDestroy cluster %s (y/N): " % cluster_name + if any(master_nodes + slave_nodes): + print("The following instances will be terminated:") + for inst in master_nodes + slave_nodes: + print("> %s" % get_dns_name(inst, opts.private_ips)) + print("ALL DATA ON ALL NODES WILL BE LOST!!") + + msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) response = raw_input(msg) if response == "y": - print "Terminating master..." + print("Terminating master...") for inst in master_nodes: inst.terminate() - print "Terminating slaves..." + print("Terminating slaves...") for inst in slave_nodes: inst.terminate() # Delete security groups as well if opts.delete_groups: - print "Deleting security groups (this will take some time)..." group_names = [cluster_name + "-master", cluster_name + "-slaves"] wait_for_cluster_state( conn=conn, @@ -1061,15 +1365,16 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='terminated' ) + print("Deleting security groups (this will take some time)...") attempt = 1 while attempt <= 3: - print "Attempt %d" % attempt + print("Attempt %d" % attempt) groups = [g for g in conn.get_all_security_groups() if g.name in group_names] success = True # Delete individual rules in all groups before deleting groups to # remove dependencies between them for group in groups: - print "Deleting rules in security group " + group.name + print("Deleting rules in security group " + group.name) for rule in group.rules: for grant in rule.grants: success &= group.revoke(ip_protocol=rule.ip_protocol, @@ -1082,11 +1387,12 @@ def real_main(): time.sleep(30) # Yes, it does have to be this long :-( for group in groups: try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name + # It is needed to use group_id to make it work with VPC + conn.delete_security_group(group_id=group.id) + print("Deleted security group %s" % group.name) except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group " + group.name + print("Failed to delete security group %s" % group.name) # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails @@ -1096,18 +1402,21 @@ def real_main(): attempt += 1 if not success: - print "Failed to delete all security groups after 3 tries." - print "Try re-running in a few minutes." + print("Failed to delete all security groups after 3 tries.") + print("Try re-running in a few minutes.") elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + if not master_nodes[0].public_dns_name and not opts.private_ips: + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") + else: + master = get_dns_name(master_nodes[0], opts.private_ips) + print("Logging into master " + master + "...") + proxy_opt = [] + if opts.proxy_port is not None: + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "reboot-slaves": response = raw_input( @@ -1117,15 +1426,18 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Rebooting slaves..." + print("Rebooting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - print "Rebooting " + inst.id + print("Rebooting " + inst.id) inst.reboot() elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name + if not master_nodes[0].public_dns_name and not opts.private_ips: + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") + else: + print(get_dns_name(master_nodes[0], opts.private_ips)) elif action == "stop": response = raw_input( @@ -1138,11 +1450,11 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Stopping master..." + print("Stopping master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.stop() - print "Stopping slaves..." + print("Stopping slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: if inst.spot_instance_request_id: @@ -1152,11 +1464,11 @@ def real_main(): elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print "Starting slaves..." + print("Starting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - print "Starting master..." + print("Starting master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -1166,18 +1478,29 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='ssh-ready' ) + + # Determine types of running instances + existing_master_type = master_nodes[0].instance_type + existing_slave_type = slave_nodes[0].instance_type + # Setting opts.master_instance_type to the empty string indicates we + # have the same instance type for the master and the slaves + if existing_master_type == existing_slave_type: + existing_master_type = "" + opts.master_instance_type = existing_master_type + opts.instance_type = existing_slave_type + setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: - print >> stderr, "Invalid action: %s" % action + print("Invalid action: %s" % action, file=stderr) sys.exit(1) def main(): try: real_main() - except UsageError, e: - print >> stderr, "\nError:\n", e + except UsageError as e: + print("\nError:\n", e, file=stderr) sys.exit(1) diff --git a/examples/pom.xml b/examples/pom.xml index 8caad2bc2e27..e6884b09dca9 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -90,6 +90,17 @@ org.apache.spark spark-streaming-zeromq_${scala.binary.version} ${project.version} + + + org.spark-project.protobuf + protobuf-java + + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} org.apache.hbase @@ -234,11 +245,12 @@ org.apache.commons commons-math3 + provided com.twitter algebird-core_${scala.binary.version} - 0.8.1 + 0.9.0 org.scalacheck @@ -262,6 +274,22 @@ com.ning compress-lzf + + commons-cli + commons-cli + + + commons-codec + commons-codec + + + commons-lang + commons-lang + + + commons-logging + commons-logging + io.netty netty @@ -270,10 +298,22 @@ jline jline + + net.jpountz.lz4 + lz4 + org.apache.cassandra.deps avro + + org.apache.commons + commons-math3 + + + org.apache.thrift + libthrift + @@ -281,6 +321,17 @@ scopt_${scala.binary.version} 3.2.0 + + + + org.scala-lang + scala-library + provided + + @@ -322,12 +373,6 @@ - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - @@ -350,51 +395,7 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - - - - scala-2.10 - - !scala-2.11 - - - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-scala-sources - generate-sources - - add-source - - - - src/main/scala - scala-2.10/src/main/scala - scala-2.10/src/main/java - - - - - - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 0fbee6e43360..9bbc14ea4087 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -34,8 +34,8 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating model selection using CrossValidator. @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,14 +112,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = cvModel.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java new file mode 100644 index 000000000000..a377694507d2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.Classifier; +import org.apache.spark.ml.classification.ClassificationModel; +import org.apache.spark.ml.param.IntParam; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.util.Identifiable$; +import org.apache.spark.mllib.linalg.BLAS; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}. + * + * Run with + *
    + * bin/run-example ml.JavaDeveloperApiExample
    + * 
    + */ +public class JavaDeveloperApiExample { + + public static void main(String[] args) throws Exception { + SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // Prepare training data. + List localTraining = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + MyJavaLogisticRegressionModel model = lr.fit(training); + + // Prepare test data. + List localTest = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame results = model.transform(test); + double sumPredictions = 0; + for (Row r : results.select("features", "label", "prediction").collect()) { + sumPredictions += r.getDouble(2); + } + if (sumPredictions != 0.0) { + throw new Exception("MyJavaLogisticRegression predicted something other than 0," + + " even though all weights are 0!"); + } + + jsc.stop(); + } +} + +/** + * Example of defining a type of {@link Classifier}. + * + * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to + * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. + * However, this should still compile and run successfully. + */ +class MyJavaLogisticRegression + extends Classifier { + + public MyJavaLogisticRegression() { + init(); + } + + public MyJavaLogisticRegression(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; + } + + /** + * Param for max number of iterations + *

    + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + */ + IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); + + int getMaxIter() { return (Integer) getOrDefault(maxIter); } + + private void init() { + setMaxIter(100); + } + + // The parameter setter is in this class since it should return type MyJavaLogisticRegression. + MyJavaLogisticRegression setMaxIter(int value) { + return (MyJavaLogisticRegression) set(maxIter, value); + } + + // This method is used by fit(). + // In Java, we have to make it public since Java does not understand Scala's protected modifier. + public MyJavaLogisticRegressionModel train(DataFrame dataset) { + // Extract columns from data using helper method. + JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); + + // Do learning to estimate the weight vector. + int numFeatures = oldDataset.take(1).get(0).features().size(); + Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + + // Create a model, and return it. + return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + } + + @Override + public MyJavaLogisticRegression copy(ParamMap extra) { + return defaultCopy(extra); + } +} + +/** + * Example of defining a type of {@link ClassificationModel}. + * + * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to + * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. + * However, this should still compile and run successfully. + */ +class MyJavaLogisticRegressionModel + extends ClassificationModel { + + private Vector weights_; + public Vector weights() { return weights_; } + + public MyJavaLogisticRegressionModel(String uid, Vector weights) { + this.uid_ = uid; + this.weights_ = weights; + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; + } + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public Vector predictRaw(Vector features) { + double margin = BLAS.dot(features, weights_); + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + return Vectors.dense(-margin, margin); + } + + /** + * Number of classes the label can take. 2 indicates binary classification. + */ + public int numClasses() { return 2; } + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + *

    + * This is used for the defaul implementation of [[transform()]]. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + @Override + public MyJavaLogisticRegressionModel copy(ParamMap extra) { + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java new file mode 100644 index 000000000000..be2bf0c7b465 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.KMeansModel; +import org.apache.spark.ml.clustering.KMeans; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +/** + * An example demonstrating a k-means clustering. + * Run with + *

    + * bin/run-example ml.JavaSimpleParamsExample  
    + * 
    + */ +public class JavaKMeansExample { + + private static class ParsePoint implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + if (args.length != 2) { + System.err.println("Usage: ml.JavaKMeansExample "); + System.exit(1); + } + String inputFile = args[0]; + int k = Integer.parseInt(args[1]); + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a k-means model + KMeans kmeans = new KMeans() + .setK(k); + KMeansModel model = kmeans.fit(dataset); + + // Shows the result + Vector[] centers = model.clusterCenters(); + System.out.println("Cluster Centers: "); + for (Vector center: centers) { + System.out.println(center); + } + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java new file mode 100644 index 000000000000..e7f2f6f61507 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.commons.cli.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.ml.classification.OneVsRestModel; +import org.apache.spark.ml.util.MetadataUtils; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + *
    + * bin/run-example ml.JavaOneVsRestExample [options]
    + * 
    + */ +public class JavaOneVsRestExample { + + private static class Params { + String input; + String testInput = null; + Integer maxIter = 100; + double tol = 1E-6; + boolean fitIntercept = true; + Double regParam = null; + Double elasticNetParam = null; + double fracTest = 0.2; + } + + public static void main(String[] args) { + // parse the arguments + Params params = parse(args); + SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // configure the base classifier + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept); + + if (params.regParam != null) { + classifier.setRegParam(params.regParam); + } + if (params.elasticNetParam != null) { + classifier.setElasticNetParam(params.elasticNetParam); + } + + // instantiate the One Vs Rest Classifier + OneVsRest ovr = new OneVsRest().setClassifier(classifier); + + String input = params.input; + RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); + RDD train; + RDD test; + + // compute the train/ test split: if testInput is not provided use part of input + String testInput = params.testInput; + if (testInput != null) { + train = inputData; + // compute the number of features in the training set. + int numFeatures = inputData.first().features().size(); + test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + } else { + double f = params.fracTest; + RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + train = tmp[0]; + test = tmp[1]; + } + + // train the multiclass model + DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); + OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + + // score the model on test data + DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); + DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + .select("prediction", "label"); + + // obtain metrics + MulticlassMetrics metrics = new MulticlassMetrics(predictions); + StructField predictionColSchema = predictions.schema().apply("prediction"); + Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); + + // compute the false positive rate per label + StringBuilder results = new StringBuilder(); + results.append("label\tfpr\n"); + for (int label = 0; label < numClasses; label++) { + results.append(label); + results.append("\t"); + results.append(metrics.falsePositiveRate((double) label)); + results.append("\n"); + } + + Matrix confusionMatrix = metrics.confusionMatrix(); + // output the Confusion Matrix + System.out.println("Confusion Matrix"); + System.out.println(confusionMatrix); + System.out.println(); + System.out.println(results); + + jsc.stop(); + } + + private static Params parse(String[] args) { + Options options = generateCommandlineOptions(); + CommandLineParser parser = new PosixParser(); + Params params = new Params(); + + try { + CommandLine cmd = parser.parse(options, args); + String value; + if (cmd.hasOption("input")) { + params.input = cmd.getOptionValue("input"); + } + if (cmd.hasOption("maxIter")) { + value = cmd.getOptionValue("maxIter"); + params.maxIter = Integer.parseInt(value); + } + if (cmd.hasOption("tol")) { + value = cmd.getOptionValue("tol"); + params.tol = Double.parseDouble(value); + } + if (cmd.hasOption("fitIntercept")) { + value = cmd.getOptionValue("fitIntercept"); + params.fitIntercept = Boolean.parseBoolean(value); + } + if (cmd.hasOption("regParam")) { + value = cmd.getOptionValue("regParam"); + params.regParam = Double.parseDouble(value); + } + if (cmd.hasOption("elasticNetParam")) { + value = cmd.getOptionValue("elasticNetParam"); + params.elasticNetParam = Double.parseDouble(value); + } + if (cmd.hasOption("testInput")) { + value = cmd.getOptionValue("testInput"); + params.testInput = value; + } + if (cmd.hasOption("fracTest")) { + value = cmd.getOptionValue("fracTest"); + params.fracTest = Double.parseDouble(value); + } + + } catch (ParseException e) { + printHelpAndQuit(options); + } + return params; + } + + @SuppressWarnings("static") + private static Options generateCommandlineOptions() { + Option input = OptionBuilder.withArgName("input") + .hasArg() + .isRequired() + .withDescription("input path to labeled examples. This path must be specified") + .create("input"); + Option testInput = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("input path to test examples") + .create("testInput"); + Option fracTest = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("fraction of data to hold out for testing." + + " If given option testInput, this option is ignored. default: 0.2") + .create("fracTest"); + Option maxIter = OptionBuilder.withArgName("maxIter") + .hasArg() + .withDescription("maximum number of iterations for Logistic Regression. default:100") + .create("maxIter"); + Option tol = OptionBuilder.withArgName("tol") + .hasArg() + .withDescription("the convergence tolerance of iterations " + + "for Logistic Regression. default: 1E-6") + .create("tol"); + Option fitIntercept = OptionBuilder.withArgName("fitIntercept") + .hasArg() + .withDescription("fit intercept for logistic regression. default true") + .create("fitIntercept"); + Option regParam = OptionBuilder.withArgName( "regParam" ) + .hasArg() + .withDescription("the regularization parameter for Logistic Regression.") + .create("regParam"); + Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) + .hasArg() + .withDescription("the ElasticNet mixing parameter for Logistic Regression.") + .create("elasticNetParam"); + + Options options = new Options() + .addOption(input) + .addOption(testInput) + .addOption(fracTest) + .addOption(maxIter) + .addOption(tol) + .addOption(fitIntercept) + .addOption(regParam) + .addOption(elasticNetParam); + + return options; + } + + private static void printHelpAndQuit(Options options) { + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp("JavaOneVsRestExample", options); + System.exit(-1); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index eaaa344be49c..94beeced3d47 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -29,8 +29,8 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -71,41 +71,42 @@ public static void main(String[] args) { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); + System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + double thresholds[] = {0.45, 0.55}; + paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.scoreCol().w("probability")); // Change output column name + paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); - System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); + System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. List localTest = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerTempTable("results"); - DataFrame results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); - for (Row r: results.collect()) { + // LogisticRegressionModel.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 82d665a3e138..54738813d001 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -30,8 +30,8 @@ import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -66,7 +66,7 @@ public static void main(String[] args) { .setOutputCol("features"); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01); + .setRegParam(0.001); Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); @@ -77,16 +77,17 @@ public static void main(String[] args) { List localTest = Lists.newArrayList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), + new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = model.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java new file mode 100644 index 000000000000..36baf5868736 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +/** + * Java example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt + */ +public class JavaFPGrowthExample { + + public static void main(String[] args) { + String inputFile; + double minSupport = 0.3; + int numPartition = -1; + if (args.length < 1) { + System.err.println( + "Usage: JavaFPGrowth [minSupport] [numPartition]"); + System.exit(1); + } + inputFile = args[0]; + if (args.length >= 2) { + minSupport = Double.parseDouble(args[1]); + } + if (args.length >= 3) { + numPartition = Integer.parseInt(args[2]); + } + + SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD> transactions = sc.textFile(inputFile).map( + new Function>() { + @Override + public ArrayList call(String s) { + return Lists.newArrayList(s.split(" ")); + } + } + ); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartition) + .run(transactions); + + for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java new file mode 100644 index 000000000000..fd53c81cc497 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java new file mode 100644 index 000000000000..6c6f9768f015 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple3; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaPowerIterationClusteringExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + @SuppressWarnings("unchecked") + JavaRDD> similarities = sc.parallelize(Lists.newArrayList( + new Tuple3(0L, 1L, 0.9), + new Tuple3(1L, 2L, 0.9), + new Tuple3(2L, 3L, 0.9), + new Tuple3(3L, 4L, 0.1), + new Tuple3(4L, 5L, 0.9))); + + PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); + PowerIterationClusteringModel model = pic.run(similarities); + + for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8defb769ffaa..afee279ec32b 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -94,17 +94,17 @@ public String call(Row row) { System.out.println("=== Data source: Parquet File ==="); // DataFrames can be saved as parquet files, maintaining the schema information. - schemaPeople.saveAsParquetFile("people.parquet"); + schemaPeople.write().parquet("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -133,8 +133,8 @@ public String call(Row row) { // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java new file mode 100644 index 000000000000..bab9f2478e77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import kafka.serializer.StringDecoder; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka.KafkaUtils; +import org.apache.spark.streaming.Durations; + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + */ + +public final class JavaDirectKafkaWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: DirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a list of one or more kafka topics to consume from\n\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + String brokers = args[0]; + String topics = args[1]; + + // Create context with 2 second batch interval + SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); + + HashSet topicsSet = new HashSet(Arrays.asList(topics.split(","))); + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", brokers); + + // Create direct kafka stream with brokers and topics + JavaPairInputDStream messages = KafkaUtils.createDirectStream( + jssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicsSet + ); + + // Get the lines, split them into words, count the words and print + JavaDStream lines = messages.map(new Function, String>() { + @Override + public String call(Tuple2 tuple2) { + return tuple2._2(); + } + }); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + wordCounts.print(); + + // Start the computation + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java new file mode 100644 index 000000000000..e63697a79f23 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +/** Java Bean class to be used with the example JavaSqlNetworkWordCount. */ +public class JavaRecord implements java.io.Serializable { + private String word; + + public String getWord() { + return word; + } + + public void setWord(String word) { + this.word = word; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java new file mode 100644 index 000000000000..46562ddbbcb5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.regex.Pattern; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + * network every second. + * + * Usage: JavaSqlNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example org.apache.spark.examples.streaming.JavaSqlNetworkWordCount localhost 9999` + */ + +public final class JavaSqlNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaSqlNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + + // Create a JavaReceiverInputDStream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + // Note that no duplication in storage level only for running locally. + // Replication necessary in distributed scenario for fault tolerance. + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + // Convert RDDs of the words DStream to DataFrame and run SQL query + words.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaRDD rdd, Time time) { + SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame + JavaRDD rowRDD = rdd.map(new Function() { + public JavaRecord call(String word) { + JavaRecord record = new JavaRecord(); + record.setWord(word); + return record; + } + }); + DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); + + // Register as table + wordsDataFrame.registerTempTable("words"); + + // Do word count on table using SQL and print it + DataFrame wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word"); + System.out.println("========= " + time + "========="); + wordCountsDataFrame.show(); + return null; + } + }); + + ssc.start(); + ssc.awaitTermination(); + } +} + +/** Lazily instantiated singleton instance of SQLContext */ +class JavaSQLContextSingleton { + static private transient SQLContext instance = null; + static public SQLContext getInstance(SparkContext sparkContext) { + if (instance == null) { + instance = new SQLContext(sparkContext); + } + return instance; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java new file mode 100644 index 000000000000..99b63a2590ae --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every + * second starting with initial value of word count. + * Usage: JavaStatefulNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. + *

    + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example + * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999` + */ +public class JavaStatefulNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaStatefulNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Update the cumulative count function + final Function2, Optional, Optional> updateFunction = + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + Integer newSum = state.or(0); + for (Integer value : values) { + newSum += value; + } + return Optional.of(newSum); + } + }; + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint("."); + + // Initial RDD input to updateStateByKey + @SuppressWarnings("unchecked") + List> tuples = Arrays.asList(new Tuple2("hello", 1), + new Tuple2("world", 1)); + JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples); + + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); + + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + JavaPairDStream wordsDstream = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + // This will give a Dstream made of state (which is the cumulative count of the words) + JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, + new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); + + stateDstream.print(); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 70b6146e39a8..1c3a787bd0e9 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -21,7 +21,8 @@ This example requires numpy (http://www.numpy.org/) """ -from os.path import realpath +from __future__ import print_function + import sys import numpy as np @@ -57,9 +58,9 @@ def update(i, vec, mat, ratings): Usage: als [M] [U] [F] [iterations] [partitions]" """ - print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an + print("""WARN: This is a naive implementation of ALS and is given as an example. Please use the ALS method found in pyspark.mllib.recommendation for more - conventional use.""" + conventional use.""", file=sys.stderr) sc = SparkContext(appName="PythonALS") M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 @@ -68,8 +69,8 @@ def update(i, vec, mat, ratings): ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5 partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2 - print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \ - (M, U, F, ITERATIONS, partitions) + print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % + (M, U, F, ITERATIONS, partitions)) R = matrix(rand(M, F)) * matrix(rand(U, F).T) ms = matrix(rand(M, F)) @@ -95,7 +96,7 @@ def update(i, vec, mat, ratings): usb = sc.broadcast(us) error = rmse(R, ms, us) - print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error + print("Iteration %d:" % i) + print("\nRMSE: %5.4f\n" % error) sc.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4626bbb7e3b0..da368ac628a4 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,9 +15,12 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext +from functools import reduce """ Read data file users.avro in local Spark distro: @@ -49,7 +52,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: avro_inputformat [reader_schema_file] Run with example jar: @@ -57,7 +60,7 @@ /path/to/examples/avro_inputformat.py [reader_schema_file] Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -77,6 +80,6 @@ conf=conf) output = avro_rdd.map(lambda x: x[0]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 05f34b74df45..93ca0cfcc930 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, """ + print(""" Usage: cassandra_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/cassandra_inputformat.py Assumes you have some data in Cassandra already, running on , in and - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -77,6 +79,6 @@ conf=conf) output = cass_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index d144539e58b8..5d643eac92f9 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -46,7 +48,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: cassandra_outputformat Run with example jar: @@ -60,7 +62,7 @@ ... fname text, ... lname text ... ); - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3b16010f1cb9..c5ae5d043b8e 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -15,7 +15,10 @@ # limitations under the License. # +from __future__ import print_function + import sys +import json from pyspark import SparkContext @@ -25,43 +28,49 @@ hbase(main):016:0> create 'test', 'f1' 0 row(s) in 1.0430 seconds -hbase(main):017:0> put 'test', 'row1', 'f1', 'value1' +hbase(main):017:0> put 'test', 'row1', 'f1:a', 'value1' 0 row(s) in 0.0130 seconds -hbase(main):018:0> put 'test', 'row2', 'f1', 'value2' +hbase(main):018:0> put 'test', 'row1', 'f1:b', 'value2' 0 row(s) in 0.0030 seconds -hbase(main):019:0> put 'test', 'row3', 'f1', 'value3' +hbase(main):019:0> put 'test', 'row2', 'f1', 'value3' 0 row(s) in 0.0050 seconds -hbase(main):020:0> put 'test', 'row4', 'f1', 'value4' +hbase(main):020:0> put 'test', 'row3', 'f1', 'value4' 0 row(s) in 0.0110 seconds hbase(main):021:0> scan 'test' ROW COLUMN+CELL - row1 column=f1:, timestamp=1401883411986, value=value1 - row2 column=f1:, timestamp=1401883415212, value=value2 - row3 column=f1:, timestamp=1401883417858, value=value3 - row4 column=f1:, timestamp=1401883420805, value=value4 + row1 column=f1:a, timestamp=1401883411986, value=value1 + row1 column=f1:b, timestamp=1401883415212, value=value2 + row2 column=f1:, timestamp=1401883417858, value=value3 + row3 column=f1:, timestamp=1401883420805, value=value4 4 row(s) in 0.0240 seconds """ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: hbase_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ - /path/to/examples/hbase_inputformat.py
    + /path/to/examples/hbase_inputformat.py
    [] Assumes you have some data in HBase already, running on , in
    - """ + optionally, you can specify parent znode for your hbase cluster - + """, file=sys.stderr) exit(-1) host = sys.argv[1] table = sys.argv[2] sc = SparkContext(appName="HBaseInputFormat") + # Other options for configuring scan behavior are available. More information available at + # https://github.com/apache/hbase/blob/master/hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormat.java conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + if len(sys.argv) > 3: + conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], + "hbase.mapreduce.inputtable": table} keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" @@ -72,8 +81,10 @@ keyConverter=keyConv, valueConverter=valueConv, conf=conf) + hbase_rdd = hbase_rdd.flatMapValues(lambda v: v.split("\n")).mapValues(json.loads) + output = hbase_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index abb425b1f886..9e5641789a97 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -40,7 +42,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: hbase_outputformat
    Run with example jar: @@ -48,7 +50,7 @@ /path/to/examples/hbase_outputformat.py Assumes you have created
    with column family in HBase running on already - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 86ef6f32c84e..0ea7cfb7025a 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -22,6 +22,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -47,12 +48,12 @@ def closestPoint(p, centers): if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given + print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on - how to use MLlib's KMeans implementation.""" + how to use MLlib's KMeans implementation.""", file=sys.stderr) sc = SparkContext(appName="PythonKMeans") lines = sc.textFile(sys.argv[1]) @@ -67,15 +68,15 @@ def closestPoint(p, centers): closest = data.map( lambda p: (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1])) newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() + lambda st: (st[0], st[1][0] / st[1][1])).collect() - tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints) - for (x, y) in newPoints: - kPoints[x] = y + for (iK, p) in newPoints: + kPoints[iK] = p - print "Final centers: " + str(kPoints) + print("Final centers: " + str(kPoints)) sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 3aa56b052816..b318b7d87bfd 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -22,10 +22,8 @@ In practice, one may prefer to use the LogisticRegression algorithm in MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py. """ +from __future__ import print_function -from collections import namedtuple -from math import exp -from os.path import realpath import sys import numpy as np @@ -42,19 +40,19 @@ def readPointBatch(iterator): strs = list(iterator) matrix = np.zeros((len(strs), D + 1)) - for i in xrange(len(strs)): - matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ') + for i, s in enumerate(strs): + matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ') return [matrix] if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is + print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py - to see how MLlib's implementation is used.""" + to see how MLlib's implementation is used.""", file=sys.stderr) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() @@ -62,7 +60,7 @@ def readPointBatch(iterator): # Initialize w to a random value w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) + print("Initial w: " + str(w)) # Compute logistic regression gradient for a matrix of data points def gradient(matrix, w): @@ -76,9 +74,9 @@ def add(x, y): return x for i in range(iterations): - print "On iteration %i" % (i + 1) + print("On iteration %i" % (i + 1)) w -= points.map(lambda m: gradient(m, w)).reduce(add) - print "Final w: " + str(w) + print("Final w: " + str(w)) sc.stop() diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 000000000000..f0ca97c72494 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 000000000000..6446f0fe5eea --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py new file mode 100644 index 000000000000..150dadd42f33 --- /dev/null +++ b/examples/src/main/python/ml/kmeans_example.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys +import re + +import numpy as np +from pyspark import SparkContext +from pyspark.ml.clustering import KMeans, KMeansModel +from pyspark.mllib.linalg import VectorUDT, _convert_to_vector +from pyspark.sql import SQLContext +from pyspark.sql.types import Row, StructField, StructType + +""" +A simple example demonstrating a k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/kmeans_example.py + +This example requires NumPy (http://www.numpy.org/). +""" + + +def parseVector(line): + array = np.array([float(x) for x in line.split(' ')]) + return _convert_to_vector(array) + + +if __name__ == "__main__": + + FEATURES_COL = "features" + + if len(sys.argv) != 3: + print("Usage: kmeans_example.py ", file=sys.stderr) + exit(-1) + path = sys.argv[1] + k = sys.argv[2] + + sc = SparkContext(appName="PythonKMeansExample") + sqlContext = SQLContext(sc) + + lines = sc.textFile(path) + data = lines.map(parseVector) + row_rdd = data.map(lambda x: Row(x)) + schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)]) + df = sqlContext.createDataFrame(row_rdd, schema) + + kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL) + model = kmeans.fit(df) + centers = model.clusterCenters() + + print("Cluster Centers: ") + for center in centers: + print(center) + + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 000000000000..55afe1b207fe --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 000000000000..c7730e1bfacd --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 000000000000..2d6d115d54d0 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index c7df3d7b7476..b4f06bf88874 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -15,11 +15,13 @@ # limitations under the License. # +from __future__ import print_function + from pyspark import SparkContext -from pyspark.sql import SQLContext, Row from pyspark.ml import Pipeline -from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext """ @@ -33,47 +35,37 @@ if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. - LabeledDocument = Row('id', 'text', 'label') - training = sqlCtx.inferSchema( - sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) - .map(lambda x: LabeledDocument(*x))) + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. - tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - hashingTF = HashingTF() \ - .setInputCol(tokenizer.getOutputCol()) \ - .setOutputCol("features") - lr = LogisticRegression() \ - .setMaxIter(10) \ - .setRegParam(0.01) - pipeline = Pipeline() \ - .setStages([tokenizer, hashingTF, lr]) + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.001) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) # Prepare test documents, which are unlabeled. - Document = Row('id', 'text') - test = sqlCtx.inferSchema( - sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) - .map(lambda x: Document(*x))) + Document = Row("id", "text") + test = sc.parallelize([(4, "spark i j k"), + (5, "l m n"), + (6, "spark hadoop spark"), + (7, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() # Make predictions on test documents and print columns of interest. prediction = model.transform(test) - prediction.registerTempTable("prediction") - selected = sqlCtx.sql("SELECT id, text, prediction from prediction") + selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 4218eca822a9..0e13546b88e6 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -18,6 +18,7 @@ """ Correlations using MLlib. """ +from __future__ import print_function import sys @@ -29,7 +30,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: correlations ()" + print("Usage: correlations ()", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: @@ -41,20 +42,20 @@ points = MLUtils.loadLibSVMFile(sc, filepath)\ .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) - print - print 'Summary of data file: ' + filepath - print '%d data points' % points.count() + print() + print('Summary of data file: ' + filepath) + print('%d data points' % points.count()) # Statistics (correlations) - print - print 'Correlation (%s) between label and each feature' % corrType - print 'Feature\tCorrelation' + print() + print('Correlation (%s) between label and each feature' % corrType) + print('Feature\tCorrelation') numFeatures = points.take(1)[0].features.size labelRDD = points.map(lambda lp: lp.label) for i in range(numFeatures): featureRDD = points.map(lambda lp: lp.features[i]) corr = Statistics.corr(labelRDD, featureRDD, corrType) - print '%d\t%g' % (i, corr) - print + print('%d\t%g' % (i, corr)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index b5a70db2b9a3..e23ecc0c5d30 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -19,6 +19,7 @@ An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ +from __future__ import print_function import os import sys @@ -32,31 +33,31 @@ def summarize(dataset): - print "schema: %s" % dataset.schema().json() + print("schema: %s" % dataset.schema().json()) labels = dataset.map(lambda r: r.label) - print "label average: %f" % labels.mean() + print("label average: %f" % labels.mean()) features = dataset.map(lambda r: r.features) summary = Statistics.colStats(features) - print "features average: %r" % summary.mean() + print("features average: %r" % summary.mean()) if __name__ == "__main__": if len(sys.argv) > 2: - print >> sys.stderr, "Usage: dataset_example.py " + print("Usage: dataset_example.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print "Save dataset as a Parquet file to %s." % tempdir + print("Save dataset as a Parquet file to %s." % tempdir) dataset0.saveAsParquetFile(tempdir) - print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + print("Load it back and summarize it again.") + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index fccabd841b13..513ed8fd5145 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import numpy import os @@ -83,18 +84,17 @@ def reindexClassLabels(data): numClasses = len(classCounts) # origToNewLabels: class --> index in 0,...,numClasses-1 if (numClasses < 2): - print >> sys.stderr, \ - "Dataset for classification should have at least 2 classes." + \ - " The given dataset had only %d classes." % numClasses + print("Dataset for classification should have at least 2 classes." + " The given dataset had only %d classes." % numClasses, file=sys.stderr) exit(1) origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - print "numClasses = %d" % numClasses - print "Per-class example fractions, counts:" - print "Class\tFrac\tCount" + print("numClasses = %d" % numClasses) + print("Per-class example fractions, counts:") + print("Class\tFrac\tCount") for c in sortedClasses: frac = classCounts[c] / (numExamples + 0.0) - print "%g\t%g\t%d" % (c, frac, classCounts[c]) + print("%g\t%g\t%d" % (c, frac, classCounts[c])) if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): return (data, origToNewLabels) @@ -105,8 +105,7 @@ def reindexClassLabels(data): def usage(): - print >> sys.stderr, \ - "Usage: decision_tree_runner [libsvm format data filepath]" + print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) exit(1) @@ -133,13 +132,13 @@ def usage(): model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. - print "Trained DecisionTree for classification:" - print " Model numNodes: %d" % model.numNodes() - print " Model depth: %d" % model.depth() - print " Training accuracy: %g" % getAccuracy(model, reindexedData) + print("Trained DecisionTree for classification:") + print(" Model numNodes: %d" % model.numNodes()) + print(" Model depth: %d" % model.depth()) + print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) if model.numNodes() < 20: - print model.toDebugString() + print(model.toDebugString()) else: - print model + print(model) sc.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index a2cd626c9f19..2cb8010cdc07 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -18,7 +18,8 @@ """ A Gaussian Mixture Model clustering program using MLlib. """ -import sys +from __future__ import print_function + import random import argparse import numpy as np @@ -59,7 +60,7 @@ def parseVector(line): model = GaussianMixture.train(data, args.k, args.convergenceTol, args.maxIterations, args.seed) for i in range(args.k): - print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, - "sigma = ", model.gaussians[i].sigma.toArray()) - print ("Cluster labels (first 100): ", model.predict(data).take(100)) + print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray())) + print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py index e647773ad906..781bd61c9d2b 100644 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -18,6 +18,7 @@ """ Gradient boosted Trees classification and regression using MLlib. """ +from __future__ import print_function import sys @@ -34,7 +35,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification ensemble model:') @@ -49,7 +50,7 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \ + testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression ensemble model:') @@ -58,7 +59,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: gradient_boosted_trees" + print("Usage: gradient_boosted_trees", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonGradientBoostedTrees") diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 2eeb1abeeb12..002fc7579964 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -34,12 +35,13 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) k = int(sys.argv[2]) model = KMeans.train(data, k) - print "Final centers: " + str(model.clusterCenters) + print("Final centers: " + str(model.clusterCenters)) + print("Total Cost: " + str(model.computeCost(data))) sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 8cae27fc4a52..d4f1d34e2d8c 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -20,11 +20,10 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function -from math import exp import sys -import numpy as np from pyspark import SparkContext from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import LogisticRegressionWithSGD @@ -42,12 +41,12 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) model = LogisticRegressionWithSGD.train(points, iterations) - print "Final weights: " + str(model.weights) - print "Final intercept: " + str(model.intercept) + print("Final weights: " + str(model.weights)) + print("Final intercept: " + str(model.intercept)) sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py index d3c24f766432..4cfdad868c66 100755 --- a/examples/src/main/python/mllib/random_forest_example.py +++ b/examples/src/main/python/mllib/random_forest_example.py @@ -22,6 +22,7 @@ For information on multiclass classification, please refer to the decision_tree_runner.py example. """ +from __future__ import print_function import sys @@ -43,7 +44,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification forest model:') @@ -62,8 +63,8 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\ - / float(testData.count()) + testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ + .sum() / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) @@ -71,7 +72,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: random_forest_example" + print("Usage: random_forest_example", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonRandomForestExample") diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 1e8892741e71..729bae30b152 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -18,6 +18,7 @@ """ Randomly generated RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: random_rdd_generation" + print("Usage: random_rdd_generation", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") @@ -37,19 +38,19 @@ # Example: RandomRDDs.normalRDD normalRDD = RandomRDDs.normalRDD(sc, numExamples) - print 'Generated RDD of %d examples sampled from the standard normal distribution'\ - % normalRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples sampled from the standard normal distribution' + % normalRDD.count()) + print(' First 5 samples:') for sample in normalRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() # Example: RandomRDDs.normalVectorRDD normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2) - print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()) + print(' First 5 samples:') for sample in normalVectorRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index 92af3af5ebd1..b7033ab7daeb 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -18,6 +18,7 @@ """ Randomly sampled RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: sampled_rdds " + print("Usage: sampled_rdds ", file=sys.stderr) exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] @@ -41,24 +42,24 @@ examples = MLUtils.loadLibSVMFile(sc, datapath) numExamples = examples.count() if numExamples == 0: - print >> sys.stderr, "Error: Data file had no samples to load." + print("Error: Data file had no samples to load.", file=sys.stderr) exit(1) - print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() expectedSampleSize = int(numExamples * fraction) - print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ - % (fraction, expectedSampleSize) + print('Sampling RDD using fraction %g. Expected sample size = %d.' + % (fraction, expectedSampleSize)) sampledRDD = examples.sample(withReplacement=True, fraction=fraction) - print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + print(' RDD.sample(): sample has %d examples' % sampledRDD.count()) sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize) - print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + print(' RDD.takeSample(): sample has %d examples' % len(sampledArray)) - print + print() # Example: RDD.sampleByKey() keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) - print ' Keyed data using label (Int) as key ==> Orig' + print(' Keyed data using label (Int) as key ==> Orig') # Count examples per label in original data. keyCountsA = keyedRDD.countByKey() @@ -69,18 +70,18 @@ sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions) keyCountsB = sampledByKeyRDD.countByKey() sizeB = sum(keyCountsB.values()) - print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ - % sizeB + print(' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' + % sizeB) # Compare samples - print ' \tFractions of examples with key' - print 'Key\tOrig\tSample' + print(' \tFractions of examples with key') + print('Key\tOrig\tSample') for k in sorted(keyCountsA.keys()): fracA = keyCountsA[k] / float(numExamples) if sizeB != 0: fracB = keyCountsB.get(k, 0) / float(sizeB) else: fracB = 0 - print '%d\t%g\t%g' % (k, fracA, fracB) + print('%d\t%g\t%g' % (k, fracA, fracB)) sc.stop() diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py index 99fef4276a36..40d1b887927e 100644 --- a/examples/src/main/python/mllib/word2vec.py +++ b/examples/src/main/python/mllib/word2vec.py @@ -23,6 +23,7 @@ # grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines # This was done so that the example can be run in local mode +from __future__ import print_function import sys @@ -34,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) < 2: - print USAGE + print(USAGE) sys.exit("Argument for file not provided") file_path = sys.argv[1] sc = SparkContext(appName='Word2Vec') @@ -46,5 +47,5 @@ synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index a5f25d78c114..2fdc9773d4eb 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -19,6 +19,7 @@ This is an example implementation of PageRank. For more conventional use, Please refer to PageRank implementation provided by graphx """ +from __future__ import print_function import re import sys @@ -42,11 +43,12 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: pagerank " + print("Usage: pagerank ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is - given as an example! Please refer to PageRank implementation provided by graphx""" + print("""WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""", + file=sys.stderr) # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") @@ -62,19 +64,19 @@ def parseNeighbors(urls): links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() # Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - ranks = links.map(lambda (url, neighbors): (url, 1.0)) + ranks = links.map(lambda url_neighbors: (url_neighbors[0], 1.0)) # Calculates and updates URL ranks continuously using PageRank algorithm. - for iteration in xrange(int(sys.argv[2])): + for iteration in range(int(sys.argv[2])): # Calculates URL contributions to the rank of other URLs. contribs = links.join(ranks).flatMap( - lambda (url, (urls, rank)): computeContribs(urls, rank)) + lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1])) # Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15) # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): - print "%s has rank: %s." % (link, rank) + print("%s has rank: %s." % (link, rank)) sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index fa4c20ab2028..e1fd85b082c0 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,14 +36,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, """ + print(""" Usage: parquet_inputformat.py Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \\ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -50,12 +51,12 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'parquet.avro.AvroParquetInputFormat', + 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') output = parquet_rdd.map(lambda x: x[1]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index a7c74e969cdb..92e5cf45abc8 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,7 +36,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) + count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) + print("Pi is roughly %f" % (4.0 * count / n)) sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index bb686f17518a..f6b0ecb02c10 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -22,7 +24,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: sort " + print("Usage: sort ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonSort") lines = sc.textFile(sys.argv[1], 1) @@ -33,6 +35,6 @@ # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() for (num, unitcount) in output: - print num + print(num) sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 7f5c68e3d0fe..2c188759328f 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -15,11 +15,14 @@ # limitations under the License. # +from __future__ import print_function + import os +import sys from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.sql import Row, StructField, StructType, StringType, IntegerType +from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType if __name__ == "__main__": @@ -31,7 +34,7 @@ Row(name="Smith", age=23), Row(name="Sarah", age=18)]) # Infer schema from the first row, create a DataFrame and print the schema - some_df = sqlContext.inferSchema(some_rdd) + some_df = sqlContext.createDataFrame(some_rdd) some_df.printSchema() # Another RDD is created from a list of tuples @@ -40,7 +43,7 @@ schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) # Create a DataFrame by applying the schema to the RDD and print the schema - another_df = sqlContext.applySchema(another_rdd, schema) + another_df = sqlContext.createDataFrame(another_rdd, schema) another_df.printSchema() # root # |-- age: integer (nullable = true) @@ -48,7 +51,11 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + if len(sys.argv) < 2: + path = "file://" + \ + os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + else: + path = sys.argv[1] # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root @@ -68,6 +75,6 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") for each in teenagers.collect(): - print each[0] + print(each[0]) sc.stop() diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py new file mode 100644 index 000000000000..49b7902185aa --- /dev/null +++ b/examples/src/main/python/status_api_demo.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import time +import threading +import Queue + +from pyspark import SparkConf, SparkContext + + +def delayed(seconds): + def f(x): + time.sleep(seconds) + return x + return f + + +def call_in_background(f, *args): + result = Queue.Queue(1) + t = threading.Thread(target=lambda: result.put(f(*args))) + t.daemon = True + t.start() + return result + + +def main(): + conf = SparkConf().set("spark.ui.showConsoleProgress", "false") + sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf) + + def run(): + rdd = sc.parallelize(range(10), 10).map(delayed(2)) + reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + return reduced.map(delayed(2)).collect() + + result = call_in_background(run) + status = sc.statusTracker() + while result.empty(): + ids = status.getJobIdsForGroup() + for id in ids: + job = status.getJobInfo(id) + print("Job", id, "status: ", job.status) + for sid in job.stageIds: + info = status.getStageInfo(sid) + if info: + print("Stage %d: %d tasks total (%d active, %d complete)" % + (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)) + time.sleep(1) + + print("Job results are:", result.get()) + sc.stop() + +if __name__ == "__main__": + main() diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 000000000000..ea20678b9aca --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 000000000000..d75bc6daac13 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars \ + external/flume-assembly/target/scala-*/spark-streaming-flume-assembly-*.jar \ + examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f7ffb5379681..f815dd26823d 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -25,6 +25,7 @@ Then create a text file in `localdir` and the words in the file will get counted. """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: hdfs_wordcount.py " + print("Usage: hdfs_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index ed398a82b8bb..8d697f620f46 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -17,16 +17,18 @@ """ Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - Usage: network_wordcount.py + Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ +from __future__ import print_function import sys @@ -36,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kafka_wordcount.py " + print("Usage: kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 000000000000..abf9c0e21d30 --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars \ + external/mqtt-assembly/target/scala-*/spark-streaming-mqtt-assembly-*.jar \ + examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index cfa9c1ff5bfb..2b48bcfd55db 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -25,6 +25,7 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: network_wordcount.py " + print("Usage: network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py new file mode 100644 index 000000000000..b3808907f74a --- /dev/null +++ b/examples/src/main/python/streaming/queue_stream.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Create a queue of RDDs that will be mapped/reduced one at a time in + 1 second intervals. + + To run this example use + `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py +""" +import sys +import time + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonStreamingQueueStream") + ssc = StreamingContext(sc, 1) + + # Create the queue through which RDDs can be pushed to + # a QueueInputDStream + rddQueue = [] + for i in range(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)] + + # Create the QueueInputDStream and use it do some processing + inputStream = ssc.queueStream(rddQueue) + mappedStream = inputStream.map(lambda x: (x % 10, 1)) + reducedStream = mappedStream.reduceByKey(lambda a, b: a + b) + reducedStream.pprint() + + ssc.start() + time.sleep(6) + ssc.stop(stopSparkContext=True, stopGraceFully=True) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index fc6827c82bf9..ac91f0a06b17 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -35,6 +35,7 @@ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from the checkpoint data. """ +from __future__ import print_function import os import sys @@ -46,7 +47,7 @@ def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint - print "Creating new context" + print("Creating new context") if os.path.exists(outputPath): os.remove(outputPath) sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") @@ -60,8 +61,8 @@ def createContext(host, port, outputPath): def echo(time, rdd): counts = "Counts at time %s %s" % (time, rdd.collect()) - print counts - print "Appending to " + os.path.abspath(outputPath) + print(counts) + print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") @@ -70,8 +71,8 @@ def echo(time, rdd): if __name__ == "__main__": if len(sys.argv) != 5: - print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ - " " + print("Usage: recoverable_network_wordcount.py " + " ", file=sys.stderr) exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py new file mode 100644 index 000000000000..da90c07dbd82 --- /dev/null +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: sql_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` +""" +from __future__ import print_function + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.sql import SQLContext, Row + + +def getSqlContextInstance(sparkContext): + if ('sqlContextSingletonInstance' not in globals()): + globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) + return globals()['sqlContextSingletonInstance'] + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: sql_network_wordcount.py ", file=sys.stderr) + exit(-1) + host, port = sys.argv[1:] + sc = SparkContext(appName="PythonSqlNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, int(port)) + words = lines.flatMap(lambda line: line.split(" ")) + + # Convert RDDs of the words DStream to DataFrame and run SQL query + def process(time, rdd): + print("========= %s =========" % str(time)) + + try: + # Get the singleton instance of SQLContext + sqlContext = getSqlContextInstance(rdd.context) + + # Convert RDD[String] to RDD[Row] to DataFrame + rowRdd = rdd.map(lambda w: Row(word=w)) + wordsDataFrame = sqlContext.createDataFrame(rowRdd) + + # Register as table + wordsDataFrame.registerTempTable("words") + + # Do word count on table using SQL and print it + wordCountsDataFrame = \ + sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() + except: + pass + + words.foreachRDD(process) + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 18a9a5a452ff..16ef646b7c42 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -29,6 +29,7 @@ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ localhost 9999` """ +from __future__ import print_function import sys @@ -37,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: stateful_network_wordcount.py " + print("Usage: stateful_network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 00a281bfb650..7bf5fb6ddfe2 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from random import Random @@ -49,20 +51,20 @@ def generateGraph(): # the graph to obtain the path (x, z). # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) + edges = tc.map(lambda x_y: (x_y[1], x_y[0])) - oldCount = 0L + oldCount = 0 nextCount = tc.count() while True: oldCount = nextCount # Perform the join, obtaining an RDD of (y, (z, x)) pairs, # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + new_edges = tc.join(edges).map(lambda __a_b: (__a_b[1][1], __a_b[1][0])) tc = tc.union(new_edges).distinct().cache() nextCount = tc.count() if nextCount == oldCount: break - print "TC has %i edges" % tc.count() + print("TC has %i edges" % tc.count()) sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index ae6cd13b83d9..7c0143607b61 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from operator import add @@ -23,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: wordcount " + print("Usage: wordcount ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonWordCount") lines = sc.textFile(sys.argv[1], 1) @@ -32,6 +34,6 @@ .reduceByKey(add) output = counts.collect() for (word, count) in output: - print "%s: %i" % (word, count) + print("%s: %i" % (word, count)) sc.stop() diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 000000000000..aa2336e300a9 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R new file mode 100644 index 000000000000..53b817144f6a --- /dev/null +++ b/examples/src/main/r/dataframe.R @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Initialize SparkContext and SQLContext +sc <- sparkR.init(appName="SparkR-DataFrame-example") +sqlContext <- sparkRSQL.init(sc) + +# Create a simple local data.frame +localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18)) + +# Convert local data frame to a SparkR DataFrame +df <- createDataFrame(sqlContext, localDF) + +# Print its schema +printSchema(df) +# root +# |-- name: string (nullable = true) +# |-- age: double (nullable = true) + +# Create a DataFrame from a JSON file +path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") +peopleDF <- jsonFile(sqlContext, path) +printSchema(peopleDF) + +# Register this DataFrame as a table. +registerTempTable(peopleDF, "people") + +# SQL statements can be run by using the sql methods provided by sqlContext +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") + +# Call collect to get a local data.frame +teenagersLocalDF <- collect(teenagers) + +# Print the teenagers in our dataset +print(teenagersLocalDF) + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 1b53f3edbe92..d812262fd87d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -29,7 +30,7 @@ object BroadcastTest { val blockSize = if (args.length > 3) args(3) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory") + .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 11d5c92c5952..fa07c1e5017c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,13 +15,11 @@ * limitations under the License. */ + // scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ListBuffer -import scala.collection.immutable.Map +import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat @@ -31,7 +29,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* @@ -104,8 +101,8 @@ object CassandraCQLTest { val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[CqlPagingInputFormat], - classOf[java.util.Map[String,ByteBuffer]], - classOf[java.util.Map[String,ByteBuffer]]) + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.Map[String, ByteBuffer]]) println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { @@ -118,14 +115,11 @@ object CassandraCQLTest { case (productId, saleCount) => println(productId + ":" + saleCount) } - val casoutputCF = aggregatedRDD.map { + val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { - val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) - val outKey: java.util.Map[String, ByteBuffer] = outColFamKey - var outColFamVal = new ListBuffer[ByteBuffer] - outColFamVal += ByteBufferUtil.bytes(saleCount) - val outVal: java.util.List[ByteBuffer] = outColFamVal - (outKey, outVal) + val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) + val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) + (outKey, outVal) } } @@ -140,3 +134,4 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb..2e56d24c60c3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,13 +15,13 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer +import java.util.Arrays import java.util.SortedMap -import scala.collection.JavaConversions._ - import org.apache.cassandra.db.IColumn import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper @@ -31,7 +31,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra @@ -117,7 +116,7 @@ object CassandraTest { val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + val mutations = Arrays.asList(new Mutation(), new Mutation()) mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) @@ -130,6 +129,7 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala new file mode 100644 index 000000000000..d651fe4d6ee7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import scala.io.Source._ + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.SparkContext._ + +/** + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ +object DFSReadWriteTest { + + private var localFilePath: File = new File(".") + private var dfsDirPath: String = "" + + private val NPARAMS = 2 + + private def readFile(filename: String): List[String] = { + val lineIter: Iterator[String] = fromFile(filename).getLines() + val lineList: List[String] = lineIter.toList + lineList + } + + private def printUsage(): Unit = { + val usage: String = "DFS Read-Write Test\n" + + "\n" + + "Usage: localFile dfsDir\n" + + "\n" + + "localFile - (string) local file to use in test\n" + + "dfsDir - (string) DFS directory for read/write tests\n" + + println(usage) + } + + private def parseArgs(args: Array[String]): Unit = { + if (args.length != NPARAMS) { + printUsage() + System.exit(1) + } + + var i = 0 + + localFilePath = new File(args(i)) + if (!localFilePath.exists) { + System.err.println("Given path (" + args(i) + ") does not exist.\n") + printUsage() + System.exit(1) + } + + if (!localFilePath.isFile) { + System.err.println("Given path (" + args(i) + ") is not a file.\n") + printUsage() + System.exit(1) + } + + i += 1 + dfsDirPath = args(i) + } + + def runLocalWordCount(fileContents: List[String]): Int = { + fileContents.flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .groupBy(w => w) + .mapValues(_.size) + .values + .sum + } + + def main(args: Array[String]): Unit = { + parseArgs(args) + + println("Performing local word count") + val fileContents = readFile(localFilePath.toString()) + val localWordCount = runLocalWordCount(fileContents) + + println("Creating SparkConf") + val conf = new SparkConf().setAppName("DFS Read Write Test") + + println("Creating SparkContext") + val sc = new SparkContext(conf) + + println("Writing local file to DFS") + val dfsFilename = dfsDirPath + "/dfs_read_write_test" + val fileRDD = sc.parallelize(fileContents) + fileRDD.saveAsTextFile(dfsFilename) + + println("Reading file from DFS and running Word Count") + val readFileRDD = sc.textFile(dfsFilename) + + val dfsWordCount = readFileRDD + .flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .map(w => (w, 1)) + .countByKey() + .values + .sum + + sc.stop() + + if (localWordCount == dfsWordCount) { + println(s"Success! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) agree.") + } else { + println(s"Failure! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) disagree.") + } + + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc..bec61f3cd429 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,9 +15,10 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils @@ -35,10 +36,10 @@ object DriverSubmissionTest { val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") - env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) + env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println) println("System properties containing spark.test:") - properties.filter{case (k, v) => k.toString.contains("spark.test")}.foreach(println) + properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println) for (i <- 1 until numSecondsToSleep) { println(s"Alive for $i out of $numSecondsToSleep seconds") @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b2..fa4a3afeecd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 822673347bdc..244742327a90 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,10 +15,11 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin -import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, TableName} import org.apache.hadoop.hbase.mapreduce.TableInputFormat import org.apache.spark._ @@ -28,7 +29,19 @@ object HBaseTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HBaseTest") val sc = new SparkContext(sparkConf) + + // please ensure HBASE_CONF_DIR is on classpath of spark driver + // e.g: set it through spark.driver.extraClassPath property + // in spark-defaults.conf or through --driver-class-path + // command line option of spark-submit + val conf = HBaseConfiguration.create() + + if (args.length < 1) { + System.err.println("Usage: HBaseTest ") + System.exit(1) + } + // Other options for configuring scan behavior are available. More information available at // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html conf.set(TableInputFormat.INPUT_TABLE, args(0)) @@ -36,7 +49,7 @@ object HBaseTest { // Initialize hBase table if necessary val admin = new HBaseAdmin(conf) if (!admin.isTableAvailable(args(0))) { - val tableDesc = new HTableDescriptor(args(0)) + val tableDesc = new HTableDescriptor(TableName.valueOf(args(0))) admin.createTable(tableDesc) } @@ -47,5 +60,7 @@ object HBaseTest { hBaseRDD.count() sc.stop() + admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f..124dc9af6390 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003..af5f216f28ba 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e..9c8aae53cf48 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 17624c20cff3..e7b28d38bdfc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -40,8 +41,8 @@ object LocalKMeans { val convergeDist = 0.001 val rand = new Random(42) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DenseVector[Double]] = { + def generatePoint(i: Int): DenseVector[Double] = { DenseVector.fill(D){rand.nextDouble * R} } Array.tabulate(N)(generatePoint) @@ -99,7 +100,7 @@ object LocalKMeans { var pointStats = mappings.map { pair => pair._2.reduceLeft [(Int, (Vector[Double], Int))] { - case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1 + y2)) + case ((id1, (p1, c1)), (id2, (p2, c2))) => (id1, (p1 + p2, c1 + c2)) } } @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 92a683ad57ea..4f6b092a59ca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -37,9 +38,9 @@ object LocalLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb..3d923625f11b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 74620ad007d8..a80de10f4610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -22,7 +23,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { @@ -54,8 +55,8 @@ object LogQuery { // scalastyle:on /** Tracks the total query count and number of aggregate bytes for a particular group. */ class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) + def merge(other: Stats): Stats = new Stats(count + other.count, numBytes + other.numBytes) + override def toString: String = "bytes=%s\tn=%s".format(numBytes, count) } def extractKey(line: String): (String, String, String) = { @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe1..61ce9db914f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459..3b0b00fe4dd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce1..719e2176fed3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 6c0ac8013ce3..69799b7c2bb3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -117,7 +118,7 @@ object SparkALS { var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users - val Rc = sc.broadcast(R) + val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b..505ea5a4c7a8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 48e8d11cdf95..c56e1124ad41 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -79,7 +80,7 @@ object SparkKMeans { while(tempDist > convergeDist) { val closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} + val pointStats = closest.reduceByKey{case ((p1, c1), (p2, c2)) => (p1 + p2, c1 + c2)} val newPoints = pointStats.map {pair => (pair._1, pair._2._1 * (1.0 / pair._2._2))}.collectAsMap() @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 257a7d29f922..d265c227f4ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -42,9 +43,9 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 8d092b6506d3..0fd79660dd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -51,7 +52,7 @@ object SparkPageRank { showWarning() val sparkConf = new SparkConf().setAppName("PageRank") - val iters = if (args.length > 0) args(1).toInt else 10 + val iters = if (args.length > 1) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b6..818d4f2b81f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f7f83086df3d..95072071ccdd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -31,7 +32,7 @@ object SparkTC { val numVertices = 100 val rand = new Random(42) - def generateGraph = { + def generateGraph: Seq[(Int, Int)] = { val edges: mutable.Set[(Int, Int)] = mutable.Set.empty while (edges.size < numEdges) { val from = rand.nextInt(numVertices) @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b..cfbdae02212a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b10..e46ac655beb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index e322d4ce5a74..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.bagel._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) { - self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size)) - } else { - Array[PRMessage]() - } - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double) - (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int) - : (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)" - .format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } - - override def hashCode: Int = numPartitions -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 859abedf2a55..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println( - "Usage: WikipediaPageRank ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.setAppName("WikipediaPageRank") - sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRank") - val sc = new SparkContext(sparkConf) - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() - } else { - vertices = vertices.cache() - } - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect().mkString) - println(top) - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index 576a3e371b99..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML, NodeSeq} - -import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRankStandalone " + - " ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRankStandalone") - - val sc = new SparkContext(sparkConf) - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) { - input.map(parseArticle _).partitionBy(partitioner).cache() - } else { - input.map(parseArticle _).cache() - } - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, - sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - sc.stop() - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapperIterable, rankWrapperIterable)) => - val linksWrapper = linksWrapperIterable.iterator - val rankWrapper = rankWrapperIterable.iterator - if (linksWrapper.hasNext) { - val linksWrapperHead = linksWrapper.next - if (rankWrapper.hasNext) { - val rankWrapperHead = rankWrapper.next - linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) - } else { - linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends org.apache.spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T: ClassTag](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T: ClassTag](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8f..8dd6c9706e7d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index e809a65b7997..da3ffca1a6f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,13 +15,9 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx -import org.apache.spark.SparkContext._ -import org.apache.spark._ -import org.apache.spark.graphx._ - - /** * Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from * http://snap.stanford.edu/data/soc-LiveJournal1.html. @@ -31,13 +27,13 @@ object LiveJournalPageRank { if (args.length < 1) { System.err.println( "Usage: LiveJournalPageRank \n" + + " --numEPart=\n" + + " The number of partitions for the graph's edge RDD.\n" + " [--tol=]\n" + " The tolerance allowed at convergence (smaller => more accurate). Default is " + "0.001.\n" + " [--output=]\n" + " If specified, the file to write the ranks to.\n" + - " [--numEPart=]\n" + - " The number of partitions for the graph's edge RDD. Default is 4.\n" + " [--partStrategy=RandomVertexCut | EdgePartition1D | EdgePartition2D | " + "CanonicalRandomVertexCut]\n" + " The way edges are assigned to edge partitions. Default is RandomVertexCut.") @@ -47,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b78..46e52aacd90b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 283bb80f1c78..14b358d46f6a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -23,6 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} /** @@ -43,10 +45,10 @@ object CrossValidatorExample { val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -89,23 +91,24 @@ object CrossValidatorExample { crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. - val cvModel = crossval.fit(training) + val cvModel = crossval.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test) - .select("id", "text", "score", "prediction") + cvModel.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala new file mode 100644 index 000000000000..f28671f7869f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{SQLContext, DataFrame} + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.DecisionTreeExample [options] + * }}} + * Note that Decision Trees can take a large amount of memory. If the run-example command above + * fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DecisionTreeExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DecisionTreeExample") { + head("DecisionTreeExample: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + /** Load a dataset from the given path, using the given format */ + private[ml] def loadData( + sc: SparkContext, + path: String, + format: String, + expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + format match { + case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "libsvm" => expectedNumFeatures match { + case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) + case None => MLUtils.loadLibSVMFile(sc, path) + } + case _ => throw new IllegalArgumentException(s"Bad data format: $format") + } + } + + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset) + */ + private[ml] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: String, + fracTest: Double): (DataFrame, DataFrame) = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Load training data + val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + + // Load or create test set + val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + // Load testInput. + val numFeatures = origExamples.take(1)(0).features.size + val origTestExamples: RDD[LabeledPoint] = + loadData(sc, testInput, dataFormat, Some(numFeatures)) + Array(origExamples, origTestExamples) + } else { + // Split input into training, test. + origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) + } + + // For classification, convert labels to Strings since we will index them later with + // StringIndexer. + def labelsToStrings(data: DataFrame): DataFrame = { + algo.toLowerCase match { + case "classification" => + data.withColumn("labelString", data("label").cast(StringType)) + case "regression" => + data + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + val dataframes = splits.map(_.toDF()).map(labelsToStrings) + val training = dataframes(0).cache() + val test = dataframes(1).cache() + + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + (training, test) + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"DecisionTreeExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = + loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn Decision Tree + val dt = algo match { + case "classification" => + new DecisionTreeClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case "regression" => + new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Decision Tree from the fitted PipelineModel + algo match { + case "classification" => + val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel] + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } + case "regression" => + val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel] + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } + + /** + * Evaluate the given ClassificationModel on data. Print the results. + * @param model Must fit ClassificationModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateClassificationModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + // Print number of classes for reference + val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "Unknown failure when indexing labels for classification.") + } + val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision + println(s" Accuracy ($numClasses classes): $accuracy") + } + + /** + * Evaluate the given RegressionModel on data. Print the results. + * @param model Must fit RegressionModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to RegressionModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateRegressionModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError + println(s" Root mean squared error (RMSE): $RMSE") + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala new file mode 100644 index 000000000000..340c3559b15e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams} +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics [[org.apache.spark.ml.classification.LogisticRegression]]. + * Run with + * {{{ + * bin/run-example ml.DeveloperApiExample + * }}} + */ +object DeveloperApiExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("DeveloperApiExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training data. + val training = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new MyLogisticRegression() + // Print out the parameters, documentation, and any default values. + println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model = lr.fit(training.toDF()) + + // Prepare test data. + val test = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) + + // Make predictions on test data. + val sumPredictions: Double = model.transform(test.toDF()) + .select("features", "label", "prediction") + .collect() + .map { case Row(features: Vector, label: Double, prediction: Double) => + prediction + }.sum + assert(sumPredictions == 0.0, + "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + + sc.stop() + } +} + +/** + * Example of defining a parameter trait for a user-defined type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private trait MyLogisticRegressionParams extends ClassifierParams { + + /** + * Param for max number of iterations + * + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression + * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator + * class since the maxIter parameter is only used during training (not in the Model). + */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = $(maxIter) +} + +/** + * Example of defining a type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegression(override val uid: String) + extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + def this() = this(Identifiable.randomUID("myLogReg")) + + setMaxIter(100) // Initialize + + // The parameter setter is in this class since it should return type MyLogisticRegression. + def setMaxIter(value: Int): this.type = set(maxIter, value) + + // This method is used by fit() + override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { + // Extract columns from data using helper method. + val oldDataset = extractLabeledPoints(dataset) + + // Do learning to estimate the weight vector. + val numFeatures = oldDataset.take(1)(0).features.size + val weights = Vectors.zeros(numFeatures) // Learning would happen here. + + // Create a model, and return it. + new MyLogisticRegressionModel(uid, weights).setParent(this) + } + + override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) +} + +/** + * Example of defining a type of [[ClassificationModel]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegressionModel( + override val uid: String, + val weights: Vector) + extends ClassificationModel[Vector, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + override protected def predictRaw(features: Vector): Vector = { + val margin = BLAS.dot(features, weights) + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + Vectors.dense(-margin, margin) + } + + /** Number of classes the label can take. 2 indicates binary classification. */ + override val numClasses: Int = 2 + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + * + * This is used for the default implementation of [[transform()]]. + */ + override def copy(extra: ParamMap): MyLogisticRegressionModel = { + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala new file mode 100644 index 000000000000..f4a15f806ea8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.GBTExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object GBTExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + maxIter: Int = 10, + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GBTExample") { + head("GBTExample: an example Gradient-Boosted Trees app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("maxIter") + .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"GBTExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"GBTExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn GBT + val dt = algo match { + case "classification" => + new GBTClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case "regression" => + new GBTRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained GBT from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel] + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel] + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala new file mode 100644 index 000000000000..5ce38462d118 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} + + +/** + * An example demonstrating a k-means clustering. + * Run with + * {{{ + * bin/run-example ml.KMeansExample + * }}} + */ +object KMeansExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println("Usage: ml.KMeansExample ") + // scalastyle:on println + System.exit(1) + } + val input = args(0) + val k = args(1).toInt + + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a k-means model + val kmeans = new KMeans() + .setK(k) + .setFeaturesCol(FEATURES_COL) + val model = kmeans.fit(dataset) + + // Shows the result + // scalastyle:off println + println("Final Centers: ") + model.clusterCenters.foreach(println) + // scalastyle:on println + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 000000000000..b73299fb12d3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 000000000000..7682557127b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index b7885829459a..3ae53e57dbdb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -75,7 +76,7 @@ object MovieLensALS { .text("path to a MovieLens dataset of movies") .action((x, c) => c.copy(movies = x)) opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("maxIter") .text(s"max number of iterations, default: ${defaultParams.maxIter}") @@ -93,8 +94,8 @@ object MovieLensALS { | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ | examples/target/scala-*/spark-examples-*.jar \ | --rank 10 --maxIter 15 --regParam 0.1 \ - | --movies path/to/movielens/movies.dat \ - | --ratings path/to/movielens/ratings.dat + | --movies data/mllib/als/sample_movielens_movies.txt \ + | --ratings data/mllib/als/sample_movielens_ratings.txt """.stripMargin) } @@ -109,7 +110,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() @@ -137,9 +138,9 @@ object MovieLensALS { .setRegParam(params.regParam) .setNumBlocks(params.numBlocks) - val model = als.fit(training) + val model = als.fit(training.toDF()) - val predictions = model.transform(test).cache() + val predictions = model.transform(test.toDF()).cache() // Evaluate the model. // TODO: Create an evaluator to compute RMSE. @@ -157,18 +158,25 @@ object MovieLensALS { println(s"Test RMSE = $rmse.") // Inspect false positives. - predictions.registerTempTable("prediction") - sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") - sqlContext.sql( - """ - |SELECT userId, prediction.movieId, title, rating, prediction - | FROM prediction JOIN movie ON prediction.movieId = movie.movieId - | WHERE rating <= 1 AND prediction >= 4 - | LIMIT 100 - """.stripMargin) - .collect() - .foreach(println) + // Note: We reference columns in 2 ways: + // (1) predictions("movieId") lets us specify the movieId column in the predictions + // DataFrame, rather than the movieId column in the movies DataFrame. + // (2) $"userId" specifies the userId column in the predictions DataFrame. + // We could also write predictions("userId") but do not have to since + // the movies DataFrame does not have a column "userId." + val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF() + val falsePositives = predictions.join(movies) + .where((predictions("movieId") === movies("movieId")) + && ($"rating" <= 1) && ($"prediction" >= 4)) + .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction") + val numFalsePositives = falsePositives.count() + println(s"Found $numFalsePositives false positives") + if (numFalsePositives > 0) { + println(s"Example false positives:") + falsePositives.limit(100).collect().foreach(println) + } sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala new file mode 100644 index 000000000000..bab31f585b0e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} + +import scopt.OptionParser + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + * {{{ + * ./bin/run-example ml.OneVsRestExample [options] + * }}} + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object OneVsRestExample { + + case class Params private[ml] ( + input: String = null, + testInput: Option[String] = None, + maxIter: Int = 100, + tol: Double = 1E-6, + fitIntercept: Boolean = true, + regParam: Option[Double] = None, + elasticNetParam: Option[Double] = None, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("OneVsRest Example") { + head("OneVsRest Example: multiclass to binary reduction using OneVsRest") + opt[String]("input") + .text("input path to labeled examples. This path must be specified") + .required() + .action((x, c) => c.copy(input = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text("input path to test dataset. If given, option fracTest is ignored") + .action((x, c) => c.copy(testInput = Some(x))) + opt[Int]("maxIter") + .text(s"maximum number of iterations for Logistic Regression." + + s" default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations for Logistic Regression." + + s" default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Boolean]("fitIntercept") + .text(s"fit intercept for Logistic Regression." + + s" default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("regParam") + .text(s"the regularization parameter for Logistic Regression.") + .action((x, c) => c.copy(regParam = Some(x))) + opt[Double]("elasticNetParam") + .text(s"the ElasticNet mixing parameter for Logistic Regression.") + .action((x, c) => c.copy(elasticNetParam = Some(x))) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") + val sc = new SparkContext(conf) + val inputData = MLUtils.loadLibSVMFile(sc, params.input) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // compute the train/test split: if testInput is not provided use part of input. + val data = params.testInput match { + case Some(t) => { + // compute the number of features in the training set. + val numFeatures = inputData.first().features.size + val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) + Array[RDD[LabeledPoint]](inputData, testData) + } + case None => { + val f = params.fracTest + inputData.randomSplit(Array(1 - f, f), seed = 12345) + } + } + val Array(train, test) = data.map(_.toDF().cache()) + + // instantiate the base classifier + val classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept) + + // Set regParam, elasticNetParam if specified in params + params.regParam.foreach(classifier.setRegParam) + params.elasticNetParam.foreach(classifier.setElasticNetParam) + + // instantiate the One Vs Rest Classifier. + + val ovr = new OneVsRest() + ovr.setClassifier(classifier) + + // train the multiclass model. + val (trainingDuration, ovrModel) = time(ovr.fit(train)) + + // score the model on test data. + val (predictionDuration, predictions) = time(ovrModel.transform(test)) + + // evaluate the model + val predictionsAndLabels = predictions.select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val metrics = new MulticlassMetrics(predictionsAndLabels) + + val confusionMatrix = metrics.confusionMatrix + + // compute the false positive rate per label + val predictionColSchema = predictions.schema("prediction") + val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get + val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + + println(s" Training Time ${trainingDuration} sec\n") + + println(s" Prediction Time ${predictionDuration} sec\n") + + println(s" Confusion Matrix\n ${confusionMatrix.toString}\n") + + println("label\tfpr") + + println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + + sc.stop() + } + + private def time[R](block: => R): (Long, R) = { + val t0 = System.nanoTime() + val result = block // call-by-name + val t1 = System.nanoTime() + (NANO.toSeconds(t1 - t0), result) + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala new file mode 100644 index 000000000000..109178f4137b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.RandomForestExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomForestExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 10, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("RandomForestExample") { + head("RandomForestExample: an example random forest app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"number of features to use per node (supported:" + + s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"RandomForestExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"RandomForestExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn Random Forest + val dt = algo match { + case "classification" => + new RandomForestClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case "regression" => + new RandomForestRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Random Forest from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel] + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel] + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 95cc9801eaeb..f4d1fe57856a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -37,12 +38,12 @@ object SimpleParamsExample { val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into DataFrames, where it uses the bean metadata to infer the schema. - val training = sparkContext.parallelize(Seq( + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes + // into DataFrames, where it uses the case class metadata to infer the schema. + val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -58,45 +59,46 @@ object SimpleParamsExample { .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model1 = lr.fit(training) + val model1 = lr.fit(training.toDF()) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.fittingParamMap) + println("Model 1 was fit using parameters: " + model1.parent.extractParamMap()) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. - val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.fittingParamMap) + val model2 = lr.fit(training.toDF(), paramMapCombined) + println("Model 2 was fit using parameters: " + model2.parent.extractParamMap()) - // Prepare test documents. - val test = sparkContext.parallelize(Seq( + // Prepare test data. + val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) - // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test) - .select("features", "label", "probability", "prediction") + // Make predictions on test data using the Transformer.transform() method. + // LogisticRegressionModel.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test.toDF()) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") } sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 065db62b0f5e..960280137cbf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -23,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} @BeanInfo @@ -44,10 +46,10 @@ object SimpleTextClassificationPipeline { val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -63,28 +65,29 @@ object SimpleTextClassificationPipeline { .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01) + .setRegParam(0.001) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. - val model = pipeline.fit(training) + val model = pipeline.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), - Document(6L, "mapreduce spark"), + Document(6L, "spark hadoop spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. - model.transform(test) - .select("id", "text", "score", "prediction") + model.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b9..1a4016f76c2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e784..026d4ecc6d10 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4..69988cc1b933 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index ab58375649d2..dc13f82488af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -71,7 +72,7 @@ object DatasetExample { val conf = new SparkConf().setAppName(s"DatasetExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ // for implicit conversions + import sqlContext.implicits._ // for implicit conversions // Load input data val origData: RDD[LabeledPoint] = params.dataFormat match { @@ -81,18 +82,18 @@ object DatasetExample { println(s"Loaded ${origData.count()} instances from file: ${params.input}") // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDataFrame + val df: DataFrame = origData.toDF() println(s"Inferred schema:\n${df.schema.prettyJson}") println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to DataFrames. - val labelsDf: DataFrame = origData.select("label") + // Select columns + val labelsDf: DataFrame = df.select("label") val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresDf: DataFrame = origData.select("features") + val featuresDf: DataFrame = df.select("features") val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), @@ -103,10 +104,10 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - df.saveAsParquetFile(outputDir) + df.write.parquet(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.parquetFile(outputDir) + val newDataset = sqlContext.read.parquet(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 205d80dd0268..cc6bce3cb7c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -22,7 +23,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -100,7 +100,7 @@ object DecisionTreeRunner { .action((x, c) => c.copy(numTrees = x)) opt[String]("featureSubsetStrategy") .text(s"feature subset sampling strategy" + - s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}), " + s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") @@ -126,7 +126,7 @@ object DecisionTreeRunner { .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) - opt[String]("") + opt[String]("dataFormat") .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") .action((x, c) => c.copy(dataFormat = x)) arg[String]("") @@ -272,6 +272,8 @@ object DecisionTreeRunner { case Variance => impurity.Variance } + params.checkpointDir.foreach(sc.setCheckpointDir) + val strategy = new Strategy( algo = params.algo, @@ -282,7 +284,6 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain, useNodeIdCache = params.useNodeIdCache, - checkpointDir = params.checkpointDir, checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { val startTime = System.nanoTime() @@ -353,7 +354,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -362,4 +367,6 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index df76b45e5081..1fce4ba7efd6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -40,23 +41,23 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - + val ctx = new SparkContext(conf) + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 11e35598baf5..380d85d60e7b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -56,7 +57,7 @@ object DenseKMeans { .text(s"number of clusters, required") .action((x, c) => c.copy(k = x)) opt[Int]("numIterations") - .text(s"number of iterations, default; ${defaultParams.numIterations}") + .text(s"number of iterations, default: ${defaultParams.numIterations}") .action((x, c) => c.copy(numIterations = x)) opt[String]("initMode") .text(s"initialization mode (${InitializationMode.values.mkString(",")}), " + @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala new file mode 100644 index 000000000000..14b930550d55 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.fpm.FPGrowth +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.FPGrowthExample \ + * --minSupport 0.8 --numPartition 2 ./data/mllib/sample_fpgrowth.txt + */ +object FPGrowthExample { + + case class Params( + input: String = null, + minSupport: Double = 0.3, + numPartition: Int = -1) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("FPGrowthExample") { + head("FPGrowth: an example FP-growth app.") + opt[Double]("minSupport") + .text(s"minimal support level, default: ${defaultParams.minSupport}") + .action((x, c) => c.copy(minSupport = x)) + opt[Int]("numPartition") + .text(s"number of partition, default: ${defaultParams.numPartition}") + .action((x, c) => c.copy(numPartition = x)) + arg[String]("") + .text("input paths to input data set, whose file format is that each line " + + "contains a transaction with each item in String and separated by a space") + .required() + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"FPGrowthExample with $params") + val sc = new SparkContext(conf) + val transactions = sc.textFile(params.input).map(_.split(" ")).cache() + + println(s"Number of transactions: ${transactions.count()}") + + val model = new FPGrowth() + .setMinSupport(params.minSupport) + .setNumPartitions(params.numPartition) + .run(transactions) + + println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") + + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + } + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 431ead8c0c16..e16a6bf03357 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -25,6 +26,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils + /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ @@ -68,7 +70,7 @@ object GradientBoostedTreesRunner { .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) - opt[String]("") + opt[String]("dataFormat") .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") .action((x, c) => c.copy(dataFormat = x)) arg[String]("") @@ -144,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index f4c545ad70e9..75b0f69cf91a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -26,7 +27,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -48,6 +49,7 @@ object LDAExample { topicConcentration: Double = -1, vocabSize: Int = 10000, stopwordFile: String = "", + algorithm: String = "em", checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] @@ -78,6 +80,10 @@ object LDAExample { .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + s" default: ${defaultParams.stopwordFile}") .action((x, c) => c.copy(stopwordFile = x)) + opt[String]("algorithm") + .text(s"inference algorithm to use. em and online are supported." + + s" default: ${defaultParams.algorithm}") + .action((x, c) => c.copy(algorithm = x)) opt[String]("checkpointDir") .text(s"Directory for checkpointing intermediate results." + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + @@ -128,13 +134,23 @@ object LDAExample { // Run LDA. val lda = new LDA() - lda.setK(params.k) + + val optimizer = params.algorithm.toLowerCase match { + case "em" => new EMLDAOptimizer + // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets. + case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize) + case _ => throw new IllegalArgumentException( + s"Only em, online are supported but got ${params.algorithm}.") + } + + lda.setOptimizer(optimizer) + .setK(params.k) .setMaxIterations(params.maxIterations) .setDocConcentration(params.docConcentration) .setTopicConcentration(params.topicConcentration) .setCheckpointInterval(params.checkpointInterval) if (params.checkpointDir.nonEmpty) { - lda.setCheckpointDir(params.checkpointDir.get) + sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() val ldaModel = lda.run(corpus) @@ -142,9 +158,13 @@ object LDAExample { println(s"Finished training LDA model. Summary:") println(s"\t Training time: $elapsed sec") - val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble - println(s"\t Training data average log likelihood: $avgLogLikelihood") - println() + + if (ldaModel.isInstanceOf[DistributedLDAModel]) { + val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel] + val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + } // Print the topics, showing the top-weighted terms for each topic. val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) @@ -159,7 +179,7 @@ object LDAExample { } println() } - + sc.stop() } /** @@ -173,7 +193,9 @@ object LDAExample { stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { // Get dataset of document texts - // One document per line in each text file. + // One document per line in each text file. If the input consists of many small files, + // this can result in a large number of small partitions, which can degrade performance. + // In this case, consider using coalesce() to create fewer, larger partitions. val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) // Split text into words @@ -281,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07..8878061a0970 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 91a0a860d6c7..69691ae297f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -54,7 +55,7 @@ object MovieLensALS { val parser = new OptionParser[Params]("MovieLensALS") { head("MovieLensALS: an example app for ALS on MovieLens data.") opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("numIterations") .text(s"number of iterations, default: ${defaultParams.numIterations}") @@ -100,7 +101,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) @@ -175,9 +176,12 @@ object MovieLensALS { } /** Compute RMSE (Root Mean Squared Error). */ - def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = { + def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) + : Double = { - def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + def mapPredictedRating(r: Double): Double = { + if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + } val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) val predictionsAndRatings = predictions.map{ x => @@ -186,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284..5f839c75dd58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala new file mode 100644 index 000000000000..072322395461 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * Takes an input of K concentric circles and the number of points in the innermost circle. + * The output should be K clusters - each cluster containing precisely the points associated + * with each of the input circles. + * + * Run with + * {{{ + * ./bin/run-example mllib.PowerIterationClusteringExample [options] + * + * Where options include: + * k: Number of circles/clusters + * n: Number of sampled points on innermost circle.. There are proportionally more points + * within the outer/larger circles + * maxIterations: Number of Power Iterations + * outerRadius: radius of the outermost of the concentric circles + * }}} + * + * Here is a sample run and output: + * + * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 + * + * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], + * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * + * + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object PowerIterationClusteringExample { + + case class Params( + input: String = null, + k: Int = 3, + numPoints: Int = 5, + maxIterations: Int = 10, + outerRadius: Double = 3.0 + ) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("PowerIterationClusteringExample") { + head("PowerIterationClusteringExample: an example PIC app using concentric circles.") + opt[Int]('k', "k") + .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]('n', "n") + .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") + .action((x, c) => c.copy(numPoints = x)) + opt[Int]("maxIterations") + .text(s"number of iterations, default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Double]('r', "r") + .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") + .action((x, c) => c.copy(outerRadius = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf() + .setMaster("local") + .setAppName(s"PowerIterationClustering with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + val model = new PowerIterationClustering() + .setK(params.k) + .setMaxIterations(params.maxIterations) + .run(circlesRdd) + + val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) + val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignmentsStr = assignments + .map { case (k, v) => + s"$k -> ${v.sorted.mkString("[", ",", "]")}" + }.mkString(",") + val sizesStr = assignments.map { + _._2.size + }.sorted.mkString("(", ",", ")") + println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") + + sc.stop() + } + + def generateCircle(radius: Double, n: Int): Seq[(Double, Double)] = { + Seq.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (radius * math.cos(theta), radius * math.sin(theta)) + } + } + + def generateCirclesRdd(sc: SparkContext, + nCircles: Int = 3, + nPoints: Int = 30, + outerRadius: Double): RDD[(Long, Long, Double)] = { + + val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} + val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} + val points = (0 until nCircles).flatMap { cx => + generateCircle(radii(cx), groupSizes(cx)) + }.zipWithIndex + val rdd = sc.parallelize(points) + val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => + if (i0 < i1) { + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + } else { + None + } + } + distancesRdd + } + + /** + * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel + */ + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = { + val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) + val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) + coeff * math.exp(expCoeff * ssquares) + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af9..bee85ba0f996 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af6..6963f43e082c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5..f81fc292a3bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed..af03724a8ac6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2..b4a5dca031ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d7..b42f4cb5f933 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e30..464fbd385ab5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615..65b4bc46f026 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index a11890d6f2b1..805184e740f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.pythonconverters import java.util.{Collection => JCollection, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.mapred.AvroWrapper @@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable { return null } schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}") } } @@ -59,7 +58,7 @@ object AvroConversionUtil extends Serializable { val map = new java.util.HashMap[String, Any] obj match { case record: IndexedRecord => - record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + record.getSchema.getFields.asScala.zipWithIndex.foreach { case (f, i) => map.put(f.name, fromAvro(record.get(i), f.schema)) } case other => throw new SparkException( @@ -69,9 +68,9 @@ object AvroConversionUtil extends Serializable { } def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { - obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + obj.asInstanceOf[JMap[_, _]].asScala.map { case (key, value) => (key.toString, fromAvro(value, schema.getValueType)) - } + }.asJava } def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { @@ -92,17 +91,17 @@ object AvroConversionUtil extends Serializable { def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { case c: JCollection[_] => - c.map(fromAvro(_, schema.getElementType)) + c.asScala.map(fromAvro(_, schema.getElementType)).toSeq.asJava case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => - arr.toSeq + arr.toSeq.asJava.asInstanceOf[JCollection[Any]] case arr: Array[_] => - arr.map(fromAvro(_, schema.getElementType)).toSeq + arr.map(fromAvro(_, schema.getElementType)).toSeq.asJava case other => throw new SparkException( s"Unknown ARRAY type ${other.getClass.getName}") } def unpackUnion(obj: Any, schema: Schema): Any = { - schema.getTypes.toList match { + schema.getTypes.asScala.toList match { case List(s) => fromAvro(obj, s) case List(n, s) if n.getType == NULL => fromAvro(obj, s) case List(s, n) if n.getType == NULL => fromAvro(obj, s) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 83feb5703b90..00ce47af4813 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples.pythonconverters -import org.apache.spark.api.python.Converter import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions._ +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra @@ -30,7 +32,7 @@ import collection.JavaConversions._ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] { override def convert(obj: Any): java.util.Map[String, Int] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb))) + result.asScala.mapValues(ByteBufferUtil.toInt).asJava } } @@ -41,7 +43,7 @@ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int] class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] { override def convert(obj: Any): java.util.Map[String, String] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) + result.asScala.mapValues(ByteBufferUtil.string).asJava } } @@ -52,7 +54,7 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { val input = obj.asInstanceOf[java.util.Map[String, Int]] - mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + input.asScala.mapValues(ByteBufferUtil.bytes).asJava } } @@ -63,6 +65,6 @@ class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, By class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { override def convert(obj: Any): java.util.List[ByteBuffer] = { val input = obj.asInstanceOf[java.util.List[String]] - seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + input.asScala.map(ByteBufferUtil.bytes).asJava } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 273bee0a8b30..0a25ee7ae56f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -17,21 +17,34 @@ package org.apache.spark.examples.pythonconverters -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.KeyValue.Type +import org.apache.hadoop.hbase.CellUtil /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts an - * HBase Result to a String + * Implementation of [[org.apache.spark.api.python.Converter]] that converts all + * the records in an HBase Result to a String */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) + val output = result.listCells.asScala.map(cell => + Map( + "row" -> Bytes.toStringBinary(CellUtil.cloneRow(cell)), + "columnFamily" -> Bytes.toStringBinary(CellUtil.cloneFamily(cell)), + "qualifier" -> Bytes.toStringBinary(CellUtil.cloneQualifier(cell)), + "timestamp" -> cell.getTimestamp.toString, + "type" -> Type.codeToType(cell.getTypeByte).toString, + "value" -> Bytes.toStringBinary(CellUtil.cloneValue(cell)) + ) + ) + output.map(JSONObject(_).toString()).mkString("\n") } } @@ -63,7 +76,7 @@ class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBy */ class StringListToPutConverter extends Converter[Any, Put] { override def convert(obj: Any): Put = { - val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val output = obj.asInstanceOf[java.util.ArrayList[String]].asScala.map(Bytes.toBytes).toArray val put = new Put(output(0)) put.add(output(1), output(2), output(3)) } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 82a0b637b3cf..2cc56f04e5c1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,11 +15,12 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -32,44 +33,45 @@ object RDDRelation { val sqlContext = new SQLContext(sc) // Importing the SQL context gives access to all the SQL functions and implicit conversions. - import sqlContext._ + import sqlContext.implicits._ - val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF() // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerTempTable("records") + df.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") - sql("SELECT * FROM records").collect().foreach(println) + sqlContext.sql("SELECT * FROM records").collect().foreach(println) // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) + val count = sqlContext.sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10") + val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) + df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - rdd.saveAsParquetFile("pair.parquet") + df.write.parquet("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. - val parquetFile = sqlContext.parquetFile("pair.parquet") + val parquetFile = sqlContext.read.parquet("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") - sql("SELECT * FROM parquetFile").collect().foreach(println) + sqlContext.sql("SELECT * FROM parquetFile").collect().foreach(println) sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 5725da184811..bf40bd1ef13d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -43,7 +44,8 @@ object HiveFromSpark { // HiveContext. When not configured by the hive-site.xml, the context automatically // creates metastore_db and warehouse in the current directory. val hiveContext = new HiveContext(sc) - import hiveContext._ + import hiveContext.implicits._ + import hiveContext.sql sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(s"LOAD DATA LOCAL INPATH '${kv1File.getAbsolutePath}' INTO TABLE src") @@ -67,7 +69,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerTempTable("records") + rdd.toDF().registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") @@ -76,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a..e9c990719876 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -85,13 +86,13 @@ extends Actor with ActorHelper { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - override def preStart = remotePublisher ! SubscribeReceiver(context.self) + override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - def receive = { + def receive: PartialFunction[Any, Unit] = { case msg => store(msg.asInstanceOf[T]) } - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) } @@ -104,10 +105,8 @@ extends Actor with ActorHelper { object FeederActor { def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) + if (args.length < 2){ + System.err.println("Usage: FeederActor \n") System.exit(1) } val Seq(host, port) = args.toSeq @@ -172,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae9..28e9bf520e56 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala new file mode 100644 index 000000000000..bd78526f8c29 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.streaming + +import kafka.serializer.StringDecoder + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.kafka._ +import org.apache.spark.SparkConf + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ + * topic1,topic2 + */ +object DirectKafkaWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println(s""" + |Usage: DirectKafkaWordCount + | is a list of one or more Kafka brokers + | is a list of one or more kafka topics to consume from + | + """.stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(brokers, topics) = args + + // Create context with 2 second batch interval + val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create direct kafka stream with brokers and topics + val topicsSet = topics.split(",").toSet + val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topicsSet) + + // Get the lines, split them into words, count the words and print + val lines = messages.map(_._2) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _) + wordCounts.print() + + // Start the computation + ssc.start() + ssc.awaitTermination() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1..91e52e4eff5a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b..2bdbc37e2a28 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e..1f282d437dc3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 75% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 387c0e421334..b40d17e9c2fa 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,11 +15,12 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming -import java.util.Properties +import java.util.HashMap -import kafka.producer._ +import org.apache.kafka.clients.producer.{ProducerConfig, KafkaProducer, ProducerRecord} import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka._ @@ -49,10 +50,10 @@ object KafkaWordCount { val Array(zkQuorum, group, topics, numThreads) = args val sparkConf = new SparkConf().setAppName("KafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) @@ -77,25 +78,28 @@ object KafkaWordCountProducer { val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args // Zookeeper connection properties - val props = new Properties() - props.put("metadata.broker.list", brokers) - props.put("serializer.class", "kafka.serializer.StringEncoder") + val props = new HashMap[String, Object]() + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers) + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringSerializer") + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringSerializer") - val config = new ProducerConfig(props) - val producer = new Producer[String, String](config) + val producer = new KafkaProducer[String, String](props) // Send some messages while(true) { - val messages = (1 to messagesPerSec.toInt).map { messageNum => + (1 to messagesPerSec.toInt).foreach { messageNum => val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString) .mkString(" ") - new KeyedMessage[String, String](topic, str) - }.toArray + val message = new ProducerRecord[String, String](topic, null, str) + producer.send(message) + } - producer.send(messages: _*) - Thread.sleep(100) + Thread.sleep(1000) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 6ff0c47793a2..d772ae309f40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,10 +15,11 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming -import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic} -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -31,8 +32,6 @@ import org.apache.spark.SparkConf */ object MQTTPublisher { - var client: MqttClient = _ - def main(args: Array[String]) { if (args.length < 2) { System.err.println("Usage: MQTTPublisher ") @@ -43,24 +42,35 @@ object MQTTPublisher { val Seq(brokerUrl, topic) = args.toSeq + var client: MqttClient = null + try { - var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp") - client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + val persistence = new MemoryPersistence() + client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) + + client.connect() + + val msgtopic = client.getTopic(topic) + val msgContent = "hello mqtt demo for spark streaming" + val message = new MqttMessage(msgContent.getBytes("utf-8")) + + while (true) { + try { + msgtopic.publish(message) + println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + Thread.sleep(10) + println("Queue is full, wait for to consume data from the message queue") + } + } } catch { case e: MqttException => println("Exception Caught: " + e) + } finally { + if (client != null) { + client.disconnect() + } } - - client.connect() - - val msgtopic: MqttTopic = client.getTopic(topic) - val msg: String = "hello mqtt demo for spark streaming" - - while (true) { - val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes("utf-8")) - msgtopic.publish(message) - println("Published data. topic: " + msgtopic.getName() + " Message: " + message) - } - client.disconnect() } } @@ -87,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -96,11 +108,12 @@ object MQTTWordCount { val sparkConf = new SparkConf().setAppName("MQTTWordCount") val ssc = new StreamingContext(sparkConf, Seconds(2)) val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) - - val words = lines.flatMap(x => x.toString.split(" ")) + val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() ssc.start() ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada1..9a57fe286d1a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb..5322929d177b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index c3a05c89d817..9916882e4f94 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -55,7 +56,8 @@ import org.apache.spark.util.IntParam */ object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) + : StreamingContext = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint @@ -107,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala new file mode 100644 index 000000000000..ed617754cbf1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.util.IntParam +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel + +/** + * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + * network every second. + * + * Usage: SqlNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example org.apache.spark.examples.streaming.SqlNetworkWordCount localhost 9999` + */ + +object SqlNetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + // Create the context with a 2 second batch size + val sparkConf = new SparkConf().setAppName("SqlNetworkWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create a socket stream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + // Note that no duplication in storage level only for running locally. + // Replication necessary in distributed scenario for fault tolerance. + val lines = ssc.socketTextStream(args(0), args(1).toInt, StorageLevel.MEMORY_AND_DISK_SER) + val words = lines.flatMap(_.split(" ")) + + // Convert RDDs of the words DStream to DataFrame and run SQL query + words.foreachRDD((rdd: RDD[String], time: Time) => { + // Get the singleton instance of SQLContext + val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + import sqlContext.implicits._ + + // Convert RDD[String] to RDD[case class] to DataFrame + val wordsDataFrame = rdd.map(w => Record(w)).toDF() + + // Register as table + wordsDataFrame.registerTempTable("words") + + // Do word count on table using SQL and print it + val wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word") + println(s"========= $time =========") + wordCountsDataFrame.show() + }) + + ssc.start() + ssc.awaitTermination() + } +} + + +/** Case class for converting RDD to DataFrame */ +case class Record(word: String) + + +/** Lazily instantiated singleton instance of SQLContext */ +object SQLContextSingleton { + + @transient private var instance: SQLContext = _ + + def getInstance(sparkContext: SparkContext): SQLContext = { + if (instance == null) { + instance = new SQLContext(sparkContext) + } + instance + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc44135..02ba1c2eed0f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index 62f49530edb1..825c671a929b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,9 +15,11 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ +import com.twitter.algebird.CMSHasherImplicits._ import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ @@ -67,7 +69,8 @@ object TwitterAlgebirdCMS { val users = stream.map(status => status.getUser.getId) - val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC) var globalCMS = cms.zero val mm = new MapMonoid[Long, Int]() var globalExact = Map[Long, Int]() @@ -111,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8a..49826ede7041 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f..49cee1b43c2d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6510c70bd186..6ac9a72c3794 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -35,7 +36,7 @@ import org.apache.spark.SparkConf */ object SimpleZeroMQPublisher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SimpleZeroMQPublisher ") System.exit(1) @@ -45,7 +46,7 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String) = ByteString(x) + implicit def stringToByteString(x: String): ByteString = ByteString(x) val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) @@ -86,7 +87,7 @@ object ZeroMQWordCount { // Create the context and set the batch size val ssc = new StreamingContext(sparkConf, Seconds(2)) - def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator + def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 8402491b6267..bea7a47cb285 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -57,8 +58,7 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - + val userID = Map((1 to 100).map(_ -> .01) : _*) def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { val rand = new Random().nextDouble() @@ -94,7 +94,7 @@ object PageViewGenerator { while (true) { val socket = listener.accept() new Thread() { - override def run = { + override def run(): Unit = { println("Got client connected from: " + socket.getInetAddress) val out = new PrintWriter(socket.getOutputStream(), true) @@ -109,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690..ec7d39da8b2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 000000000000..561ed4babe5d --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,168 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + provided + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.scala-lang + scala-library + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + + + diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0706f1ebf66e..0664cfb2021e 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -38,15 +38,46 @@ org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index 17cbc6707b5e..d87b86932dd4 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -113,7 +113,9 @@ private[sink] object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. + // scalastyle:off classforname val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + // scalastyle:on classforname bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 4373be443e67..719fca0938b3 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.streaming.flume.sink +import java.util.UUID import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.flume.Channel -import org.apache.commons.lang.RandomStringUtils -import com.google.common.util.concurrent.ThreadFactoryBuilder /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process @@ -45,8 +44,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. @@ -55,7 +53,7 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha // Since the new txn may not have the same sequence number we must guard against accidentally // committing a new transaction. To reduce the probability of that happening a random string is // prepended to the sequence number. Does not change for life of sink - private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqBase = UUID.randomUUID().toString.substring(0, 8) private val seqCounter = new AtomicLong(0) // Protected by `sequenceNumberToProcessor` diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala new file mode 100644 index 000000000000..845fc8debda7 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong + +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } + +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index ea45b14294df..7ad43b1d7b0a 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg(msg) } else { // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("",seqNum, events) + eventBatch = new EventBatch("", seqNum, events) } }) } catch { diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 2a58e9981722..42df8792f147 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c14..d2654700ea72 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -20,20 +20,28 @@ import java.net.InetSocketAddress import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -158,7 +166,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("capacity", channelCapacity.toString) channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides) + channelContext.putAll(overrides.asJava) channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) @@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 1f2681394c58..14f7daaf417e 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index dc629df4f4ac..48df27b26867 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.io.{ObjectOutput, ObjectInput} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils import org.apache.spark.Logging @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k,v) <- headers) { + for ((k, v) <- headers.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala index 88cc2aa3bf02..b9d4e762ca05 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -155,7 +154,7 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) var j = 0 while (j < events.size()) { - val event = events(j) + val event = events.get(j) val sparkFlumeEvent = new SparkFlumeEvent() sparkFlumeEvent.event.setBody(event.getBody) sparkFlumeEvent.event.setHeaders(event.getHeaders) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2de2a7926bfd..2bf99cb3cba1 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -22,7 +22,7 @@ import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.flume.source.avro.AvroSourceProtocol @@ -37,8 +37,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver -import org.jboss.netty.channel.ChannelPipelineFactory -import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ @@ -100,7 +99,7 @@ class SparkFlumeEvent() extends Externalizable { val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { + for ((k, v) <- event.getHeaders.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) @@ -128,8 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) + events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } @@ -153,9 +151,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -187,24 +185,23 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def preferredLocation = Some(host) - - /** A Netty Pipeline factory that will decompress incoming data from + override def preferredLocation: Option[String] = Option(host) + + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - - def getPipeline() = { + def getPipeline(): ChannelPipeline = { val pipeline = Channels.pipeline() val encoder = new ZlibEncoder(6) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) pipeline + } } } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89..0bc46209b836 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{LinkedBlockingQueue, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -94,9 +94,7 @@ private[streaming] class FlumePollingReceiver( override def onStop(): Unit = { logInfo("Shutting down Flume Polling Receiver") receiverExecutor.shutdownNow() - connections.foreach(connection => { - connection.transceiver.close() - }) + connections.asScala.foreach(_.transceiver.close()) channelFactory.releaseExternalResources() } @@ -110,7 +108,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 000000000000..70018c86f92b --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.Collections + +import scala.collection.JavaConverters._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: Seq[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Collections.singletonMap("test", "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.asJava) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 4b732c1592ab..c719b80aca7e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,11 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} -import org.apache.spark.annotation.Experimental +import scala.collection.JavaConverters._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -121,7 +126,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, hostname: String, @@ -138,7 +142,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses representing the hosts to connect to. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -159,7 +162,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -178,7 +180,6 @@ object FlumeUtils { * @param hostname Hostname of the host on which the Spark Sink is running * @param port Port of the host at which the Spark Sink is listening */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -195,7 +196,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -212,7 +212,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses on which the Spark Sink is running. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], @@ -233,7 +232,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], @@ -244,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private[flume] class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.size() == ports.size()) + val addresses = hosts.asScala.zip(ports.asScala).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.asScala.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 000000000000..a2ab320957db --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{Map => JMap, Collections} + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): Seq[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput(outputHeaders: Seq[JMap[String, String]], outputBodies: Seq[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies(j) && + eventHeaderToVerify == outputHeaders(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Collections.singletonMap(s"test-$t", "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index b57a1c71e35b..ff2fb8eed204 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,63 +1,50 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder - -import org.scalatest.{BeforeAndAfter, FunSuite} +import com.google.common.base.Charsets.UTF_8 +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -88,69 +75,33 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - ssc.start() + try { + val port = utils.startSingleSink() - writeAndVerify(Seq(channel), ssc, outputBuffer) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() + try { + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } + } + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) @@ -158,87 +109,21 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging ssc.start() try { - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() - } - } - - def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - channels.map(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - for (i <- 0 until channels.size) { - executorCompletion.take() - } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < batchCount * channels.size && - System.currentTimeMillis() - startTime < 15000) { - logInfo("output.size = " + outputBuffer.size) - Thread.sleep(100) - } - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - - def assertChannelIsEmpty(channel: MemoryChannel) = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach - clock.addToTime(batchDuration.milliseconds) + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + case (key, value) => (key.toString, value.toString) + }).map(_.asJava) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f333e3891b5f..5ffb60bd602f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,46 +17,26 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer -import java.nio.charset.Charset - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} +import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} -import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,18 +49,29 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -97,61 +88,10 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - val decoder = Charset.forName("UTF-8").newDecoder() - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) - output should be (input) - } - } - /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 503fc129dc4f..6f4e2a89e9af 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -33,9 +33,6 @@ streaming-kafka-assembly - scala-${scala.binary.version} - spark-streaming-kafka-assembly-${project.version}.jar - ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} @@ -50,6 +47,90 @@ ${project.version} provided + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + net.jpountz.lz4 + lz4 + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + @@ -61,7 +142,6 @@ maven-shade-plugin false - ${spark.jar} *:* diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index b29b0509656b..ded863bd985e 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,10 +41,17 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} - 0.8.0 + 0.8.2.1 com.sun.jmx diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala new file mode 100644 index 000000000000..9159051ba06e --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.annotation.Experimental + +/** + * Represents the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID. + */ +final class Broker private( + /** Broker's hostname */ + val host: String, + /** Broker's port */ + val port: Int) extends Serializable { + override def equals(obj: Any): Boolean = obj match { + case that: Broker => + this.host == that.host && + this.port == that.port + case _ => false + } + + override def hashCode: Int = { + 41 * (41 + host.hashCode) + port + } + + override def toString(): String = { + s"Broker($host, $port)" + } +} + +/** + * :: Experimental :: + * Companion object that provides methods to create instances of [[Broker]]. + */ +@Experimental +object Broker { + def create(host: String, port: Int): Broker = + new Broker(host, port) + + def apply(host: String, port: Int): Broker = + new Broker(host, port) + + def unapply(broker: Broker): Option[(String, Int)] = { + if (broker == null) { + None + } else { + Some((broker.host, broker.port)) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala new file mode 100644 index 000000000000..1000094e93cb --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.reflect.ClassTag + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +/** + * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where + * each given Kafka topic/partition corresponds to an RDD partition. + * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number + * of messages + * per second that each '''partition''' will accept. + * Starting offsets are specified in advance, + * and this DStream is not responsible for committing offsets, + * so that you can control exactly-once semantics. + * For an easy interface to Kafka-managed offsets, + * see {@link org.apache.spark.streaming.kafka.KafkaCluster} + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler function for translating each message into the desired type + */ +private[streaming] +class DirectKafkaInputDStream[ + K: ClassTag, + V: ClassTag, + U <: Decoder[K]: ClassTag, + T <: Decoder[V]: ClassTag, + R: ClassTag]( + @transient ssc_ : StreamingContext, + val kafkaParams: Map[String, String], + val fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R + ) extends InputDStream[R](ssc_) with Logging { + val maxRetries = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRetries", 1) + + // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") + private[streaming] override def name: String = s"Kafka direct stream [$id]" + + protected[streaming] override val checkpointData = + new DirectKafkaInputDStreamCheckpointData + + + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new DirectKafkaRateController(id, + RateEstimator.create(ssc.conf, ssc_.graph.batchDuration))) + } else { + None + } + } + + protected val kc = new KafkaCluster(kafkaParams) + + private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRatePerPartition", 0) + protected def maxMessagesPerPartition: Option[Long] = { + val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val numPartitions = currentOffsets.keys.size + + val effectiveRateLimitPerPartition = estimatedRateLimit + .filter(_ > 0) + .map { limit => + if (maxRateLimitPerPartition > 0) { + Math.min(maxRateLimitPerPartition, (limit / numPartitions)) + } else { + limit / numPartitions + } + }.getOrElse(maxRateLimitPerPartition) + + if (effectiveRateLimitPerPartition > 0) { + val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 + Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) + } else { + None + } + } + + protected var currentOffsets = fromOffsets + + @tailrec + protected final def latestLeaderOffsets(retries: Int): Map[TopicAndPartition, LeaderOffset] = { + val o = kc.getLatestLeaderOffsets(currentOffsets.keySet) + // Either.fold would confuse @tailrec, do it manually + if (o.isLeft) { + val err = o.left.get.toString + if (retries <= 0) { + throw new SparkException(err) + } else { + log.error(err) + Thread.sleep(kc.config.refreshLeaderBackoffMs) + latestLeaderOffsets(retries - 1) + } + } else { + o.right.get + } + } + + // limits the maximum number of messages per partition + protected def clamp( + leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = { + maxMessagesPerPartition.map { mmp => + leaderOffsets.map { case (tp, lo) => + tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset)) + } + }.getOrElse(leaderOffsets) + } + + override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = { + val untilOffsets = clamp(latestLeaderOffsets(maxRetries)) + val rdd = KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) + + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + + currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) + Some(rdd) + } + + override def start(): Unit = { + } + + def stop(): Unit = { + } + + private[streaming] + class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } + + override def update(time: Time) { + batchForTime.clear() + generatedRDDs.foreach { kv => + val a = kv._2.asInstanceOf[KafkaRDD[K, V, U, T, R]].offsetRanges.map(_.toTuple).toArray + batchForTime += kv._1 -> a + } + } + + override def cleanup(time: Time) { } + + override def restore() { + // this is assuming that the topics don't change during execution, which is true currently + val topics = fromOffsets.keySet + val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) + + batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + } + } + } + + /** + * A RateController to retrieve the rate from RateEstimator. + */ + private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala new file mode 100644 index 000000000000..8465432c5850 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.control.NonFatal +import scala.util.Random +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ +import java.util.Properties +import kafka.api._ +import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition} +import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import org.apache.spark.SparkException + +/** + * Convenience methods for interacting with a Kafka cluster. + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form + */ +private[spark] +class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { + import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} + + // ConsumerConfig isn't serializable + @transient private var _config: SimpleConsumerConfig = null + + def config: SimpleConsumerConfig = this.synchronized { + if (_config == null) { + _config = SimpleConsumerConfig(kafkaParams) + } + _config + } + + def connect(host: String, port: Int): SimpleConsumer = + new SimpleConsumer(host, port, config.socketTimeoutMs, + config.socketReceiveBufferBytes, config.clientId) + + def connectLeader(topic: String, partition: Int): Either[Err, SimpleConsumer] = + findLeader(topic, partition).right.map(hp => connect(hp._1, hp._2)) + + // Metadata api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-MetadataAPI + // scalastyle:on + + def findLeader(topic: String, partition: Int): Either[Err, (String, Int)] = { + val req = TopicMetadataRequest(TopicMetadataRequest.CurrentVersion, + 0, config.clientId, Seq(topic)) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + resp.topicsMetadata.find(_.topic == topic).flatMap { tm: TopicMetadata => + tm.partitionsMetadata.find(_.partitionId == partition) + }.foreach { pm: PartitionMetadata => + pm.leader.foreach { leader => + return Right((leader.host, leader.port)) + } + } + } + Left(errs) + } + + def findLeaders( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, (String, Int)]] = { + val topics = topicAndPartitions.map(_.topic) + val response = getPartitionMetadata(topics).right + val answer = response.flatMap { tms: Set[TopicMetadata] => + val leaderMap = tms.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.flatMap { pm: PartitionMetadata => + val tp = TopicAndPartition(tm.topic, pm.partitionId) + if (topicAndPartitions(tp)) { + pm.leader.map { l => + tp -> (l.host -> l.port) + } + } else { + None + } + } + }.toMap + + if (leaderMap.keys.size == topicAndPartitions.size) { + Right(leaderMap) + } else { + val missing = topicAndPartitions.diff(leaderMap.keySet) + val err = new Err + err.append(new SparkException(s"Couldn't find leaders for ${missing}")) + Left(err) + } + } + answer + } + + def getPartitions(topics: Set[String]): Either[Err, Set[TopicAndPartition]] = { + getPartitionMetadata(topics).right.map { r => + r.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.map { pm: PartitionMetadata => + TopicAndPartition(tm.topic, pm.partitionId) + } + } + } + } + + def getPartitionMetadata(topics: Set[String]): Either[Err, Set[TopicMetadata]] = { + val req = TopicMetadataRequest( + TopicMetadataRequest.CurrentVersion, 0, config.clientId, topics.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + val respErrs = resp.topicsMetadata.filter(m => m.errorCode != ErrorMapping.NoError) + + if (respErrs.isEmpty) { + return Right(resp.topicsMetadata.toSet) + } else { + respErrs.foreach { m => + val cause = ErrorMapping.exceptionFor(m.errorCode) + val msg = s"Error getting partition metadata for '${m.topic}'. Does the topic exist?" + errs.append(new SparkException(msg, cause)) + } + } + } + Left(errs) + } + + // Leader offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetAPI + // scalastyle:on + + def getLatestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.LatestTime) + + def getEarliestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.EarliestTime) + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = { + getLeaderOffsets(topicAndPartitions, before, 1).right.map { r => + r.map { kv => + // mapValues isnt serializable, see SI-7005 + kv._1 -> kv._2.head + } + } + } + + private def flip[K, V](m: Map[K, V]): Map[V, Seq[K]] = + m.groupBy(_._2).map { kv => + kv._1 -> kv._2.keys.toSeq + } + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long, + maxNumOffsets: Int + ): Either[Err, Map[TopicAndPartition, Seq[LeaderOffset]]] = { + findLeaders(topicAndPartitions).right.flatMap { tpToLeader => + val leaderToTp: Map[(String, Int), Seq[TopicAndPartition]] = flip(tpToLeader) + val leaders = leaderToTp.keys + var result = Map[TopicAndPartition, Seq[LeaderOffset]]() + val errs = new Err + withBrokers(leaders, errs) { consumer => + val partitionsToGetOffsets: Seq[TopicAndPartition] = + leaderToTp((consumer.host, consumer.port)) + val reqMap = partitionsToGetOffsets.map { tp: TopicAndPartition => + tp -> PartitionOffsetRequestInfo(before, maxNumOffsets) + }.toMap + val req = OffsetRequest(reqMap) + val resp = consumer.getOffsetsBefore(req) + val respMap = resp.partitionErrorAndOffsets + partitionsToGetOffsets.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { por: PartitionOffsetsResponse => + if (por.error == ErrorMapping.NoError) { + if (por.offsets.nonEmpty) { + result += tp -> por.offsets.map { off => + LeaderOffset(consumer.host, consumer.port, off) + } + } else { + errs.append(new SparkException( + s"Empty offsets for ${tp}, is ${before} before log beginning?")) + } + } else { + errs.append(ErrorMapping.exceptionFor(por.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find leader offsets for ${missing}")) + Left(errs) + } + } + + // Consumer offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI + // scalastyle:on + + // this 0 here indicates api version, in this case the original ZK backed api. + private def defaultConsumerApiVersion: Short = 0 + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, Long]] = + getConsumerOffsets(groupId, topicAndPartitions, defaultConsumerApiVersion) + + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition], + consumerApiVersion: Short + ): Either[Err, Map[TopicAndPartition, Long]] = { + getConsumerOffsetMetadata(groupId, topicAndPartitions, consumerApiVersion).right.map { r => + r.map { kv => + kv._1 -> kv._2.offset + } + } + } + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = + getConsumerOffsetMetadata(groupId, topicAndPartitions, defaultConsumerApiVersion) + + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition], + consumerApiVersion: Short + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = { + var result = Map[TopicAndPartition, OffsetMetadataAndError]() + val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq, consumerApiVersion) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.fetchOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { ome: OffsetMetadataAndError => + if (ome.error == ErrorMapping.NoError) { + result += tp -> ome + } else { + errs.append(ErrorMapping.exceptionFor(ome.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find consumer offsets for ${missing}")) + Left(errs) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long] + ): Either[Err, Map[TopicAndPartition, Short]] = + setConsumerOffsets(groupId, offsets, defaultConsumerApiVersion) + + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long], + consumerApiVersion: Short + ): Either[Err, Map[TopicAndPartition, Short]] = { + val meta = offsets.map { kv => + kv._1 -> OffsetAndMetadata(kv._2) + } + setConsumerOffsetMetadata(groupId, meta, consumerApiVersion) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetAndMetadata] + ): Either[Err, Map[TopicAndPartition, Short]] = + setConsumerOffsetMetadata(groupId, metadata, defaultConsumerApiVersion) + + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetAndMetadata], + consumerApiVersion: Short + ): Either[Err, Map[TopicAndPartition, Short]] = { + var result = Map[TopicAndPartition, Short]() + val req = OffsetCommitRequest(groupId, metadata, consumerApiVersion) + val errs = new Err + val topicAndPartitions = metadata.keySet + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.commitOffsets(req) + val respMap = resp.commitStatus + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { err: Short => + if (err == ErrorMapping.NoError) { + result += tp -> err + } else { + errs.append(ErrorMapping.exceptionFor(err)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't set offsets for ${missing}")) + Left(errs) + } + + // Try a call against potentially multiple brokers, accumulating errors + private def withBrokers(brokers: Iterable[(String, Int)], errs: Err) + (fn: SimpleConsumer => Any): Unit = { + brokers.foreach { hp => + var consumer: SimpleConsumer = null + try { + consumer = connect(hp._1, hp._2) + fn(consumer) + } catch { + case NonFatal(e) => + errs.append(e) + } finally { + if (consumer != null) { + consumer.close() + } + } + } + } +} + +private[spark] +object KafkaCluster { + type Err = ArrayBuffer[Throwable] + + /** If the result is right, return it, otherwise throw SparkException */ + def checkErrors[T](result: Either[Err, T]): T = { + result.fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + + private[spark] + case class LeaderOffset(host: String, port: Int, offset: Long) + + /** + * High-level kafka consumers connect to ZK. ConsumerConfig assumes this use case. + * Simple consumers connect directly to brokers, but need many of the same configs. + * This subclass won't warn about missing ZK params, or presence of broker params. + */ + private[spark] + class SimpleConsumerConfig private(brokers: String, originalProps: Properties) + extends ConsumerConfig(originalProps) { + val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => + val hpa = hp.split(":") + if (hpa.size == 1) { + throw new SparkException(s"Broker not the in correct format of : [$brokers]") + } + (hpa(0), hpa(1).toInt) + } + } + + private[spark] + object SimpleConsumerConfig { + /** + * Make a consumer config without requiring group.id or zookeeper.connect, + * since communicating with brokers also needs common settings such as timeout + */ + def apply(kafkaParams: Map[String, String]): SimpleConsumerConfig = { + // These keys are from other pre-existing kafka configs for specifying brokers, accept either + val brokers = kafkaParams.get("metadata.broker.list") + .orElse(kafkaParams.get("bootstrap.servers")) + .getOrElse(throw new SparkException( + "Must specify metadata.broker.list or bootstrap.servers")) + + val props = new Properties() + kafkaParams.foreach { case (key, value) => + // prevent warnings on parameters ConsumerConfig doesn't know about + if (key != "metadata.broker.list" && key != "bootstrap.servers") { + props.put(key, value) + } + } + + Seq("zookeeper.connect", "group.id").foreach { s => + if (!props.containsKey(s)) { + props.setProperty(s, "") + } + } + + new SimpleConsumerConfig(brokers, props) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 4d26b640e8d7..04b2dc10d39e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * Input stream that pulls messages from a Kafka Broker. @@ -111,7 +111,8 @@ class KafkaReceiver[ val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") + val executorPool = + ThreadUtils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -134,7 +135,7 @@ class KafkaReceiver[ store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { - case e: Throwable => logError("Error handling message; exiting", e) + case e: Throwable => reportError("Error handling message; exiting", e) } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala new file mode 100644 index 000000000000..ea5f842c6caf --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{PartialResult, BoundedDouble} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + +import kafka.api.{FetchRequestBuilder, FetchResponse} +import kafka.common.{ErrorMapping, TopicAndPartition} +import kafka.consumer.SimpleConsumer +import kafka.message.{MessageAndMetadata, MessageAndOffset} +import kafka.serializer.Decoder +import kafka.utils.VerifiableProperties + +/** + * A batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param messageHandler function for translating each message into the desired type + */ +private[kafka] +class KafkaRDD[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag] private[spark] ( + sc: SparkContext, + kafkaParams: Map[String, String], + val offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, (String, Int)], + messageHandler: MessageAndMetadata[K, V] => R + ) extends RDD[R](sc, Nil) with Logging with HasOffsetRanges { + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => + val (host, port) = leaders(TopicAndPartition(o.topic, o.partition)) + new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port) + }.toArray + } + + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[R] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.size < 1) { + return new Array[R](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[R] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, + parts.keys.toArray) + res.foreach(buf ++= _) + buf.toArray + } + + override def getPreferredLocations(thePart: Partition): Seq[String] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + // TODO is additional hostname resolution necessary here + Seq(part.host) + } + + private def errBeginAfterEnd(part: KafkaRDDPartition): String = + s"Beginning offset ${part.fromOffset} is after the ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged" + + private def errRanOutBeforeEnd(part: KafkaRDDPartition): String = + s"Ran out of messages before reaching ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates that messages may have been lost" + + private def errOvershotEnd(itemOffset: Long, part: KafkaRDDPartition): String = + s"Got ${itemOffset} > ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates a message may have been skipped" + + override def compute(thePart: Partition, context: TaskContext): Iterator[R] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + if (part.fromOffset == part.untilOffset) { + log.info(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + s"skipping ${part.topic} ${part.partition}") + Iterator.empty + } else { + new KafkaRDDIterator(part, context) + } + } + + private class KafkaRDDIterator( + part: KafkaRDDPartition, + context: TaskContext) extends NextIterator[R] { + + context.addTaskCompletionListener{ context => closeIfNeeded() } + + log.info(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + + val kc = new KafkaCluster(kafkaParams) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[K]] + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[V]] + val consumer = connectLeader + var requestOffset = part.fromOffset + var iter: Iterator[MessageAndOffset] = null + + // The idea is to use the provided preferred host, except on task retry atttempts, + // to minimize number of kafka metadata requests + private def connectLeader: SimpleConsumer = { + if (context.attemptNumber > 0) { + kc.connectLeader(part.topic, part.partition).fold( + errs => throw new SparkException( + s"Couldn't connect to leader for topic ${part.topic} ${part.partition}: " + + errs.mkString("\n")), + consumer => consumer + ) + } else { + kc.connect(part.host, part.port) + } + } + + private def handleFetchErr(resp: FetchResponse) { + if (resp.hasError) { + val err = resp.errorCode(part.topic, part.partition) + if (err == ErrorMapping.LeaderNotAvailableCode || + err == ErrorMapping.NotLeaderForPartitionCode) { + log.error(s"Lost leader for topic ${part.topic} partition ${part.partition}, " + + s" sleeping for ${kc.config.refreshLeaderBackoffMs}ms") + Thread.sleep(kc.config.refreshLeaderBackoffMs) + } + // Let normal rdd retry sort out reconnect attempts + throw ErrorMapping.exceptionFor(err) + } + } + + private def fetchBatch: Iterator[MessageAndOffset] = { + val req = new FetchRequestBuilder() + .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) + .build() + val resp = consumer.fetch(req) + handleFetchErr(resp) + // kafka may return a batch that starts before the requested offset + resp.messageSet(part.topic, part.partition) + .iterator + .dropWhile(_.offset < requestOffset) + } + + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } + + override def getNext(): R = { + if (iter == null || !iter.hasNext) { + iter = fetchBatch + } + if (!iter.hasNext) { + assert(requestOffset == part.untilOffset, errRanOutBeforeEnd(part)) + finished = true + null.asInstanceOf[R] + } else { + val item = iter.next() + if (item.offset >= part.untilOffset) { + assert(item.offset == part.untilOffset, errOvershotEnd(item.offset, part)) + finished = true + null.asInstanceOf[R] + } else { + requestOffset = item.nextOffset + messageHandler(new MessageAndMetadata( + part.topic, part.partition, item.message, item.offset, keyDecoder, valueDecoder)) + } + } + } + } +} + +private[kafka] +object KafkaRDD { + import KafkaCluster.LeaderOffset + + /** + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the batch + * @param untilOffsets per-topic/partition Kafka offsets defining the (exclusive) + * ending point of the batch + * @param messageHandler function for translating each message into the desired type + */ + def apply[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + untilOffsets: Map[TopicAndPartition, LeaderOffset], + messageHandler: MessageAndMetadata[K, V] => R + ): KafkaRDD[K, V, U, T, R] = { + val leaders = untilOffsets.map { case (tp, lo) => + tp -> (lo.host, lo.port) + }.toMap + + val offsetRanges = fromOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + }.toArray + + new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala new file mode 100644 index 000000000000..a660d2a00c35 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.Partition + +/** @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ +private[kafka] +class KafkaRDDPartition( + val index: Int, + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long, + val host: String, + val port: Int +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala new file mode 100644 index 000000000000..c9fd715d3d55 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.concurrent.TimeoutException +import java.util.{Map => JMap, Properties} + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.language.postfixOps +import scala.util.control.NonFatal + +import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} +import kafka.serializer.StringEncoder +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.{ZKStringSerializer, ZkUtils} +import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} + +import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +private[kafka] class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkClient: ZkClient = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkClient = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkClient).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration) + server = new KafkaServer(brokerConf) + server.startup() + (server, port) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + AdminUtils.createTopic(zkClient, topic, 1, 1) + // wait until metadata is propagated + waitUntilMetadataIsPropagated(topic, 0) + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Unit = { + producer = new Producer[String, String](new ProducerConfig(producerConfiguration)) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) + producer.close() + producer = null + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerAddress) + props.put("serializer.class", classOf[StringEncoder].getName) + // wait for all in-sync replicas to ack sends + props.put("request.required.acks", "-1") + props + } + + // A simplified version of scalatest eventually, rewritten here to avoid adding extra test + // dependency + def eventually[T](timeout: Time, interval: Time)(func: => T): T = { + def makeAttempt(): Either[Throwable, T] = { + try { + Right(func) + } catch { + case e if NonFatal(e) => Left(e) + } + } + + val startTime = System.currentTimeMillis() + @tailrec + def tryAgain(attempt: Int): T = { + makeAttempt() match { + case Right(result) => result + case Left(e) => + val duration = System.currentTimeMillis() - startTime + if (duration < timeout.milliseconds) { + Thread.sleep(interval.milliseconds) + } else { + throw new TimeoutException(e.getMessage) + } + + tryAgain(attempt + 1) + } + } + + tryAgain(1) + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + ZkUtils.getLeaderForPartition(zkClient, topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(Time(10000), Time(100)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index df725f0c65a6..312822207753 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,22 +17,29 @@ package org.apache.spark.streaming.kafka -import java.lang.{Integer => JInt} -import java.util.{Map => JMap} +import java.lang.{Integer => JInt, Long => JLong} +import java.util.{List => JList, Map => JMap, Set => JSet} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import scala.collection.JavaConversions._ -import kafka.serializer.{Decoder, StringDecoder} +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.{SparkContext, SparkException} object KafkaUtils { /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) * @param groupId The group id for this consumer @@ -56,7 +63,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param kafkaParams Map of kafka configuration parameters, * see http://kafka.apache.org/08/configuration.html @@ -70,12 +77,12 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + val walEnabled = WriteAheadLogUtils.enableReceiverLog(ssc.conf) new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) @@ -89,11 +96,11 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*)) } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. @@ -108,15 +115,15 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), storageLevel) } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object - * @param keyTypeClass Key type of RDD - * @param valueTypeClass value type of RDD + * @param keyTypeClass Key type of DStream + * @param valueTypeClass value type of Dstream * @param keyDecoderClass Type of kafka key decoder * @param valueDecoderClass Type of kafka value decoder * @param kafkaParams Map of kafka configuration parameters, @@ -142,6 +149,523 @@ object KafkaUtils { implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( - jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + jssc.ssc, + kafkaParams.asScala.toMap, + Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), + storageLevel) + } + + /** get leaders for the given offset ranges, or throw an exception */ + private def leadersForRanges( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { + val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet + val leaders = kc.findLeaders(topics) + KafkaCluster.checkErrors(leaders) + } + + /** Make sure offsets are available in kafka, or throw an exception */ + private def checkOffsets( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Unit = { + val topics = offsetRanges.map(_.topicAndPartition).toSet + val result = for { + low <- kc.getEarliestLeaderOffsets(topics).right + high <- kc.getLatestLeaderOffsets(topics).right + } yield { + offsetRanges.filterNot { o => + low(o.topicAndPartition).offset <= o.fromOffset && + o.untilOffset <= high(o.topicAndPartition).offset + } + } + val badRanges = KafkaCluster.checkErrors(result) + if (!badRanges.isEmpty) { + throw new SparkException("Offsets not available on leader: " + badRanges.mkString(",")) + } + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange] + ): RDD[(K, V)] = sc.withScope { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val kc = new KafkaCluster(kafkaParams) + val leaders = leadersForRanges(kc, offsetRanges) + checkOffsets(kc, offsetRanges) + new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[K, V] => R + ): RDD[R] = sc.withScope { + val kc = new KafkaCluster(kafkaParams) + val leaderMap = if (leaders.isEmpty) { + leadersForRanges(kc, offsetRanges) + } else { + // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker + leaders.map { + case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) + }.toMap + } + val cleanedHandler = sc.clean(messageHandler) + checkOffsets(kc, offsetRanges) + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange] + ): JavaPairRDD[K, V] = jsc.sc.withScope { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + new JavaPairRDD(createRDD[K, V, KD, VD]( + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges)) + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaRDD[R] = jsc.sc.withScope { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val leaderMap = Map(leaders.asScala.toSeq: _*) + createRDD[K, V, KD, VD, R]( + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges, leaderMap, messageHandler.call(_)) + } + + /** + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R + ): InputDStream[R] = { + val cleanedHandler = ssc.sc.clean(messageHandler) + new DirectKafkaInputDStream[K, V, KD, VD, R]( + ssc, kafkaParams, fromOffsets, cleanedHandler) + } + + /** + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + topics: Set[String] + ): InputDStream[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val kc = new KafkaCluster(kafkaParams) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + + val result = for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + val fromOffsets = leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) + } + KafkaCluster.checkErrors(result) + } + + /** + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class of the value decoder + * @param recordClass Class of the records in DStream + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaInputDStream[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) + createDirectStream[K, V, KD, VD, R]( + jssc.ssc, + Map(kafkaParams.asScala.toSeq: _*), + Map(fromOffsets.asScala.mapValues(_.longValue()).toSeq: _*), + cleanedHandler + ) + } + + /** + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class type of the value decoder + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + topics: JSet[String] + ): JavaPairInputDStream[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + createDirectStream[K, V, KD, VD]( + jssc.ssc, + Map(kafkaParams.asScala.toSeq: _*), + Set(topics.asScala.toSeq: _*) + ) + } +} + +/** + * This is a helper class that wraps the KafkaUtils.createStream() into more + * Python-friendly class and function so that it can be easily + * instantiated and called from Python's KafkaUtils (see SPARK-6027). + * + * The zero-arg constructor helps instantiate this class from the Class object + * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() + * takes care of known parameters instead of passing them from Python + */ +private[kafka] class KafkaUtilsPythonHelper { + def createStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JMap[String, JInt], + storageLevel: StorageLevel): JavaPairReceiverInputDStream[Array[Byte], Array[Byte]] = { + KafkaUtils.createStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics, + storageLevel) + } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) + if (topicsFromOffsets != topics.asScala.toSet) { + throw new IllegalStateException( + s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq.asJava } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala new file mode 100644 index 000000000000..8a5f37149451 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import kafka.common.TopicAndPartition + +/** + * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * offset ranges in RDDs generated by the direct Kafka DStream (see + * [[KafkaUtils.createDirectStream()]]). + * {{{ + * KafkaUtils.createDirectStream(...).foreachRDD { rdd => + * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + * ... + * } + * }}} + */ +trait HasOffsetRanges { + def offsetRanges: Array[OffsetRange] +} + +/** + * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class + * can be created with `OffsetRange.create()`. + * @param topic Kafka topic name + * @param partition Kafka partition id + * @param fromOffset Inclusive starting offset + * @param untilOffset Exclusive ending offset + */ +final class OffsetRange private( + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long) extends Serializable { + import OffsetRange.OffsetRangeTuple + + /** Kafka TopicAndPartition object, for convenience */ + def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + + override def equals(obj: Any): Boolean = obj match { + case that: OffsetRange => + this.topic == that.topic && + this.partition == that.partition && + this.fromOffset == that.fromOffset && + this.untilOffset == that.untilOffset + case _ => false + } + + override def hashCode(): Int = { + toTuple.hashCode() + } + + override def toString(): String = { + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" + } + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[streaming] + def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset) +} + +/** + * Companion object the provides methods to create instances of [[OffsetRange]]. + */ +object OffsetRange { + def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def create( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + def apply(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def apply( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[kafka] + type OffsetRangeTuple = (String, Int, Long, Long) + + private[kafka] + def apply(t: OffsetRangeTuple) = + new OffsetRange(t._1, t._2, t._3, t._4) +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index be734b80272d..764d170934aa 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -33,7 +33,7 @@ import org.I0Itec.zkclient.ZkClient import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. @@ -96,7 +96,7 @@ class ReliableKafkaReceiver[ blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() // Initialize the block generator for storing Kafka message. - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + @@ -121,7 +121,7 @@ class ReliableKafkaReceiver[ zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) - messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + messageHandlerThreadPool = ThreadUtils.newDaemonFixedThreadPool( topics.values.sum, "KafkaMessageHandler") blockGenerator.start() @@ -201,12 +201,31 @@ class ReliableKafkaReceiver[ topicPartitionOffsetMap.clear() } - /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + /** + * Store the ready-to-be-stored block and commit the related offsets to zookeeper. This method + * will try a fixed number of times to push the block. If the push fails, the receiver is stopped. + */ private def storeBlockAndCommitOffset( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { - store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) - Option(blockOffsetMap.get(blockId)).foreach(commitOffset) - blockOffsetMap.remove(blockId) + var count = 0 + var pushed = false + var exception: Exception = null + while (!pushed && count <= 3) { + try { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + pushed = true + } catch { + case ex: Exception => + count += 1 + exception = ex + } + } + if (pushed) { + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } else { + stop("Error while storing block into Spark", exception) + } } /** @@ -248,7 +267,7 @@ class ReliableKafkaReceiver[ } } catch { case e: Exception => - logError("Error handling message", e) + reportError("Error handling message", e) } } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java new file mode 100644 index 000000000000..9db07d0507fe --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; + +import scala.Tuple2; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaDirectKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient KafkaTestUtils kafkaTestUtils = null; + + @Before + public void setUp() { + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + } + + @After + public void tearDown() { + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } + } + + @Test + public void testKafkaStream() throws InterruptedException { + final String topic1 = "topic1"; + final String topic2 = "topic2"; + // hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference<>(); + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashSet sent = new HashSet(); + sent.addAll(Arrays.asList(topic1data)); + sent.addAll(Arrays.asList(topic2data)); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaDStream stream1 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicToSet(topic1) + ).transformToPair( + // Make sure you can get offset ranges from the rdd + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + Assert.assertEquals(offsets[0].topic(), topic1); + return rdd; + } + } + ).map( + new Function, String>() { + @Override + public String call(Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaDStream stream2 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + topicOffsetToMap(topic2, (long) 0), + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + JavaDStream unifiedStream = stream1.union(stream2); + + final Set result = Collections.synchronizedSet(new HashSet()); + unifiedStream.foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaRDD rdd) throws Exception { + result.addAll(rdd.collect()); + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + return null; + } + } + ); + ssc.start(); + long startTime = System.currentTimeMillis(); + boolean matches = false; + while (!matches && System.currentTimeMillis() - startTime < 20000) { + matches = sent.size() == result.size(); + Thread.sleep(50); + } + Assert.assertEquals(sent, result); + ssc.stop(); + } + + private HashSet topicToSet(String topic) { + HashSet topicSet = new HashSet(); + topicSet.add(topic); + return topicSet; + } + + private HashMap topicOffsetToMap(String topic, Long offsetToStart) { + HashMap topicMap = new HashMap(); + topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); + return topicMap; + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java new file mode 100644 index 000000000000..a9dc6e50613c --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; + +import scala.Tuple2; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +public class JavaKafkaRDDSuite implements Serializable { + private transient JavaSparkContext sc = null; + private transient KafkaTestUtils kafkaTestUtils = null; + + @Before + public void setUp() { + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + sc = new JavaSparkContext(sparkConf); + } + + @After + public void tearDown() { + if (sc != null) { + sc.stop(); + sc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } + } + + @Test + public void testKafkaRDD() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); + + OffsetRange[] offsetRanges = { + OffsetRange.create(topic1, 0, 0, 1), + OffsetRange.create(topic2, 0, 0, 1) + }; + + HashMap emptyLeaders = new HashMap(); + HashMap leaders = new HashMap(); + String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); + Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); + leaders.put(new TopicAndPartition(topic1, 0), broker); + leaders.put(new TopicAndPartition(topic2, 0), broker); + + JavaRDD rdd1 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + offsetRanges + ).map( + new Function, String>() { + @Override + public String call(Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaRDD rdd2 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + emptyLeaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + JavaRDD rdd3 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + leaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + // just making sure the java user apis work; the scala tests handle logic corner cases + long count1 = rdd1.count(); + long count2 = rdd2.count(); + long count3 = rdd3.count(); + Assert.assertTrue(count1 > 0); + Assert.assertEquals(count1, count2); + Assert.assertEquals(count1, count3); + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 6e1abf3f385e..e4c659215b76 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -18,41 +18,34 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Random; +import java.util.*; -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.Duration; -import scala.Predef; import scala.Tuple2; -import scala.collection.JavaConverters; - -import junit.framework.Assert; import kafka.serializer.StringDecoder; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.Test; -import org.junit.After; -import org.junit.Before; - public class JavaKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; private transient Random random = new Random(); - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); ssc = new JavaStreamingContext(sparkConf, new Duration(500)); @@ -60,10 +53,15 @@ public void setUp() { @After public void tearDown() { - ssc.stop(); - ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -77,14 +75,11 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - suiteBase.createTopic(topic); - HashMap tmp = new HashMap(sent); - suiteBase.produceAndSendMessage(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, sent); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +92,7 @@ public void testKafkaStream() throws InterruptedException { topics, StorageLevel.MEMORY_ONLY_SER()); - final HashMap result = new HashMap(); + final Map result = Collections.synchronizedMap(new HashMap()); JavaDStream words = stream.map( new Function, String>() { @@ -127,6 +122,7 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); + long startTime = System.currentTimeMillis(); boolean sizeMatches = false; while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { @@ -137,6 +133,5 @@ public Void call(JavaPairRDD rdd) throws Exception { for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } - ssc.stop(); } } diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala new file mode 100644 index 000000000000..02225d5aa7cc --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.postfixOps + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.Utils + +class DirectKafkaStreamSuite + extends SparkFunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Eventually + with Logging { + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + + private var sc: SparkContext = _ + private var ssc: StreamingContext = _ + private var testDir: File = _ + + private var kafkaTestUtils: KafkaTestUtils = _ + + override def beforeAll { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + } + + override def afterAll { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + after { + if (ssc != null) { + ssc.stop() + sc = null + } + if (sc != null) { + sc.stop() + } + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } + + + test("basic stream receiving with multiple topics and smallest starting offset") { + val topics = Set("basic1", "basic2", "basic3") + val data = Map("a" -> 7, "b" -> 9) + topics.foreach { t => + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) + } + val totalSent = data.values.sum * topics.size + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics) + } + + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.foreachRDD { rdd => + for (o <- offsetRanges) { + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + val collected = rdd.mapPartitionsWithIndex { (i, iter) => + // For each partition, get size of the range in the partition, + // and the number of items in the partition + val off = offsetRanges(i) + val all = iter.toSeq + val partSize = all.size + val rangeSize = off.untilOffset - off.fromOffset + Iterator((partSize, rangeSize)) + }.collect + + // Verify whether number of elements in each partition + // matches with the corresponding offset range + collected.foreach { case (partSize, rangeSize) => + assert(partSize === rangeSize, "offset ranges are wrong") + } + } + stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + } + ssc.stop() + } + + test("receiving from largest starting offset") { + val topic = "largest" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + kafkaTestUtils.sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() > 3) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + kafkaTestUtils.sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + + test("creating stream by offset") { + val topic = "offset" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + kafkaTestUtils.sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() >= 10) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder, String]( + ssc, kafkaParams, Map(topicPartition -> 11L), + (m: MessageAndMetadata[String, String]) => m.message()) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + stream.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + kafkaTestUtils.sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + // Test to verify the offset ranges can be recovered from the checkpoints + test("offset recovery") { + val topic = "recovery" + kafkaTestUtils.createTopic(topic) + testDir = Utils.createTempDir() + + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + // Send data to Kafka and wait for it to be received + def sendDataAndWaitForReceive(data: Seq[Int]) { + val strings = data.map { _.toString} + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) + } + } + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val kafkaStream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + val keyedStream = kafkaStream.map { v => "key" -> v._2.toInt } + val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) => + Some(values.sum + state.getOrElse(0)) + } + ssc.checkpoint(testDir.getAbsolutePath) + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + DirectKafkaStreamSuite.collectedData.appendAll(data) + } + + // This is ensure all the data is eventually receiving only once + stateStream.foreachRDD { (rdd: RDD[(String, Int)]) => + rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 } + } + ssc.start() + + // Send some data and wait for them to be received + for (i <- (1 to 10).grouped(4)) { + sendDataAndWaitForReceive(i) + } + + // Verify that offset ranges were generated + val offsetRangesBeforeStop = getOffsetRanges(kafkaStream) + assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated") + assert( + offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 }, + "starting offset not zero" + ) + ssc.stop() + logInfo("====== RESTARTING ========") + + // Recover context from checkpoints + ssc = new StreamingContext(testDir.getAbsolutePath) + val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]] + + // Verify offset ranges have been recovered + val recoveredOffsetRanges = getOffsetRanges(recoveredStream) + assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered") + val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) } + assert( + recoveredOffsetRanges.forall { or => + earlierOffsetRangesAsSets.contains((or._1, or._2.toSet)) + }, + "Recovered ranges are not the same as the ones generated" + ) + // Restart context, give more data and verify the total at the end + // If the total is write that means each records has been received only once + ssc.start() + sendDataAndWaitForReceive(11 to 20) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total === (1 to 20).sum) + } + ssc.stop() + } + + test("Direct Kafka stream report input information") { + val topic = "report-test" + val data = Map("a" -> 7, "b" -> 9) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) + + val totalSent = data.values.sum + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + import DirectKafkaStreamSuite._ + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val collector = new InputInfoCollector + ssc.addStreamingListener(collector) + + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + + stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + + // Calculate all the record number collected in the StreamingListener. + assert(collector.numRecordsSubmitted.get() === totalSent) + assert(collector.numRecordsStarted.get() === totalSent) + assert(collector.numRecordsCompleted.get() === totalSent) + } + ssc.stop() + } + + test("using rate controller") { + val topic = "backpressure" + val topicPartition = TopicAndPartition(topic, 0) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 100 + val estimator = new ConstantEstimator(100) + val messageKeys = (1 to 200).map(_.toString) + val messages = messageKeys.map((_, 1)).toMap + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, estimator)) + } + } + + val collectedData = + new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + + // Used for assertion failure messages. + def dataToString: String = + collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + collectedData += data + } + + ssc.start() + + // Try different rate limits. + // Send data to Kafka and wait for arrays of data to appear matching the rate. + Seq(100, 50, 20).foreach { rate => + collectedData.clear() // Empty this buffer on each pass. + estimator.updateRate(rate) // Set a new rate. + // Expect blocks of data equal to "rate", scaled by the interval length in secs. + val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) + kafkaTestUtils.sendMessages(topic, messages) + eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { + // Assert that rate estimator values are used to determine maxMessagesPerPartition. + // Funky "-" in message makes the complete assertion message read better. + assert(collectedData.exists(_.size == expectedSize), + s" - No arrays of size $expectedSize for rate $rate found in $dataToString") + } + } + + ssc.stop() + } + + /** Get the generated offset ranges from the DirectKafkaStream */ + private def getOffsetRanges[K, V]( + kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { + kafkaStream.generatedRDDs.mapValues { rdd => + rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges + }.toSeq.sortBy { _._1 } + } +} + +object DirectKafkaStreamSuite { + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + @volatile var total = -1L + + class InputInfoCollector extends StreamingListener { + val numRecordsSubmitted = new AtomicLong(0L) + val numRecordsStarted = new AtomicLong(0L) + val numRecordsCompleted = new AtomicLong(0L) + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords) + } + } +} + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} + diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala new file mode 100644 index 000000000000..d66830cbacde --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.common.TopicAndPartition +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { + private val topic = "kcsuitetopic" + Random.nextInt(10000) + private val topicAndPartition = TopicAndPartition(topic, 0) + private var kc: KafkaCluster = null + + private var kafkaTestUtils: KafkaTestUtils = _ + + override def beforeAll() { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) + } + + override def afterAll() { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + test("metadata apis") { + val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) + val leaderAddress = s"${leader._1}:${leader._2}" + assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") + + val parts = kc.getPartitions(Set(topic)).right.get + assert(parts(topicAndPartition), "didn't get partitions") + + val err = kc.getPartitions(Set(topic + "BAD")) + assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") + } + + test("leader offset apis") { + val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get + assert(earliest(topicAndPartition).offset === 0, "didn't get earliest") + + val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get + assert(latest(topicAndPartition).offset === 1, "didn't get latest") + } + + test("consumer offset apis") { + val group = "kcsuitegroup" + Random.nextInt(10000) + + val offset = Random.nextInt(10000) + + val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset)) + assert(set.isRight, "didn't set consumer offsets") + + val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get + assert(get(topicAndPartition) === offset, "didn't get consumer offsets") + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala new file mode 100644 index 000000000000..f52a738afd65 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.serializer.StringDecoder +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var kafkaTestUtils: KafkaTestUtils = _ + + private val sparkConf = new SparkConf().setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ + + override def beforeAll { + sc = new SparkContext(sparkConf) + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + } + + override def afterAll { + if (sc != null) { + sc.stop + sc = null + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + test("basic usage") { + val topic = s"topicbasic-${Random.nextInt}" + kafkaTestUtils.createTopic(topic) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) + + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt}") + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, offsetRanges) + + val received = rdd.map(_._2).collect.toSet + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head._2 === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0))) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, badRanges) + } + } + + test("iterator boundary conditions") { + // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd + val topic = s"topicboundary-${Random.nextInt}" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + kafkaTestUtils.createTopic(topic) + + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt}") + + val kc = new KafkaCluster(kafkaParams) + + // this is the "lots of messages" case + kafkaTestUtils.sendMessages(topic, sent) + val sentCount = sent.values.sum + + // rdd defined from leaders after sending messages, should get the number sent + val rdd = getRdd(kc, Set(topic)) + + assert(rdd.isDefined) + + val ranges = rdd.get.asInstanceOf[HasOffsetRanges].offsetRanges + val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum + + assert(rangeCount === sentCount, "offset range didn't include all sent messages") + assert(rdd.get.count === sentCount, "didn't get all sent messages") + + val rangesMap = ranges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + + // make sure consumer offsets are committed before the next getRdd call + kc.setConsumerOffsets(kafkaParams("group.id"), rangesMap).fold( + err => throw new Exception(err.mkString("\n")), + _ => () + ) + + // this is the "0 messages" case + val rdd2 = getRdd(kc, Set(topic)) + // shouldn't get anything, since message is sent after rdd was defined + val sentOnlyOne = Map("d" -> 1) + + kafkaTestUtils.sendMessages(topic, sentOnlyOne) + + assert(rdd2.isDefined) + assert(rdd2.get.count === 0, "got messages when there shouldn't be any") + + // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above + val rdd3 = getRdd(kc, Set(topic)) + // send lots of messages after rdd was defined, they shouldn't show up + kafkaTestUtils.sendMessages(topic, Map("extra" -> 22)) + + assert(rdd3.isDefined) + assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") + + } + + // get an rdd from the committed consumer offsets until the latest leader offsets, + private def getRdd(kc: KafkaCluster, topics: Set[String]) = { + val groupId = kc.kafkaParams("group.id") + def consumerOffsets(topicPartitions: Set[TopicAndPartition]) = { + kc.getConsumerOffsets(groupId, topicPartitions).right.toOption.orElse( + kc.getEarliestLeaderOffsets(topicPartitions).right.toOption.map { offs => + offs.map(kv => kv._1 -> kv._2.offset) + } + ) + } + kc.getPartitions(topics).right.toOption.flatMap { topicPartitions => + consumerOffsets(topicPartitions).flatMap { from => + kc.getLatestLeaderOffsets(topicPartitions).right.toOption.map { until => + val offsetRanges = from.map { case (tp: TopicAndPartition, fromOffset: Long) => + OffsetRange(tp.topic, tp.partition, fromOffset, until(tp).offset) + }.toArray + + val leaders = until.map { case (tp: TopicAndPartition, lo: KafkaCluster.LeaderOffset) => + tp -> Broker(lo.host, lo.port) + }.toMap + + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder, String]( + sc, kc.kafkaParams, offsetRanges, leaders, + (mmd: MessageAndMetadata[String, String]) => s"${mmd.offset} ${mmd.message}") + } + } + } + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index b19c053ebfc4..797b07f80d8e 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,199 +17,38 @@ package org.apache.spark.streaming.kafka -import java.io.File -import java.net.InetSocketAddress -import java.util.Properties - import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.CreateTopicCommand -import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, Producer, ProducerConfig} -import kafka.serializer.{StringDecoder, StringEncoder} -import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient.ZkClient -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.scalatest.{BeforeAndAfter, FunSuite} +import kafka.serializer.StringDecoder +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils - -/** - * This is an abstract base class for Kafka testsuites. This has the functionality to set up - * and tear down local Kafka servers, and to push data using Kafka producers. - */ -abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - - var zkAddress: String = _ - var zkClient: ZkClient = _ - - private val zkHost = "localhost" - private val zkConnectionTimeout = 6000 - private val zkSessionTimeout = 6000 - private var zookeeper: EmbeddedZookeeper = _ - private var zkPort: Int = 0 - private var brokerPort = 9092 - private var brokerConf: KafkaConfig = _ - private var server: KafkaServer = _ - private var producer: Producer[String, String] = _ - - def setupKafka() { - // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") - // Get the actual zookeeper binding port - zkPort = zookeeper.actualPort - zkAddress = s"$zkHost:$zkPort" - logInfo("==================== 0 ====================") - - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, - ZKStringSerializer) - logInfo("==================== 1 ====================") - - // Kafka broker startup - var bindSuccess: Boolean = false - while(!bindSuccess) { - try { - val brokerProps = getBrokerConfig() - brokerConf = new KafkaConfig(brokerProps) - server = new KafkaServer(brokerConf) - logInfo("==================== 2 ====================") - server.startup() - logInfo("==================== 3 ====================") - bindSuccess = true - } catch { - case e: KafkaException => - if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { - brokerPort += 1 - } - case e: Exception => throw new Exception("Kafka server create failed", e) - } - } - - Thread.sleep(2000) - logInfo("==================== 4 ====================") - } - - def tearDownKafka() { - if (producer != null) { - producer.close() - producer = null - } - - if (server != null) { - server.shutdown() - server = null - } - - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - if (zkClient != null) { - zkClient.close() - zkClient = null - } - - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } - - private def createTestMessage(topic: String, sent: Map[String, Int]) - : Seq[KeyedMessage[String, String]] = { - val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { - new KeyedMessage[String, String](topic, s) - } - messages.toSeq - } - - def createTopic(topic: String) { - CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") - logInfo("==================== 5 ====================") - // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) - } +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { + private var ssc: StreamingContext = _ + private var kafkaTestUtils: KafkaTestUtils = _ - def produceAndSendMessage(topic: String, sent: Map[String, Int]) { - producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(createTestMessage(topic, sent): _*) - producer.close() - logInfo("==================== 6 ====================") + override def beforeAll(): Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } - private def getBrokerConfig(): Properties = { - val props = new Properties() - props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkAddress) - props.put("log.flush.interval.messages", "1") - props.put("replica.socket.timeout.ms", "1500") - props - } - - private def getProducerConfig(): Properties = { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - val props = new Properties() - props.put("metadata.broker.list", brokerAddr) - props.put("serializer.class", classOf[StringEncoder].getName) - props - } - - private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { - assert( - server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) - } - } - - class EmbeddedZookeeper(val zkConnect: String) { - val random = new Random() - val snapshotDir = Utils.createTempDir() - val logDir = Utils.createTempDir() - - val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) - val (ip, port) = { - val splits = zkConnect.split(":") - (splits(0), splits(1).toInt) - } - val factory = new NIOServerCnxnFactory() - factory.configure(new InetSocketAddress(ip, port), 16) - factory.startup(zookeeper) - - val actualPort = factory.getLocalPort - - def shutdown() { - factory.shutdown() - Utils.deleteRecursively(snapshotDir) - Utils.deleteRecursively(logDir) - } - } -} - - -class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { - var ssc: StreamingContext = _ - - before { - setupKafka() - } - - after { + override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("Kafka input stream") { @@ -217,16 +56,16 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - produceAndSendMessage(topic, sent) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkAddress, + val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() + val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] stream.map(_._2).countByValue().foreachRDD { r => val ret = r.collect() ret.toMap.foreach { kv => @@ -234,14 +73,11 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { result.put(kv._1, count) } } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert(sent.size === result.size) - sent.keys.foreach { k => - assert(sent(k) === result(k).toInt) - } + assert(sent === result) } - ssc.stop() } } - diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 64ccc92c81fa..80e2df62de3f 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kafka - import java.io.File import scala.collection.mutable @@ -25,61 +24,71 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.google.common.io.Files import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.apache.commons.io.FileUtils -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { +class ReliableKafkaStreamSuite extends SparkFunSuite + with BeforeAndAfterAll with BeforeAndAfter with Eventually { - val sparkConf = new SparkConf() + private val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.receiver.writeAheadLog.enable", "true") - val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private var kafkaTestUtils: KafkaTestUtils = _ - var groupId: String = _ - var kafkaParams: Map[String, String] = _ - var ssc: StreamingContext = _ - var tempDirectory: File = null + private var groupId: String = _ + private var kafkaParams: Map[String, String] = _ + private var ssc: StreamingContext = _ + private var tempDirectory: File = null + + override def beforeAll() : Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() - before { - setupKafka() groupId = s"test-consumer-${Random.nextInt(10000)}" kafkaParams = Map( - "zookeeper.connect" -> zkAddress, + "zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> groupId, "auto.offset.reset" -> "smallest" ) + tempDirectory = Utils.createTempDir() + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDirectory) + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + before { ssc = new StreamingContext(sparkConf, Milliseconds(500)) - tempDirectory = Files.createTempDir() ssc.checkpoint(tempDirectory.getAbsolutePath) } after { if (ssc != null) { ssc.stop() + ssc = null } - if (tempDirectory != null && tempDirectory.exists()) { - FileUtils.deleteDirectory(tempDirectory) - tempDirectory = null - } - tearDownKafka() } - test("Reliable Kafka input stream with single topic") { - var topic = "test-topic" - createTopic(topic) - produceAndSendMessage(topic, data) + val topic = "test-topic" + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -95,6 +104,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter } } ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { // A basic process verification for ReliableKafkaReceiver. // Verify whether received message number is equal to the sent message number. @@ -104,14 +114,13 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter // Verify the offset number whether it is equal to the total message number. assert(getCommitOffset(groupId, topic, 0) === Some(29L)) } - ssc.stop() } test("Reliable Kafka input stream with multiple topics") { val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => - createTopic(t) - produceAndSendMessage(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. @@ -122,19 +131,18 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) stream.foreachRDD(_ => Unit) ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { // Verify the offset for each group/topic to see whether they are equal to the expected one. topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } } - ssc.stop() } /** Getting partition offset from Zookeeper. */ private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { - assert(zkClient != null, "Zookeeper client is not initialized") val topicDirs = new ZKGroupTopicDirs(groupId, topic) val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" - ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + ZkUtils.readDataMaybeNull(kafkaTestUtils.zookeeperClient, zkPath)._1.map(_.toLong) } } diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 000000000000..841260063373 --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,175 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 560c8b9d1827..69b309876a0d 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 @@ -71,5 +78,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 000000000000..ecab5b360eb3 --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + / + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 1ef91dd49284..7c2f18cb35bd 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,25 +17,12 @@ package org.apache.spark.streaming.mqtt -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ -import scala.reflect.ClassTag - -import java.util.Properties -import java.util.concurrent.Executors -import java.io.IOException - +import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken -import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ @@ -57,6 +44,8 @@ class MQTTInputDStream( storageLevel: StorageLevel ) extends ReceiverInputDStream[String](ssc_) { + private[streaming] override def name: String = s"MQTT stream [$id]" + def getReceiver(): Receiver[String] = { new MQTTReceiver(brokerUrl, topic, storageLevel) } @@ -82,18 +71,18 @@ class MQTTReceiver( val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) // Callback automatically triggers as and when new message arrives on specified topic - val callback: MqttCallback = new MqttCallback() { + val callback = new MqttCallback() { // Handles Mqtt message - override def messageArrived(arg0: String, arg1: MqttMessage) { - store(new String(arg1.getPayload(),"utf-8")) + override def messageArrived(topic: String, message: MqttMessage) { + store(new String(message.getPayload(), "utf-8")) } - override def deliveryComplete(arg0: IMqttDeliveryToken) { + override def deliveryComplete(token: IMqttDeliveryToken) { } - override def connectionLost(arg0: Throwable) { - restart("Connection lost ", arg0) + override def connectionLost(cause: Throwable) { + restart("Connection lost ", cause) } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index c5ffe51f9986..7b8d56d6faf2 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.mqtt +import scala.reflect.ClassTag + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import scala.reflect.ClassTag -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object MQTTUtils { /** @@ -73,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private[mqtt] class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index e84adc088a68..a6a9249db8ed 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,45 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf -import org.apache.spark.util.Utils +import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) - private val master: String = "local[2]" - private val framework: String = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort + private val master = "local[2]" + private val framework = this.getClass.getSimpleName private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -63,14 +48,17 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream: ReceiverInputDStream[String] = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -78,86 +66,14 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence: MqttClientPersistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic: MqttTopic = client.getTopic(topic) - val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(50) // wait for Spark streaming to consume something from the message queue - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 000000000000..1618e2c088b7 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private[mqtt] class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index da6ffe7662f6..178ae8de13b5 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 4eacc47da569..7cf02d85d73d 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -70,7 +70,7 @@ class TwitterReceiver( try { val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { + def onStatus(status: Status): Unit = { store(status) } // Unimplemented diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 64bfc5745088..9a3569789d2e 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d8..d9acb568879f 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e919c2c9b19e..37bfd10d4366 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 554705878ee7..588e6bac7b14 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -29,13 +29,16 @@ import org.apache.spark.streaming.receiver.ActorHelper /** * A receiver to subscribe to ZeroMQ stream. */ -private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) +private[streaming] class ZeroMQReceiver[T: ClassTag]( + publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[ByteString] => Iterator[T]) extends Actor with ActorHelper with Logging { - override def preStart() = ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + override def preStart(): Unit = { + ZeroMQExtension(context.system) + .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + } def receive: Receive = { diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 0469d0af8864..4ea218eaa4de 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -18,15 +18,17 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString import akka.zeromq.Subscribe + import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream} +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.ActorSupervisorStrategy object ZeroMQUtils { @@ -75,7 +77,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) } @@ -99,7 +102,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) } @@ -122,7 +126,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn) } } diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d89..35d2e62c6848 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 0fb431808bac..3636a9037d43 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 287c8e356350..eb3b1999eb99 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml new file mode 100644 index 000000000000..51af3e6f2225 --- /dev/null +++ b/extras/kinesis-asl-assembly/pom.xml @@ -0,0 +1,181 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.10 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + com.fasterxml.jackson.core + jackson-databind + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c815eda52bda..521b53e230c4 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl @@ -40,6 +40,13 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -59,7 +66,7 @@ org.mockito - mockito-all + mockito-core test @@ -67,11 +74,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - com.novocode junit-interface diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index b0bff27a61c1..06e0ff28afd9 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.regex.Pattern; +import com.amazonaws.regions.RegionUtils; import org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -40,140 +41,146 @@ import com.google.common.collect.Lists; /** - * Java-friendly Kinesis Spark Streaming WordCount example + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details - * on the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: JavaKinesisWordCountASL [app-name] [stream-name] [endpoint-url] [region-name] + * [app-name] is the name of the consumer app, used to track the read data in DynamoDB + * [stream-name] name of the Kinesis stream (ie. mySparkStream) + * [endpoint-url] endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: JavaKinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID=[your-access-key] * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * # run the example + * $ SPARK_HOME/bin/run-example streaming.JavaKinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data - * onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ public final class JavaKinesisWordCountASL { // needs to be public for access from run-example - private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); - private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); - - /* Make the constructor private to enforce singleton */ - private JavaKinesisWordCountASL() { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + public static void main(String[] args) { + // Check that all required args were passed in. + if (args.length != 3) { + System.err.println( + "Usage: JavaKinesisWordCountASL \n\n" + + " is the name of the app, used to track the read data in DynamoDB\n" + + " is the name of the Kinesis stream\n" + + " is the endpoint of the Kinesis service\n" + + " (e.g. https://kinesis.us-east-1.amazonaws.com)\n" + + "Generate data for the Kinesis stream using the example KinesisWordProducerASL.\n" + + "See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more\n" + + "details.\n" + ); + System.exit(1); } - public static void main(String[] args) { - /* Check that all required args were passed in. */ - if (args.length < 2) { - System.err.println( - "Usage: JavaKinesisWordCountASL \n" + - " is the name of the Kinesis stream\n" + - " is the endpoint of the Kinesis service\n" + - " (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - /* Populate the appropriate variables from the given args */ - String streamName = args[0]; - String endpointUrl = args[1]; - /* Set the batch interval to a fixed 2000 millis (2 seconds) */ - Duration batchInterval = new Duration(2000); - - /* Create a Kinesis client in order to determine the number of shards for the given stream */ - AmazonKinesisClient kinesisClient = new AmazonKinesisClient( - new DefaultAWSCredentialsProviderChain()); - kinesisClient.setEndpoint(endpointUrl); - - /* Determine the number of shards from the stream */ - int numShards = kinesisClient.describeStream(streamName) - .getStreamDescription().getShards().size(); - - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ - int numStreams = numShards; - - /* Setup the Spark config. */ - SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount"); - - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ - Duration checkpointInterval = batchInterval; + // Set default log4j logging level to WARN to hide Spark logs + StreamingExamples.setStreamingLogLevels(); + + // Populate the appropriate variables from the given args + String kinesisAppName = args[0]; + String streamName = args[1]; + String endpointUrl = args[2]; + + // Create a Kinesis client in order to determine the number of shards for the given stream + AmazonKinesisClient kinesisClient = + new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + int numShards = + kinesisClient.describeStream(streamName).getStreamDescription().getShards().size(); + + + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. + int numStreams = numShards; + + // Spark Streaming batch interval + Duration batchInterval = new Duration(2000); + + // Kinesis checkpoint interval. Same as batchInterval for this example. + Duration kinesisCheckpointInterval = batchInterval; + + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + + // Setup the Spark config and StreamingContext + SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + // Create the Kinesis DStreams + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2()) + ); + } - /* Setup the StreamingContext */ - JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + // Union all the streams if there is more than 1 stream + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + // Otherwise, just use the 1 stream + unionStreams = streamsList.get(0); + } - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ - List> streamsList = new ArrayList>(numStreams); - for (int i = 0; i < numStreams; i++) { - streamsList.add( - KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) - ); + // Convert each line of Array[Byte] to String, and split into words + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + // Map each word to a (word, 1) tuple so we can reduce by key to count the words + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } } - - /* Union all the streams if there is more than 1 stream */ - JavaDStream unionStreams; - if (streamsList.size() > 1) { - unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); - } else { - /* Otherwise, just use the 1 stream */ - unionStreams = streamsList.get(0); + ).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } } + ); - /* - * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. - * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. - */ - JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { - @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); - } - }); - - /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); - - /* Print the first 10 wordCounts */ - wordCounts.print(); - - /* Start the streaming context and await termination */ - jssc.start(); - jssc.awaitTermination(); - } + // Print the first 10 wordCounts + wordCounts.print(); + + // Start the streaming context and await termination + jssc.start(); + jssc.awaitTermination(); + } } diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py new file mode 100644 index 000000000000..f428f64da3c4 --- /dev/null +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from a Amazon Kinesis streams and does wordcount. + + This example spins up 1 Kinesis Receiver per shard for the given stream. + It then starts pulling from the last checkpointed sequence number of the given stream. + + Usage: kinesis_wordcount_asl.py + is the name of the consumer app, used to track the read data in DynamoDB + name of the Kinesis stream (ie. mySparkStream) + endpoint of the Kinesis service + (e.g. https://kinesis.us-east-1.amazonaws.com) + + + Example: + # export AWS keys if necessary + $ export AWS_ACCESS_KEY_ID= + $ export AWS_SECRET_KEY= + + # run the example + $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com + + There is a companion helper class called KinesisWordProducerASL which puts dummy data + onto the Kinesis stream. + + This code uses the DefaultAWSCredentialsProviderChain to find credentials + in the following order: + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Java System Properties - aws.accessKeyId and aws.secretKey + Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + Instance profile credentials - delivered through the Amazon EC2 metadata service + For more information, see + http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + + See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + the Kinesis Spark Streaming integration. +""" +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + +if __name__ == "__main__": + if len(sys.argv) != 5: + print( + "Usage: kinesis_wordcount_asl.py ", + file=sys.stderr) + sys.exit(-1) + + sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl") + ssc = StreamingContext(sc, 1) + appName, streamName, endpointUrl, regionName = sys.argv[1:] + lines = KinesisUtils.createStream( + ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties index 97348fb5b612..6cdc9286c5d7 100644 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 32da0858d1a1..de749626ec09 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,226 +15,253 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer + import scala.util.Random -import org.apache.spark.Logging -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.streaming.kinesis.KinesisUtils -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain + +import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest -import org.apache.log4j.Logger -import org.apache.log4j.Level +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils + /** - * Kinesis Spark Streaming WordCount example. + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on - * the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: KinesisWordCountASL + * is the name of the consumer app, used to track the read data in DynamoDB + * name of the Kinesis stream (ie. mySparkStream) + * endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: KinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * + * # run the example + * $ SPARK_HOME/bin/run-example streaming.KinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com * - * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class below called KinesisWordCountProducerASL which puts - * dummy data onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ -private object KinesisWordCountASL extends Logging { +object KinesisWordCountASL extends Logging { def main(args: Array[String]) { - /* Check that all required args were passed in. */ - if (args.length < 2) { + // Check that all required args were passed in. + if (args.length != 3) { System.err.println( """ - |Usage: KinesisWordCount + |Usage: KinesisWordCountASL + | + | is the name of the consumer app, used to track the read data in DynamoDB | is the name of the Kinesis stream | is the endpoint of the Kinesis service | (e.g. https://kinesis.us-east-1.amazonaws.com) + | + |Generate input data for Kinesis stream using the example KinesisWordProducerASL. + |See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more + |details. """.stripMargin) System.exit(1) } StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ - val Array(streamName, endpointUrl) = args + // Populate the appropriate variables from the given args + val Array(appName, streamName, endpointUrl) = args - /* Determine the number of shards from the stream */ - val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + + // Determine the number of shards from the stream using the low-level Kinesis Client + // from the AWS Java SDK. + val credentials = new DefaultAWSCredentialsProviderChain().getCredentials() + require(credentials != null, + "No AWS credentials found. Please specify credentials using one of the methods specified " + + "in http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html") + val kinesisClient = new AmazonKinesisClient(credentials) kinesisClient.setEndpoint(endpointUrl) - val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() - .size() + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards().size + - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. val numStreams = numShards - /* Setup the and SparkConfig and StreamingContext */ - /* Spark Streaming batch interval */ + // Spark Streaming batch interval val batchInterval = Milliseconds(2000) - val sparkConfig = new SparkConf().setAppName("KinesisWordCount") - val ssc = new StreamingContext(sparkConfig, batchInterval) - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information + // on sequence number of records that have been received. Same as batchInterval for this + // example. val kinesisCheckpointInterval = batchInterval - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + + // Setup the SparkConfig and StreamingContext + val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + // Create the Kinesis DStreams val kinesisStreams = (0 until numStreams).map { i => - KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream(ssc, appName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2) } - /* Union all the streams */ + // Union all the streams val unionStreams = ssc.union(kinesisStreams) - /* Convert each line of Array[Byte] to String, split into words, and count them */ - val words = unionStreams.flatMap(byteArray => new String(byteArray) - .split(" ")) + // Convert each line of Array[Byte] to String, and split into words + val words = unionStreams.flatMap(byteArray => new String(byteArray).split(" ")) - /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - /* Print the first 10 wordCounts */ + // Print the first 10 wordCounts wordCounts.print() - /* Start the streaming context and await termination */ + // Start the streaming context and await termination ssc.start() ssc.awaitTermination() } } /** - * Usage: KinesisWordCountProducerASL - * + * Usage: KinesisWordProducerASL \ + * + * * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service + * is the endpoint of the Kinesis service * (ie. https://kinesis.us-east-1.amazonaws.com) * is the rate of records per second to put onto the stream * is the rate of records per second to put onto the stream * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com 10 5 + * $ SPARK_HOME/bin/run-example streaming.KinesisWordProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com us-east-1 10 5 */ -private object KinesisWordCountProducerASL { +object KinesisWordProducerASL { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: KinesisWordCountProducerASL " + - " ") + if (args.length != 4) { + System.err.println( + """ + |Usage: KinesisWordProducerASL + + | + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + | is the rate of records per second to put onto the stream + | is the rate of records per second to put onto the stream + | + """.stripMargin) + System.exit(1) } + // Set default log4j logging level to WARN to hide Spark logs StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ + // Populate the appropriate variables from the given args val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args - /* Generate the records and return the totals */ - val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + // Generate the records and return the totals + val totals = generate(stream, endpoint, recordsPerSecond.toInt, + wordsPerRecord.toInt) - /* Print the array of (index, total) tuples */ - println("Totals") - totals.foreach(total => println(total.toString())) + // Print the array of (word, total) tuples + println("Totals for the words sent") + totals.foreach(println(_)) } def generate(stream: String, endpoint: String, recordsPerSecond: Int, - wordsPerRecord: Int): Seq[(Int, Int)] = { + wordsPerRecord: Int): Seq[(String, Int)] = { - val MaxRandomInts = 10 + val randomWords = List("spark", "you", "are", "my", "father") + val totals = scala.collection.mutable.Map[String, Int]() - /* Create the Kinesis client */ + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + - s" $recordsPerSecond records per second and $wordsPerRecord words per record"); - - val totals = new Array[Int](MaxRandomInts) - /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ - for (i <- 1 to 5) { - - /* Generate recordsPerSec records to put onto the stream */ - val records = (1 to recordsPerSecond.toInt).map { recordNum => - /* - * Randomly generate each wordsPerRec words between 0 (inclusive) - * and MAX_RANDOM_INTS (exclusive) - */ + s" $recordsPerSecond records per second and $wordsPerRecord words per record") + + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord + for (i <- 1 to 10) { + // Generate recordsPerSec records to put onto the stream + val records = (1 to recordsPerSecond.toInt).foreach { recordNum => + // Randomly generate wordsPerRecord number of words val data = (1 to wordsPerRecord.toInt).map(x => { - /* Generate the random int */ - val randomInt = Random.nextInt(MaxRandomInts) + // Get a random index to a word + val randomWordIdx = Random.nextInt(randomWords.size) + val randomWord = randomWords(randomWordIdx) - /* Keep track of the totals */ - totals(randomInt) += 1 + // Increment total count to compare to server counts later + totals(randomWord) = totals.getOrElse(randomWord, 0) + 1 - randomInt.toString() + randomWord }).mkString(" ") - /* Create a partitionKey based on recordNum */ + // Create a partitionKey based on recordNum val partitionKey = s"partitionKey-$recordNum" - /* Create a PutRecordRequest with an Array[Byte] version of the data */ + // Create a PutRecordRequest with an Array[Byte] version of the data val putRecordRequest = new PutRecordRequest().withStreamName(stream) .withPartitionKey(partitionKey) - .withData(ByteBuffer.wrap(data.getBytes())); + .withData(ByteBuffer.wrap(data.getBytes())) - /* Put the record onto the stream and capture the PutRecordResult */ - val putRecordResult = kinesisClient.putRecord(putRecordRequest); + // Put the record onto the stream and capture the PutRecordResult + val putRecordResult = kinesisClient.putRecord(putRecordRequest) } - /* Sleep for a second */ + // Sleep for a second Thread.sleep(1000) println("Sent " + recordsPerSecond + " records") } - - /* Convert the totals to (index, total) tuple */ - (0 to (MaxRandomInts - 1)).zip(totals) + // Convert the totals to (index, total) tuple + totals.toSeq.sortBy(_._1) } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { - - /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + // Set reasonable logging levels for streaming if the user has not configured log4j. def setStreamingLogLevels() { val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized) { @@ -246,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala new file mode 100644 index 000000000000..5d32fa699ae5 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.NextIterator + + +/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ +private[kinesis] +case class SequenceNumberRange( + streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + +/** Class representing an array of Kinesis sequence number ranges */ +private[kinesis] +case class SequenceNumberRanges(ranges: Seq[SequenceNumberRange]) { + def isEmpty(): Boolean = ranges.isEmpty + + def nonEmpty(): Boolean = ranges.nonEmpty + + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") +} + +private[kinesis] +object SequenceNumberRanges { + def apply(range: SequenceNumberRange): SequenceNumberRanges = { + new SequenceNumberRanges(Seq(range)) + } +} + + +/** Partition storing the information of the ranges of Kinesis sequence numbers to read */ +private[kinesis] +class KinesisBackedBlockRDDPartition( + idx: Int, + blockId: BlockId, + val isBlockIdValid: Boolean, + val seqNumberRanges: SequenceNumberRanges + ) extends BlockRDDPartition(blockId, idx) + +/** + * A BlockRDD where the block data is backed by Kinesis, which can accessed using the + * sequence numbers of the corresponding blocks. + */ +private[kinesis] +class KinesisBackedBlockRDD( + @transient sc: SparkContext, + val regionName: String, + val endpointUrl: String, + @transient blockIds: Array[BlockId], + @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + val retryTimeoutMs: Int = 10000, + val awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[Array[Byte]](sc, blockIds) { + + require(blockIds.length == arrayOfseqNumberRanges.length, + "Number of blockIds is not equal to the number of sequence number ranges") + + override def isValid(): Boolean = true + + override def getPartitions: Array[Partition] = { + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] + val blockId = partition.blockId + + def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + logDebug(s"Read partition data of $this from block manager, block $blockId") + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + } + + def getBlockFromKinesis(): Iterator[Array[Byte]] = { + val credenentials = awsCredentialsOption.getOrElse { + new DefaultAWSCredentialsProviderChain().getCredentials() + } + partition.seqNumberRanges.ranges.iterator.flatMap { range => + new KinesisSequenceRangeIterator( + credenentials, endpointUrl, regionName, range, retryTimeoutMs) + } + } + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromKinesis() } + } else { + getBlockFromKinesis() + } + } +} + + +/** + * An iterator that return the Kinesis data based on the given range of sequence numbers. + * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber, + * until the endSequenceNumber is reached. + */ +private[kinesis] +class KinesisSequenceRangeIterator( + credentials: AWSCredentials, + endpointUrl: String, + regionId: String, + range: SequenceNumberRange, + retryTimeoutMs: Int + ) extends NextIterator[Array[Byte]] with Logging { + + private val client = new AmazonKinesisClient(credentials) + private val streamName = range.streamName + private val shardId = range.shardId + + private var toSeqNumberReceived = false + private var lastSeqNumber: String = null + private var internalIterator: Iterator[Record] = null + + client.setEndpoint(endpointUrl, "kinesis", regionId) + + override protected def getNext(): Array[Byte] = { + var nextBytes: Array[Byte] = null + if (toSeqNumberReceived) { + finished = true + } else { + + if (internalIterator == null) { + + // If the internal iterator has not been initialized, + // then fetch records from starting sequence number + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + } else if (!internalIterator.hasNext) { + + // If the internal iterator does not have any more records, + // then fetch more records after the last consumed sequence number + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } + + if (!internalIterator.hasNext) { + + // If the internal iterator still does not have any data, then throw exception + // and terminate this iterator + finished = true + throw new SparkException( + s"Could not read until the end sequence number of the range: $range") + } else { + + // Get the record, copy the data into a byte array and remember its sequence number + val nextRecord: Record = internalIterator.next() + val byteBuffer = nextRecord.getData() + nextBytes = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(nextBytes) + lastSeqNumber = nextRecord.getSequenceNumber() + + // If the this record's sequence number matches the stopping sequence number, then make sure + // the iterator is marked finished next time getNext() is called + if (nextRecord.getSequenceNumber == range.toSeqNumber) { + toSeqNumberReceived = true + } + } + + } + nextBytes + } + + override protected def close(): Unit = { + client.shutdown() + } + + /** + * Get records starting from or after the given sequence number. + */ + private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + val shardIterator = getKinesisIterator(iteratorType, seqNum) + val result = getRecordsAndNextKinesisIterator(shardIterator) + result._1 + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator( + shardIterator: String): (Iterator[Record], String) = { + val getRecordsRequest = new GetRecordsRequest + getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setShardIterator(shardIterator) + val getRecordsResult = retryOrTimeout[GetRecordsResult]( + s"getting records using shard iterator") { + client.getRecords(getRecordsRequest) + } + (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest + getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setStreamName(streamName) + getShardIteratorRequest.setShardId(shardId) + getShardIteratorRequest.setShardIteratorType(iteratorType.toString) + getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) + val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + s"getting shard iterator from sequence number $sequenceNumber") { + client.getShardIterator(getShardIteratorRequest) + } + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String)(body: => T): T = { + import KinesisSequenceRangeIterator._ + + var startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} + +private[streaming] +object KinesisSequenceRangeIterator { + val MAX_RETRIES = 3 + val MIN_RETRY_WAIT_TIME_MS = 100 +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 0b80b611cdce..83a453755951 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -18,39 +18,37 @@ package org.apache.spark.streaming.kinesis import org.apache.spark.Logging import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.util.SystemClock +import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint */ def shouldCheckpoint(): Boolean = { - new SystemClock().currentTime() > checkpointClock.currentTime() + new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() } /** * Advance the checkpoint clock by the checkpoint interval. */ - def advanceCheckpoint() = { - checkpointClock.addToTime(checkpointInterval.milliseconds) + def advanceCheckpoint(): Unit = { + checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala new file mode 100644 index 000000000000..2e4204dcb6f1 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.{Duration, StreamingContext, Time} + +private[kinesis] class KinesisInputDStream( + @transient _ssc: StreamingContext, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointAppName: String, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends ReceiverInputDStream[Array[Byte]](_ssc) { + + private[streaming] + override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = { + + // This returns true even for when blockInfos is empty + val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty) + + if (allBlocksHaveRanges) { + // Create a KinesisBackedBlockRDD, even when there are no blocks + val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray + val seqNumRanges = blockInfos.map { + _.metadataOption.get.asInstanceOf[SequenceNumberRanges] }.toArray + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " + + s"seq number ranges: ${seqNumRanges.mkString(", ")} ") + new KinesisBackedBlockRDD( + context.sc, regionName, endpointUrl, blockIds, seqNumRanges, + isBlockIdValid = isBlockIdValid, + retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, + awsCredentialsOption = awsCredentialsOption) + } else { + logWarning("Kinesis sequence number information was not present with some block metadata," + + " it may not be possible to recover from failures") + super.createBlockRDD(time, blockInfos) + } + } + + override def getReceiver(): Receiver[Array[Byte]] = { + new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, + checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1bd1f324298e..6e0988c1af8a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -16,134 +16,321 @@ */ package org.apache.spark.streaming.kinesis -import java.net.InetAddress import java.util.UUID -import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.model.Record + +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkEnv} -import com.amazonaws.auth.AWSCredentialsProvider -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +private[kinesis] +case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) + extends AWSCredentials { + override def getAWSAccessKeyId: String = accessKeyId + override def getAWSSecretKey: String = secretKey +} /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) - * as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers - * to run within a Spark Executor. * - * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams - * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing - * DynamoDB table with the same name this Kinesis application. + * The way this Receiver works is as follows: + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. + * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. + * @param regionName Region name used by the Kinesis Client Library for + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * - * @return ReceiverInputDStream[Array[Byte]] + * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies + * the credentials */ private[kinesis] class KinesisReceiver( - appName: String, - streamName: String, + val streamName: String, endpointUrl: String, - checkpointInterval: Duration, + regionName: String, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel) - extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + checkpointAppName: String, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => /* - * The following vars are built in the onStart() method which executes in the Spark Worker after - * this code is serialized and shipped remotely. + * ================================================================================= + * The following vars are initialize in the onStart() method which executes in the + * Spark worker after this Receiver is serialized and shipped to the worker. + * ================================================================================= */ - /* - * workerId should be based on the ip address of the actual Spark Worker where this code runs - * (not the Driver's ip address.) + /** + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * where this code runs (not the driver's IP address.) */ - var workerId: String = null + @volatile private var workerId: String = null - /* - * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file at the default location (~/.aws/credentials) shared by all - * AWS SDKs and the AWS CLI - * Instance profile credentials delivered through the Amazon EC2 metadata service + /** + * Worker is the core client abstraction from the Kinesis Client Library (KCL). + * A worker can process more than one shards from the given stream. + * Each shard is assigned its own IRecordProcessor and the worker run multiple such + * processors. */ - var credentialsProvider: AWSCredentialsProvider = null + @volatile private var worker: Worker = null + @volatile private var workerThread: Thread = null - /* KCL config instance. */ - var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + /** BlockGenerator used to generates blocks out of Kinesis data */ + @volatile private var blockGenerator: BlockGenerator = null - /* - * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the - * IRecordProcessor.processRecords() method. - * We're using our custom KinesisRecordProcessor in this case. + /** + * Sequence number ranges added to the current block being generated. + * Accessing and updating of this map is synchronized by locks in BlockGenerator. */ - var recordProcessorFactory: IRecordProcessorFactory = null + private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] - /* - * Create a Kinesis Worker. - * This is the core client abstraction from the Kinesis Client Library (KCL). - * We pass the RecordProcessorFactory from above as well as the KCL config instance. - * A Kinesis Worker can process 1..* shards from the given stream - each with its - * own RecordProcessor. - */ - var worker: Worker = null + /** Sequence number ranges of data added to each generated block */ + private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] + with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] /** - * This is called when the KinesisReceiver starts and must be non-blocking. - * The KCL creates and manages the receiving/processing thread pool through the Worker.run() - * method. + * Latest sequence number ranges that have been stored successfully. + * This is used for checkpointing through KCL */ + private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String] + with mutable.SynchronizedMap[String, String] + /** + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { - workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() - credentialsProvider = new DefaultAWSCredentialsProviderChain() - kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, - credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) - recordProcessorFactory = new IRecordProcessorFactory { + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) + + workerId = Utils.localHostName() + ":" + UUID.randomUUID() + + // KCL config instance + val awsCredProvider = resolveAWSCredentialsProvider() + val kinesisClientLibConfiguration = + new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + val recordProcessorFactory = new IRecordProcessorFactory { override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, workerId, new KinesisCheckpointState(checkpointInterval)) } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) - worker.run() + workerThread = new Thread() { + override def run(): Unit = { + try { + worker.run() + } catch { + case NonFatal(e) => + restart("Error running the KCL worker in Receiver", e) + } + } + } + + blockIdToSeqNumRanges.clear() + blockGenerator.start() + + workerThread.setName(s"Kinesis Receiver ${streamId}") + workerThread.setDaemon(true) + workerThread.start() logInfo(s"Started receiver with workerId $workerId") } /** - * This is called when the KinesisReceiver stops. - * The KCL worker.shutdown() method stops the receiving/processing threads. - * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop() { - worker.shutdown() - logInfo(s"Shut down receiver with workerId $workerId") + if (workerThread != null) { + if (worker != null) { + worker.shutdown() + worker = null + } + workerThread.join() + workerThread = null + logInfo(s"Stopped receiver for workerId $workerId") + } workerId = null - credentialsProvider = null - kinesisClientLibConfiguration = null - recordProcessorFactory = null - worker = null + } + + /** Add records of the given shard to the current block being generated */ + private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { + if (records.size > 0) { + val dataIterator = records.iterator().asScala.map { record => + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } + val metadata = SequenceNumberRange(streamName, shardId, + records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) + blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) + + } + } + + /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ + private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { + shardIdToLatestStoredSeqNum.get(shardId) + } + + /** + * Remember the range of sequence numbers that was added to the currently active block. + * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`. + */ + private def rememberAddedRange(range: SequenceNumberRange): Unit = { + seqNumRangesInCurrentBlock += range + } + + /** + * Finalize the ranges added to the block that was active and prepare the ranges buffer + * for next block. Internally, this is synchronized with `rememberAddedRange()`. + */ + private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { + blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray) + seqNumRangesInCurrentBlock.clear() + logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") + } + + /** Store the block along with its associated ranges */ + private def storeBlockWithRanges( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = { + val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) + if (rangesToReportOption.isEmpty) { + stop("Error while storing block into Spark, could not find sequence number ranges " + + s"for block $blockId") + return + } + + val rangesToReport = rangesToReportOption.get + var attempt = 0 + var stored = false + var throwable: Throwable = null + while (!stored && attempt <= 3) { + try { + store(arrayBuffer, rangesToReport) + stored = true + } catch { + case NonFatal(th) => + attempt += 1 + throwable = th + } + } + if (!stored) { + stop("Error while storing block into Spark", throwable) + } + + // Update the latest sequence number that have been successfully stored for each shard + // Note that we are doing this sequentially because the array of sequence number ranges + // is assumed to be + rangesToReport.ranges.foreach { range => + shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber + } + } + + /** + * If AWS credential is provided, return a AWSCredentialProvider returning that credential. + * Otherwise, return the DefaultAWSCredentialsProviderChain. + */ + private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { + awsCredentialsOption match { + case Some(awsCredentials) => + logInfo("Using provided AWS credentials") + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = awsCredentials + override def refresh(): Unit = { } + } + case None => + logInfo("Using DefaultAWSCredentialsProviderChain") + new DefaultAWSCredentialsProviderChain() + } + } + + + /** + * Class to handle blocks generated by this receiver's block generator. Specifically, in + * the context of the Kinesis Receiver, this handler does the following. + * + * - When an array of records is added to the current active block in the block generator, + * this handler keeps track of the corresponding sequence number range. + * - When the currently active block is ready to sealed (not more records), this handler + * keep track of the list of ranges added into this block in another H + */ + private class GeneratedBlockHandler extends BlockGeneratorListener { + + /** + * Callback method called after a data item is added into the BlockGenerator. + * The data addition, block generation, and calls to onAddData and onGenerateBlock + * are all synchronized through the same lock. + */ + def onAddData(data: Any, metadata: Any): Unit = { + rememberAddedRange(metadata.asInstanceOf[SequenceNumberRange]) + } + + /** + * Callback method called after a block has been generated. + * The data addition, block generation, and calls to onAddData and onGenerateBlock + * are all synchronized through the same lock. + */ + def onGenerateBlock(blockId: StreamBlockId): Unit = { + finalizeRangesForCurrentBlock(blockId) + } + + /** Callback method called when a block is ready to be pushed / stored. */ + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + storeBlockWithRanges(blockId, + arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]]) + } + + /** Callback called in case of any error in internal of the BlockGenerator */ + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 8ecc2d90160b..b2405123321e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -18,24 +18,23 @@ package org.apache.spark.streaming.kinesis import java.util.List -import scala.collection.JavaConversions.asScalaBuffer import scala.util.Random +import scala.util.control.NonFatal -import org.apache.spark.Logging - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.Logging + /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes @@ -47,8 +46,9 @@ private[kinesis] class KinesisRecordProcessor( workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { - /* shardId to be populated during initialize() */ - var shardId: String = _ + // shardId to be populated during initialize() + @volatile + private var shardId: String = _ /** * The Kinesis Client Library calls this method during IRecordProcessor initialization. @@ -56,8 +56,8 @@ private[kinesis] class KinesisRecordProcessor( * @param shardId assigned by the KCL to this particular RecordProcessor. */ override def initialize(shardId: String) { - logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") this.shardId = shardId + logInfo(s"Initialized workerId $workerId with shardId $shardId") } /** @@ -66,48 +66,44 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { - /* - * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - */ - batch.foreach(record => receiver.store(record.getData().array())) - - logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + receiver.addRecords(shardId, batch) + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored - * in the batch. - * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the - * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. - * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). - * However, if the worker dies unexpectedly, a checkpoint may not happen. - * This could lead to records being processed more than once. + * + * Checkpoint the sequence number of the last record successfully stored. + * Note that in this current implementation, the checkpointing occurs only when after + * checkpointIntervalMillis from the last checkpoint, AND when there is new record + * to process. This leads to the checkpointing lagging behind what records have been + * stored by the receiver. Ofcourse, this can lead records processed more than once, + * under failures and restarts. + * + * TODO: Instead of checkpointing here, run a separate timer task to perform + * checkpointing so that it checkpoints in a timely manner independent of whether + * new records are available or not. */ if (checkpointState.shouldCheckpoint()) { - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() + /* Update the next checkpoint time */ + checkpointState.advanceCheckpoint() - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + logDebug(s"Checkpoint: Next checkpoint is at " + + s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") + } } } catch { - case e: Throwable => { + case NonFatal(e) => { /* * If there is a failure within the batch, the batch will not be checkpointed. * This will potentially cause records since the last checkpoint to be processed @@ -116,22 +112,22 @@ private[kinesis] class KinesisRecordProcessor( logError(s"Exception: WorkerId $workerId encountered and exception while storing " + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e } } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -145,8 +141,12 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + case ShutdownReason.TERMINATE => + val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId) + if (latestSeqNumToCheckpointOption.nonEmpty) { + KinesisRecordProcessor.retryRandom( + checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100) + } /* * ZOMBIE Use Case. NoOp. @@ -190,7 +190,7 @@ private[kinesis] object KinesisRecordProcessor extends Logging { logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) } - /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + /* Throw: Shutdown has been requested by the Kinesis Client Library. */ case _: ShutdownException => { logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala new file mode 100644 index 000000000000..634bf9452107 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Random, Success, Try} + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient +import com.amazonaws.services.dynamodbv2.document.DynamoDB +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark.Logging + +/** + * Shared utility methods for performing Kinesis tests that actually transfer data + */ +private[kinesis] class KinesisTestUtils extends Logging { + + val endpointUrl = KinesisTestUtils.endpointUrl + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val streamShardCount = 2 + + private val createStreamTimeoutSeconds = 300 + private val describeStreamPollTimeSeconds = 1 + + @volatile + private var streamCreated = false + + @volatile + private var _streamName: String = _ + + private lazy val kinesisClient = { + val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) + client.setEndpoint(endpointUrl) + client + } + + private lazy val dynamoDB = { + val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain()) + dynamoDBClient.setRegion(RegionUtils.getRegion(regionName)) + new DynamoDB(dynamoDBClient) + } + + def streamName: String = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + _streamName + } + + def createStream(): Unit = { + require(!streamCreated, "Stream already created") + _streamName = findNonExistentStreamName() + + // Create a stream. The number of shards determines the provisioned throughput. + logInfo(s"Creating stream ${_streamName}") + val createStreamRequest = new CreateStreamRequest() + createStreamRequest.setStreamName(_streamName) + createStreamRequest.setShardCount(2) + kinesisClient.createStream(createStreamRequest) + + // The stream is now being created. Wait for it to become active. + waitForStreamToBeActive(_streamName) + streamCreated = true + logInfo(s"Created stream ${_streamName}") + } + + /** + * Push data to Kinesis stream and return a map of + * shardId -> seq of (data, seq number) pushed to corresponding shard + */ + def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + + testData.foreach { num => + val str = num.toString + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(ByteBuffer.wrap(str.getBytes())) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") + shardIdToSeqNumbers.toMap + } + + /** + * Expose a Python friendly API. + */ + def pushData(testData: java.util.List[Int]): Unit = { + pushData(testData.asScala) + } + + def deleteStream(): Unit = { + try { + if (streamCreated) { + kinesisClient.deleteStream(streamName) + } + } catch { + case e: Exception => + logWarning(s"Could not delete stream $streamName") + } + } + + def deleteDynamoDBTable(tableName: String): Unit = { + try { + val table = dynamoDB.getTable(tableName) + table.delete() + table.waitForDelete() + } catch { + case e: Exception => + logWarning(s"Could not delete DynamoDB table $tableName") + } + } + + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + + private def findNonExistentStreamName(): String = { + var testStreamName: String = null + do { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}" + } while (describeStream(testStreamName).nonEmpty) + testStreamName + } + + private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = { + val startTime = System.currentTimeMillis() + val endTime = startTime + TimeUnit.SECONDS.toMillis(createStreamTimeoutSeconds) + while (System.currentTimeMillis() < endTime) { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + describeStream(streamNameToWaitFor).foreach { description => + val streamStatus = description.getStreamStatus() + logDebug(s"\t- current state: $streamStatus\n") + if ("ACTIVE".equals(streamStatus)) { + return + } + } + } + require(false, s"Stream $streamName never became active") + } +} + +private[kinesis] object KinesisTestUtils { + + val envVarNameForEnablingTests = "ENABLE_KINESIS_TESTS" + val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL" + val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + + lazy val shouldRunTests = { + val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1") + if (isEnvSet) { + // scalastyle:off println + // Print this so that they are easily visible on the console and not hidden in the log4j logs. + println( + s""" + |Kinesis tests that actually send data has been enabled by setting the environment + |variable $envVarNameForEnablingTests to 1. This will create Kinesis Streams and + |DynamoDB tables in AWS. Please be aware that this may incur some AWS costs. + |By default, the tests use the endpoint URL $defaultEndpointUrl to create Kinesis streams. + |To change this endpoint URL to a different region, you can set the environment variable + |$endVarNameForEndpoint to the desired endpoint URL + |(e.g. $endVarNameForEndpoint="https://kinesis.us-west-2.amazonaws.com"). + """.stripMargin) + // scalastyle:on println + } + isEnvSet + } + + lazy val endpointUrl = { + val url = sys.env.getOrElse(endVarNameForEndpoint, defaultEndpointUrl) + // scalastyle:off println + // Print this so that they are easily visible on the console and not hidden in the log4j logs. + println(s"Using endpoint URL $url for creating Kinesis streams for tests.") + // scalastyle:on println + url + } + + def isAWSCredentialsPresent: Boolean = { + Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess + } + + def getAWSCredentials(): AWSCredentials = { + assert(shouldRunTests, + "Kinesis test not enabled, should not attempt to get AWS credentials") + Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { + case Success(cred) => cred + case Failure(e) => + throw new Exception( + s""" + |Kinesis tests enabled using environment variable $envVarNameForEnablingTests + |but could not find AWS credentials. Please follow instructions in AWS documentation + |to set the credentials in your system such that the DefaultAWSCredentialsProviderChain + |can find the credentials. + """.stripMargin) + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 96f4399accd3..c799fadf2d5c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,82 +16,331 @@ */ package org.apache.spark.streaming.kinesis -import org.apache.spark.annotation.Experimental +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream -import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.apache.spark.streaming.{Duration, StreamingContext} -/** - * Helper class to create Amazon Kinesis Input Stream - * :: Experimental :: - */ -@Experimental object KinesisUtils { /** - * Create an InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: - * @param ssc StreamingContext object + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): ReceiverInputDStream[Array[Byte]] = { + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. * - * @return ReceiverInputDStream[Array[Byte]] + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental + @deprecated("use other forms of createStream", "1.4.0") def createStream( ssc: StreamingContext, streamName: String, endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, - checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), + initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None) + } } /** - * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. * - * @return JavaReceiverInputDStream[Array[Byte]] + * @param jssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental + @deprecated("use other forms of createStream", "1.4.0") def createStream( - jssc: JavaStreamingContext, - streamName: String, - endpointUrl: String, + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { - jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, - endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream( + jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel) + } + + private def getRegionByEndpoint(endpointUrl: String): String = { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() } + + private def validateRegion(regionName: String): String = { + Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { + throw new IllegalArgumentException(s"Region name '$regionName' is not valid") + } + } +} + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { + initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + } + + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + } diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index 853ef0ed2986..edbecdae9209 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala new file mode 100644 index 000000000000..a89e5627e014 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.{SparkConf, SparkContext, SparkException} + +class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { + + private val testData = 1 to 8 + + private var testUtils: KinesisTestUtils = null + private var shardIds: Seq[String] = null + private var shardIdToData: Map[String, Seq[Int]] = null + private var shardIdToSeqNumbers: Map[String, Seq[String]] = null + private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null + private var shardIdToRange: Map[String, SequenceNumberRange] = null + private var allRanges: Seq[SequenceNumberRange] = null + + private var sc: SparkContext = null + private var blockManager: BlockManager = null + + + override def beforeAll(): Unit = { + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils() + testUtils.createStream() + + shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") + + shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq + shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }} + shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} + shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => + val seqNumRange = SequenceNumberRange( + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + (shardId, seqNumRange) + } + allRanges = shardIdToRange.values.toSeq + + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.deleteStream() + } + if (sc != null) { + sc.stop() + } + } + + testIfEnabled("Basic reading from Kinesis") { + // Verify all data using multiple ranges in a single RDD partition + val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, + fakeBlockIds(1), + Array(SequenceNumberRanges(allRanges.toArray)) + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData1.toSet === testData.toSet) + + // Verify all data using one range in each of the multiple RDD partitions + val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData2.toSet === testData.toSet) + + // Verify ordering within each partition + val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collectPartitions() + assert(receivedData3.length === allRanges.size) + for (i <- 0 until allRanges.size) { + assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId)) + } + } + + testIfEnabled("Read data available in both block manager and Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available only in block manager, not in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0) + } + + testIfEnabled("Read data available only in Kinesis, not in block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available partially in block manager, rest in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1) + } + + testIfEnabled("Test isBlockValid skips block fetching from block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0, + testIsBlockValid = true) + } + + testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, + testBlockRemove = true) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInKinesis Number of partitions to write to the Kinesis. + * Partitions (numPartitions - 1 - numPartitionsInKinesis) to + * (numPartitions - 1) will be written to Kinesis + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInKinesis = 4 + */ + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInKinesis: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false + ): Unit = { + require(shardIds.size > 1, "Need at least 2 shards to test") + require(numPartitionsInBM <= shardIds.size , + "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") + require(numPartitionsInKinesis <= shardIds.size , + "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") + require(numPartitionsInBM <= numPartitions, + "Number of partitions in BlockManager cannot be more than that in RDD") + require(numPartitionsInKinesis <= numPartitions, + "Number of partitions in Kinesis cannot be more than that in RDD") + + // Put necessary blocks in the block manager + val blockIds = fakeBlockIds(numPartitions) + blockIds.foreach(blockManager.removeBlock(_)) + (0 until numPartitionsInBM).foreach { i => + val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() } + blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY) + } + + // Create the necessary ranges to use in the RDD + val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + val realRanges = Array.tabulate(numPartitionsInKinesis) { i => + val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) + SequenceNumberRanges(Array(range)) + } + val ranges = (fakeRanges ++ realRanges) + + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + + require( + blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not + require( + ranges.takeRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName == testUtils.streamName } + }, "Incorrect configuration of RDD, expected ranges not set: " + ) + + require( + ranges.dropRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName != testUtils.streamName } + }, "Incorrect configuration of RDD, unexpected ranges set" + ) + + val rdd = new KinesisBackedBlockRDD( + sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges) + val collectedData = rdd.map { bytes => + new String(bytes).toInt + }.collect() + assert(collectedData.toSet === testData.toSet) + + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") + val rdd2 = new KinesisBackedBlockRDD( + sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges, + isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInKinesis, + "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet) + } + } + + /** Generate fake block ids */ + private def fakeBlockIds(num: Int): Array[BlockId] = { + Array.tabulate(num) { i => new StreamBlockId(0, i) } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala new file mode 100644 index 000000000000..ee428f31d6ce --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.SparkFunSuite + +/** + * Helper class that runs Kinesis real data transfer tests or + * ignores them based on env variable is set or not. + */ +trait KinesisFunSuite extends SparkFunSuite { + import KinesisTestUtils._ + + /** Run the test if environment variable is set or ignore the test */ + def testIfEnabled(testName: String)(testBody: => Unit) { + if (shouldRunTests) { + test(testName)(testBody) + } else { + ignore(s"$testName [enable by setting env var $envVarNameForEnablingTests=1]")(testBody) + } + } + + /** Run the give body of code only if Kinesis tests are enabled */ + def runIfTestsEnabled(message: String)(body: => Unit): Unit = { + if (shouldRunTests) { + body + } else { + ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")() + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 41dbd64c2b1f..3d136aec2e70 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -17,47 +17,40 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays -import scala.collection.JavaConversions.seqAsJavaList - -import org.apache.spark.annotation.Experimental -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.Seconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock -import org.scalatest.BeforeAndAfter -import org.scalatest.Matchers -import org.scalatest.mock.EasyMockSugar - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfter, Matchers} + +import org.apache.spark.streaming.{Milliseconds, TestSuiteBase} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** - * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with EasyMockSugar { + with MockitoSugar { val app = "TestKinesisReceiver" val stream = "mySparkStream" val endpoint = "endpoint-url" val workerId = "dummyWorkerId" val shardId = "dummyShardId" + val seqNum = "dummySeqNum" + val someSeqNum = Some(seqNum) val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) - val batch = List[Record](record1, record2) + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val batch = Arrays.asList(record1, record2) var receiverMock: KinesisReceiver = _ var checkpointerMock: IRecordProcessorCheckpointer = _ @@ -65,7 +58,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ - override def beforeFunction() = { + override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] @@ -73,203 +66,203 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } - test("kinesis utils api") { - val ssc = new StreamingContext(master, framework, batchDuration) - // Tests the API, does not actually test data receiving - val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); - ssc.stop() + override def afterFunction(): Unit = { + super.afterFunction() + // Since this suite was originally written using EasyMock, add this to preserve the old + // mocking semantics (see SPARK-5735 for more details) + verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, + checkpointStateMock, currentClockMock) + } + + test("check serializability of SerializableAWSCredentials") { + Utils.deserialize[SerializableAWSCredentials]( + Utils.serialize(new SerializableAWSCredentials("x", "y"))) } test("process records including store and checkpoint") { - val expectedCheckpointIntervalMillis = 10 - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).once() - receiverMock.store(record2.getData().array()).once() - checkpointStateMock.shouldCheckpoint().andReturn(true).once() - checkpointerMock.checkpoint().once() - checkpointStateMock.advanceCheckpoint().once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) + verify(checkpointStateMock, times(1)).shouldCheckpoint() + verify(checkpointerMock, times(1)).checkpoint(anyString) + verify(checkpointStateMock, times(1)).advanceCheckpoint() } test("shouldn't store and checkpoint when receiver is stopped") { - expecting { - receiverMock.isStopped().andReturn(true).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) + verify(checkpointerMock, never).checkpoint(anyString) } test("shouldn't checkpoint when exception occurs during store") { - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when( + receiverMock.addRecords(anyString, anyListOf(classOf[Record])) + ).thenThrow(new RuntimeException()) + + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) + recordProcessor.processRecords(batch, checkpointerMock) } + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(checkpointerMock, never).checkpoint(anyString) } test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { + when(currentClockMock.getTimeMillis()).thenReturn(0) + val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - } + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should checkpoint if we have exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should add to time when advancing checkpoint") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointIntervalMillis = 10 + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shutdown should checkpoint if the reason is TERMINATE") { - expecting { - checkpointerMock.checkpoint().once() - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) - } + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) + + verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) + verify(checkpointerMock, times(1)).checkpoint(anyString) } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - expecting { - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) - } + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + + verify(checkpointerMock, never).checkpoint(anyString) } test("retry success on first attempt") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()).thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(1)).isStopped() } test("retry success on second attempt after a Kinesis throttling exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new ThrottlingException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new ThrottlingException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry success on second attempt after a Kinesis dependency exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new KinesisClientLibDependencyException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry failed after a shutdown exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[ShutdownException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) + + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after an invalid state exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[InvalidStateException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) + + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after unexpected exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) + + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after exhausing all retries") { val expectedErrorMessage = "final try error message" - expecting { - checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) - .andThrow(new ThrottlingException(expectedErrorMessage)).once() - } - whenExecuting(checkpointerMock) { - val exception = intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - exception.getMessage().shouldBe(expectedErrorMessage) + when(checkpointerMock.checkpoint()) + .thenThrow(new ThrottlingException("error message")) + .thenThrow(new ThrottlingException(expectedErrorMessage)) + + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + exception.getMessage().shouldBe(expectedErrorMessage) + + verify(checkpointerMock, times(2)).checkpoint() } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala new file mode 100644 index 000000000000..1177dc758100 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.kinesis.KinesisTestUtils._ +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} + +class KinesisStreamSuite extends KinesisFunSuite + with Eventually with BeforeAndAfter with BeforeAndAfterAll { + + // This is the name that KCL will use to save metadata to DynamoDB + private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + private val batchDuration = Seconds(1) + + // Dummy parameters for API testing + private val dummyEndpointUrl = defaultEndpointUrl + private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName() + private val dummyAWSAccessKey = "dummyAccessKey" + private val dummyAWSSecretKey = "dummySecretKey" + + private var testUtils: KinesisTestUtils = null + private var ssc: StreamingContext = null + private var sc: SparkContext = null + + override def beforeAll(): Unit = { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) + + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils() + testUtils.createStream() + } + } + + override def afterAll(): Unit = { + if (ssc != null) { + ssc.stop() + } + if (sc != null) { + sc.stop() + } + if (testUtils != null) { + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + testUtils.deleteStream() + testUtils.deleteDynamoDBTable(appName) + } + } + + before { + ssc = new StreamingContext(sc, batchDuration) + } + + after { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + ssc = null + } + if (testUtils != null) { + testUtils.deleteDynamoDBTable(appName) + } + } + + test("KinesisUtils API") { + // Tests the API, does not actually test data receiving + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", + dummyEndpointUrl, Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + dummyEndpointUrl, dummyRegionName, + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + dummyEndpointUrl, dummyRegionName, + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + dummyAWSAccessKey, dummyAWSSecretKey) + } + + test("RDD generation") { + val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), + StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) + assert(inputStream.isInstanceOf[KinesisInputDStream]) + + val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream] + val time = Time(1000) + + // Generate block info data for testing + val seqNumRanges1 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")) + val blockId1 = StreamBlockId(kinesisStream.id, 123) + val blockInfo1 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) + + val seqNumRanges2 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb")) + val blockId2 = StreamBlockId(kinesisStream.id, 345) + val blockInfo2 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) + + // Verify that the generated KinesisBackedBlockRDD has the all the right information + val blockInfos = Seq(blockInfo1, blockInfo2) + val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD] + assert(kinesisRDD.regionName === dummyRegionName) + assert(kinesisRDD.endpointUrl === dummyEndpointUrl) + assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) + assert(kinesisRDD.awsCredentialsOption === + Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(nonEmptyRDD.partitions.size === blockInfos.size) + nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } + val partitions = nonEmptyRDD.partitions.map { + _.asInstanceOf[KinesisBackedBlockRDDPartition] }.toSeq + assert(partitions.map { _.seqNumberRanges } === Seq(seqNumRanges1, seqNumRanges2)) + assert(partitions.map { _.blockId } === Seq(blockId1, blockId2)) + assert(partitions.forall { _.isBlockIdValid === true }) + + // Verify that KinesisBackedBlockRDD is generated even when there are no blocks + val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) + emptyRDD shouldBe a [KinesisBackedBlockRDD] + emptyRDD.partitions shouldBe empty + + // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid + blockInfos.foreach { _.setBlockIdInvalid() } + kinesisStream.createBlockRDD(time, blockInfos).partitions.foreach { partition => + assert(partition.asInstanceOf[KinesisBackedBlockRDDPartition].isBlockIdValid === false) + } + } + + + /** + * Test the stream by sending data to a Kinesis stream and receiving from it. + * This test is not run by default as it requires AWS credentials that the test + * environment may not have. Even if there is AWS credentials available, the user + * may not want to run these tests to avoid the Kinesis costs. To enable this test, + * you must have AWS credentials available through the default AWS provider chain, + * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . + */ + testIfEnabled("basic operation") { + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + testUtils.pushData(testData) + assert(collected === testData.toSet, "\nData received does not match data sent") + } + ssc.stop(stopSparkContext = false) + } + + testIfEnabled("failure recovery") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + val checkpointDir = Utils.createTempDir().getAbsolutePath + + ssc = new StreamingContext(sc, Milliseconds(1000)) + ssc.checkpoint(checkpointDir) + + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] + with mutable.SynchronizedMap[Time, (Array[SequenceNumberRanges], Seq[Int])] + + val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch + kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq + collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + }) + + ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint + ssc.start() + + def numBatchesWithData: Int = collectedData.count(_._2._2.nonEmpty) + + def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty + + // Run until there are at least 10 batches with some data in them + // If this times out because numBatchesWithData is empty, then its likely that foreachRDD + // function failed with exceptions, and nothing got added to `collectedData` + eventually(timeout(2 minutes), interval(1 seconds)) { + testUtils.pushData(1 to 5) + assert(isCheckpointPresent && numBatchesWithData > 10) + } + ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused + + // Restart the context from checkpoint and verify whether the + logInfo("Restarting from checkpoint") + ssc = new StreamingContext(checkpointDir) + ssc.start() + val recoveredKinesisStream = ssc.graph.getInputStreams().head + + // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges + // and return the same data + val times = collectedData.keySet + times.foreach { time => + val (arrayOfSeqNumRanges, data) = collectedData(time) + val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] + rdd shouldBe a [KinesisBackedBlockRDD] + + // Verify the recovered sequence ranges + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) + arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => + assert(expected.ranges.toSeq === found.ranges.toSeq) + } + + // Verify the recovered data + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) + } + ssc.stop() + } + +} diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index f2f0aa78b0a4..478d0019a25f 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8fac24b6ed86..853dea9a7795 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -40,14 +40,26 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava - org.jblas - jblas - ${jblas.version} + com.github.fommil.netlib + core + ${netlib.java.version} + + + net.sourceforge.f2j + arpack_combined_all + 0.1 org.scalacheck diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index f70715fca6ee..23430179f12e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -49,3 +49,19 @@ abstract class EdgeContext[VD, ED, A] { et } } + +object EdgeContext { + + /** + * Extractor mainly used for Graph#aggregateMessages*. + * Example: + * {{{ + * val messages = graph.aggregateMessages( + * case ctx @ EdgeContext(_, _, _, _, attr) => + * ctx.sendToDst(attr) + * , _ + _) + * }}} + */ + def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = + Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 6f03eb143977..ce1054ed92ba 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -26,20 +26,20 @@ class EdgeDirection private (private val name: String) extends Serializable { * out becomes in and both and either remain the same. */ def reverse: EdgeDirection = this match { - case EdgeDirection.In => EdgeDirection.Out - case EdgeDirection.Out => EdgeDirection.In + case EdgeDirection.In => EdgeDirection.Out + case EdgeDirection.Out => EdgeDirection.In case EdgeDirection.Either => EdgeDirection.Either case EdgeDirection.Both => EdgeDirection.Both } override def toString: String = "EdgeDirection." + name - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case other: EdgeDirection => other.name == name case _ => false } - override def hashCode = name.hashCode + override def hashCode: Int = name.hashCode } @@ -48,14 +48,14 @@ class EdgeDirection private (private val name: String) extends Serializable { */ object EdgeDirection { /** Edges arriving at a vertex. */ - final val In = new EdgeDirection("In") + final val In: EdgeDirection = new EdgeDirection("In") /** Edges originating from a vertex. */ - final val Out = new EdgeDirection("Out") + final val Out: EdgeDirection = new EdgeDirection("Out") /** Edges originating from *or* arriving at a vertex of interest. */ - final val Either = new EdgeDirection("Either") + final val Either: EdgeDirection = new EdgeDirection("Either") /** Edges originating from *and* arriving at a vertex of interest. */ - final val Both = new EdgeDirection("Both") + final val Both: EdgeDirection = new EdgeDirection("Both") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd..4611a3ace219 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index 9d473d5ebda4..65f82429d202 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * Set the edge properties of this triplet. */ - protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = { + protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { srcId = other.srcId dstId = other.dstId attr = other.attr @@ -62,7 +62,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { def vertexAttr(vid: VertexId): VD = if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } - override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 8494d06b1cdb..db73a8abc573 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * satisfy the predicates */ def subgraph( - epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + epred: EdgeTriplet[VD, ED] => Boolean = (x => true), vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] @@ -409,7 +409,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") * val inDeg: RDD[(VertexId, Int)] = - * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * rawGraph.aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) * }}} * * @note By expressing computation at the edge level we achieve diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 4933aecba128..21187be7678a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -77,7 +77,7 @@ object GraphLoader extends Logging { if (!line.isEmpty && line(0) != '#') { val lineArray = line.split("\\s+") if (lineArray.length < 2) { - logWarning("Invalid line: " + line) + throw new IllegalArgumentException("Invalid line: " + line) } val srcId = lineArray(0).toLong val dstId = lineArray(1).toLong diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index dc8b4789c4b6..9451ff1e5c0e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -113,7 +113,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Collect the neighbor vertex attributes for each vertex. * * @note This function could be highly inefficient on power-law - * graphs where high degree vertices may force a large ammount of + * graphs where high degree vertices may force a large amount of * information to be collected to a single location. * * @param edgeDirection the direction along which to collect @@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { val nbrs = edgeDirection match { case EdgeDirection.Either => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => { ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => @@ -187,7 +187,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali /** * Join the vertices with an RDD and then apply a function from the - * the vertex and RDD entry to a new vertex value. The input table + * vertex and RDD entry to a new vertex value. The input table * should contain at most one entry for each vertex. If no entry is * provided the map function is skipped and the old value is used. * @@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def filter[VD2: ClassTag, ED2: ClassTag]( preprocess: Graph[VD, ED] => Graph[VD2, ED2], epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true, - vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = { + vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = { graph.mask(preprocess(graph).subgraph(epred, vpred)) } @@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) @@ -372,6 +372,31 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali PageRank.runUntilConvergence(graph, tol, resetProb) } + + /** + * Run personalized PageRank for a given vertex, such that all random walks + * are started relative to the source node. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]] + */ + def personalizedPageRank(src: VertexId, tol: Double, + resetProb: Double = 0.15) : Graph[Double, Double] = { + PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src)) + } + + /** + * Run Personalized PageRank for a fixed number of iterations with + * with all iterations originating at the source node + * returning a graph with vertex attributes + * containing the PageRank and edge attributes the normalized edge weight. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]] + */ + def staticPersonalizedPageRank(src: VertexId, numIter: Int, + resetProb: Double = 0.15) : Graph[Double, Double] = { + PageRank.runWithOptions(graph, numIter, resetProb, Some(src)) + } + /** * Run PageRank for a fixed number of iterations returning a graph with vertex attributes * containing the PageRank and edge attributes the normalized edge weight. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 7372dfbd9fe9..70a7592da8ae 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. * * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: @@ -61,26 +61,36 @@ object PartitionStrategy { * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, * P6)` or the last * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be - * replicated to at most `2 * sqrt(numParts) - 1` machines. + * replicated to at most `2 * sqrt(numParts)` machines. * * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the * vertex locations. * - * One of the limitations of this approach is that the number of machines must either be a - * perfect square. We partially address this limitation by computing the machine assignment to - * the next - * largest perfect square and then mapping back down to the actual number of machines. - * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect - * square is used. + * When the number of partitions requested is not a perfect square we use a slightly different + * method where the last column can have a different number of rows than the others while still + * maintaining the same size per block. */ case object EdgePartition2D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt - (col * ceilSqrtNumParts + row) % numParts + if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) { + // Use old method for perfect squared to ensure we get same results + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + + } else { + // Otherwise use new method + val cols = ceilSqrtNumParts + val rows = (numParts + cols - 1) / cols + val lastColRows = numParts - rows * (cols - 1) + val col = (math.abs(src * mixingPrime) % numParts / rows).toInt + val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt + col * rows + row + + } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5e55620147df..2ca60d51f833 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -78,8 +78,8 @@ object Pregel extends Logging { * * @param graph the input graph. * - * @param initialMsg the message each vertex will receive at the on - * the first iteration + * @param initialMsg the message each vertex will receive at the first + * iteration * * @param maxIterations the maximum number of iterations to run for * @@ -127,30 +127,27 @@ object Pregel extends Logging { var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Update the graph with the new vertices. + // Receive the messages and update the vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } - g.cache() + g = g.joinVertices(messages)(vprog).cache() val oldMessages = messages - // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't - // get to send messages. We must cache messages so it can be materialized on the next line, - // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() - // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This - // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the - // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages + // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages + // and the vertices of g). activeMessages = messages.count() logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs - oldMessages.unpersist(blocking=false) - newVerts.unpersist(blocking=false) - prevG.unpersistVertices(blocking=false) - prevG.edges.unpersist(blocking=false) + oldMessages.unpersist(blocking = false) + prevG.unpersistVertices(blocking = false) + prevG.edges.unpersist(blocking = false) // count the iteration i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 09ae3f9f6c09..a9f04b559c3d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -122,8 +122,36 @@ abstract class VertexRDD[VD]( def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] /** - * Hides vertices that are the same between `this` and `other`; for vertices that are different, - * keeps the values from `other`. + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other an RDD to run the set operation against + */ + def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other a VertexRDD to run the set operation against + */ + def minus(other: VertexRDD[VD]): VertexRDD[VD] + + /** + * For each vertex present in both `this` and `other`, `diff` returns only those vertices with + * differing values; for values that are different, keeps the values from `other`. This is + * only guaranteed to work if the VertexRDDs share a common ancestor. + * + * @param other the other RDD[(VertexId, VD)] with which to diff against. + */ + def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each vertex present in both `this` and `other`, `diff` returns only those vertices with + * differing values; for values that are different, keeps the values from `other`. This is + * only guaranteed to work if the VertexRDDs share a common ancestor. + * + * @param other the other VertexRDD with which to diff against. */ def diff(other: VertexRDD[VD]): VertexRDD[VD] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 373af7544837..ab021a252eb8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -156,8 +156,8 @@ class EdgePartition[ val size = data.size var i = 0 while (i < size) { - edge.srcId = srcIds(i) - edge.dstId = dstIds(i) + edge.srcId = srcIds(i) + edge.dstId = dstIds(i) edge.attr = data(i) newData(i) = f(edge) i += 1 @@ -324,7 +324,7 @@ class EdgePartition[ * * @return an iterator over edges in the partition */ - def iterator = new Iterator[Edge[ED]] { + def iterator: Iterator[Edge[ED]] = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] private[this] var pos = 0 @@ -351,7 +351,7 @@ class EdgePartition[ override def hasNext: Boolean = pos < EdgePartition.this.size - override def next() = { + override def next(): EdgeTriplet[VD, ED] = { val triplet = new EdgeTriplet[VD, ED] val localSrcId = localSrcIds(pos) val localDstId = localDstIds(pos) @@ -518,11 +518,11 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - override def srcId = _srcId - override def dstId = _dstId - override def srcAttr = _srcAttr - override def dstAttr = _dstAttr - override def attr = _attr + override def srcId: VertexId = _srcId + override def dstId: VertexId = _dstId + override def srcAttr: VD = _srcAttr + override def dstAttr: VD = _dstAttr + override def attr: ED = _attr override def sendToSrc(msg: A) { send(_localSrcId, msg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 56cb41661e30..c88b2f65a86c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -70,9 +70,9 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 90a74d23a26c..da95314440d8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -332,9 +332,9 @@ object GraphImpl { edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { val edgeRDD = EdgeRDD.fromEdges(edges)(classTag[ED], classTag[VD]) - .withTargetStorageLevel(edgeStorageLevel).cache() + .withTargetStorageLevel(edgeStorageLevel) val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr) - .withTargetStorageLevel(vertexStorageLevel).cache() + .withTargetStorageLevel(vertexStorageLevel) GraphImpl(vertexRDD, edgeRDD) } @@ -346,9 +346,14 @@ object GraphImpl { def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + + vertices.cache() + // Convert the vertex partitions in edges to the correct type val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) + .cache() + GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 8ab255bd4038..1df86449fa0c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -50,7 +50,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to * match. */ - def reverse() = { + def reverse(): ReplicatedVertexView[VD, ED] = { val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 4fd2548b7faf..b90f9fa32705 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -88,6 +88,21 @@ private[graphx] abstract class VertexPartitionBaseOps this.withMask(newMask) } + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Self[VD]): Self[VD] = { + if (self.index != other.index) { + logWarning("Minus operations on two VertexPartitions with different indexes is slow.") + minus(createUsingIndex(other.iterator)) + } else { + self.withMask(self.mask.andNot(other.mask)) + } + } + + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Iterator[(VertexId, VD)]): Self[VD] = { + minus(createUsingIndex(other)) + } + /** * Hides vertices that are the same between this and other. For vertices that are different, keeps * the values from `other`. The indices of `this` and `other` must be the same. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 6dad167fa741..7f4e7e9d79d6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,9 +71,9 @@ class VertexRDDImpl[VD] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } @@ -87,7 +87,7 @@ class VertexRDDImpl[VD] private[graphx] ( /** The number of vertices in the RDD. */ override def count(): Long = { - partitionsRDD.map(_.size).reduce(_ + _) + partitionsRDD.map(_.size.toLong).reduce(_ + _) } override private[graphx] def mapVertexPartitions[VD2: ClassTag]( @@ -103,9 +103,44 @@ class VertexRDDImpl[VD] private[graphx] ( override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = this.mapVertexPartitions(_.map(f)) + override def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + minus(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + + override def minus (other: VertexRDD[VD]): VertexRDD[VD] = { + other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true) { + (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.minus(otherPart)) + }) + case _ => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.minus(msgs)) + } + ) + } + } + + override def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + diff(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val otherPartition = other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + other.partitionsRDD + case _ => + VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD + } val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true + otherPartition, preservesPartitioning = true ) { (thisIter, otherIter) => val thisPart = thisIter.next() val otherPart = otherIter.next() @@ -133,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient leftZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => leftZipJoin(other)(f) case _ => this.withPartitionsRDD[VD3]( @@ -162,7 +197,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient innerZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => innerZipJoin(other)(f) case _ => this.withPartitionsRDD( diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958..859f89603904 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -37,7 +37,7 @@ object ConnectedComponents { */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(edge: EdgeTriplet[VertexId, ED]) = { + def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { Iterator((edge.dstId, edge.srcAttr)) } else if (edge.srcAttr > edge.dstAttr) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 82e9e0651517..a3ad6bed1c99 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]) = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) @@ -54,7 +54,7 @@ object LabelPropagation { i -> (count1Val + count2Val) }.toMap } - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]) = { + def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { if (message.isEmpty) attr else message.maxBy(_._2)._1 } val initialMessage = Map[VertexId, Long]() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index e139959c3f5c..8c0a461e99fa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag +import scala.language.postfixOps import org.apache.spark.Logging import org.apache.spark.graphx._ @@ -25,8 +26,8 @@ import org.apache.spark.graphx._ /** * PageRank algorithm implementation. There are two implementations of PageRank implemented. * - * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number - * of iterations: + * The first implementation uses the standalone [[Graph]] interface and runs PageRank + * for a fixed number of iterations: * {{{ * var PR = Array.fill(n)( 1.0 ) * val oldPR = Array.fill(n)( 1.0 ) @@ -38,7 +39,7 @@ import org.apache.spark.graphx._ * } * }}} * - * The second implementation uses the standalone [[Graph]] interface and runs PageRank until + * The second implementation uses the [[Pregel]] interface and runs PageRank until * convergence: * * {{{ @@ -60,6 +61,7 @@ import org.apache.spark.graphx._ */ object PageRank extends Logging { + /** * Run PageRank for a fixed number of iterations returning a graph * with vertex attributes containing the PageRank and edge @@ -74,10 +76,33 @@ object PageRank extends Logging { * * @return the graph containing with each vertex containing the PageRank and each edge * containing the normalized weight. + */ + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, + resetProb: Double = 0.15): Graph[Double, Double] = + { + runWithOptions(graph, numIter, resetProb) + } + + /** + * Run PageRank for a fixed number of iterations returning a graph + * with vertex attributes containing the PageRank and edge + * attributes the normalized edge weight. + * + * @tparam VD the original vertex attribute (not used) + * @tparam ED the original edge attribute (not used) + * + * @param graph the graph on which to compute PageRank + * @param numIter the number of iterations of PageRank to run + * @param resetProb the random reset probability (alpha) + * @param srcId the source vertex for a Personalized Page Rank (optional) + * + * @return the graph containing with each vertex containing the PageRank and each edge + * containing the normalized weight. * */ - def run[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = + def runWithOptions[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, + srcId: Option[VertexId] = None): Graph[Double, Double] = { // Initialize the PageRank graph with each edge attribute having // weight 1/outDegree and each vertex with attribute 1.0. @@ -89,6 +114,10 @@ object PageRank extends Logging { // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) + val personalized = srcId isDefined + val src: VertexId = srcId.getOrElse(-1L) + def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } + var iteration = 0 var prevRankGraph: Graph[Double, Double] = null while (iteration < numIter) { @@ -103,8 +132,14 @@ object PageRank extends Logging { // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the // edge partitions. prevRankGraph = rankGraph + val rPrb = if (personalized) { + (src: VertexId , id: VertexId) => resetProb * delta(src, id) + } else { + (src: VertexId, id: VertexId) => resetProb + } + rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => resetProb + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -133,7 +168,29 @@ object PageRank extends Logging { * containing the normalized weight. */ def runUntilConvergence[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = + { + runUntilConvergenceWithOptions(graph, tol, resetProb) + } + + /** + * Run a dynamic version of PageRank returning a graph with vertex attributes containing the + * PageRank and edge attributes containing the normalized edge weight. + * + * @tparam VD the original vertex attribute (not used) + * @tparam ED the original edge attribute (not used) + * + * @param graph the graph on which to compute PageRank + * @param tol the tolerance allowed at convergence (smaller => more accurate). + * @param resetProb the random reset probability (alpha) + * @param srcId the source vertex for a Personalized Page Rank (optional) + * + * @return the graph containing with each vertex containing the PageRank and each edge + * containing the normalized weight. + */ + def runUntilConvergenceWithOptions[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, + srcId: Option[VertexId] = None): Graph[Double, Double] = { // Initialize the pagerankGraph with each edge attribute // having weight 1/outDegree and each vertex with attribute 1.0. @@ -148,6 +205,10 @@ object PageRank extends Logging { .mapVertices( (id, attr) => (0.0, 0.0) ) .cache() + val personalized = srcId.isDefined + val src: VertexId = srcId.getOrElse(-1L) + + // Define the three functions needed to implement PageRank in the GraphX // version of Pregel def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { @@ -156,6 +217,17 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } + def personalizedVertexProgram(id: VertexId, attr: (Double, Double), + msgSum: Double): (Double, Double) = { + val (oldPR, lastDelta) = attr + var teleport = oldPR + val delta = if (src==id) 1.0 else 0.0 + teleport = oldPR*delta + + val newPR = teleport + (1.0 - resetProb) * msgSum + (newPR, newPR - oldPR) + } + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) @@ -170,8 +242,17 @@ object PageRank extends Logging { val initialMessage = resetProb / (1.0 - resetProb) // Execute a dynamic version of Pregel. + val vp = if (personalized) { + (id: VertexId, attr: (Double, Double), msgSum: Double) => + personalizedVertexProgram(id, attr, msgSum) + } else { + (id: VertexId, attr: (Double, Double), msgSum: Double) => + vertexProgram(id, attr, msgSum) + } + Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + vp, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) } // end of deltaPageRank + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index f58587e10a82..9cb24ed080e1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -18,7 +18,9 @@ package org.apache.spark.graphx.lib import scala.util.Random -import org.jblas.DoubleMatrix + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + import org.apache.spark.rdd._ import org.apache.spark.graphx._ @@ -37,12 +39,23 @@ object SVDPlusPlus { var gamma7: Double) extends Serializable + /** + * This method is now replaced by the updated version of `run()` and returns exactly + * the same result. + */ + @deprecated("Call run()", "1.4.0") + def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = + { + run(edges, conf) + } + /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", * available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. * - * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), + * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^^-0.5^^*sum(y)), * see the details on page 6. * * @param edges edges for constructing the graph @@ -52,16 +65,13 @@ object SVDPlusPlus { * @return a graph with vertex attributes containing the trained model */ def run(edges: RDD[Edge[Double]], conf: Conf) - : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) = + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = { // Generate default vertex attribute - def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = { - val v1 = new DoubleMatrix(rank) - val v2 = new DoubleMatrix(rank) - for (i <- 0 until rank) { - v1.put(i, Random.nextDouble()) - v2.put(i, Random.nextDouble()) - } + def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = { + // TODO: use a fixed random seed + val v1 = Array.fill(rank)(Random.nextDouble()) + val v2 = Array.fill(rank)(Random.nextDouble()) (v1, v2, 0.0, 0.0) } @@ -72,38 +82,47 @@ object SVDPlusPlus { // construct graph var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() + materialize(g) + edges.unpersist() // Calculate initial bias and norm val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { - (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), + val gJoinT0 = g.outerJoinVertices(t0) { + (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => - (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) - } + (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1)) + }.cache() + materialize(gJoinT0) + g.unpersist() + g = gJoinT0 def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ - (DoubleMatrix, DoubleMatrix, Double, Double), + (Array[Double], Array[Double], Double, Double), Double, - (DoubleMatrix, DoubleMatrix, Double)]) { + (Array[Double], Array[Double], Double)]) { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + q.dot(usr._2) + val rank = p.length + var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = ctx.attr - pred - val updateP = q.mul(err) - .subColumnVector(p.mul(conf.gamma7)) - .mul(conf.gamma2) - val updateQ = usr._2.mul(err) - .subColumnVector(q.mul(conf.gamma7)) - .mul(conf.gamma2) - val updateY = q.mul(err * usr._4) - .subColumnVector(itm._2.mul(conf.gamma7)) - .mul(conf.gamma2) + // updateP = (err * q - conf.gamma7 * p) * conf.gamma2 + val updateP = q.clone() + blas.dscal(rank, err * conf.gamma2, updateP, 1) + blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) + // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2 + val updateQ = usr._2.clone() + blas.dscal(rank, err * conf.gamma2, updateQ, 1) + blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) + // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2 + val updateY = q.clone() + blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) + blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } @@ -111,49 +130,89 @@ object SVDPlusPlus { for (i <- 0 until conf.maxIters) { // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes g.cache() - val t1 = g.aggregateMessages[DoubleMatrix]( + val t1 = g.aggregateMessages[Array[Double]]( ctx => ctx.sendToSrc(ctx.dstAttr._2), - (g1, g2) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { - (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), - msg: Option[DoubleMatrix]) => - if (msg.isDefined) (vd._1, vd._1 - .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd - } + (g1, g2) => { + val out = g1.clone() + blas.daxpy(out.length, 1.0, g2, 1, out, 1) + out + }) + val gJoinT1 = g.outerJoinVertices(t1) { + (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), + msg: Option[Array[Double]]) => + if (msg.isDefined) { + val out = vd._1.clone() + blas.daxpy(out.length, vd._4, msg.get, 1, out, 1) + (vd._1, out, vd._3, vd._4) + } else { + vd + } + }.cache() + materialize(gJoinT1) + g.unpersist() + g = gJoinT1 // Phase 2, update p for user nodes and q, y for item nodes g.cache() val t2 = g.aggregateMessages( sendMsgTrainF(conf, u), - (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => - (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) => + { + val out1 = g1._1.clone() + blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) + val out2 = g2._2.clone() + blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) + (out1, out2, g1._3 + g2._3) + }) + val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, - vd: (DoubleMatrix, DoubleMatrix, Double, Double), - msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => - (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), - vd._3 + msg.get._3, vd._4) - } + vd: (Array[Double], Array[Double], Double, Double), + msg: Option[(Array[Double], Array[Double], Double)]) => { + val out1 = vd._1.clone() + blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) + val out2 = vd._2.clone() + blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) + (out1, out2, vd._3 + msg.get._3, vd._4) + } + }.cache() + materialize(gJoinT2) + g.unpersist() + g = gJoinT2 } // calculate error on training set def sendMsgTestF(conf: Conf, u: Double) - (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) { + (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]) { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + q.dot(usr._2) + var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = (ctx.attr - pred) * (ctx.attr - pred) ctx.sendToDst(err) } + g.cache() val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) - g = g.outerJoinVertices(t3) { - (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => + val gJoinT3 = g.outerJoinVertices(t3) { + (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd - } + }.cache() + materialize(gJoinT3) + g.unpersist() + g = gJoinT3 - (g, u) + // Convert DoubleMatrix to Array[Double]: + val newVertices = g.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4)) + (Graph(newVertices, g.edges), u) } + + /** + * Forces materialization of a Graph by count()ing its RDDs. + */ + private def materialize(g: Graph[_, _]): Unit = { + g.vertices.count() + g.edges.count() + } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 590f0474957d..179f2843818e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -61,8 +61,8 @@ object ShortestPaths { } def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = { - val newAttr = incrementMap(edge.srcAttr) - if (edge.dstAttr != addMaps(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr)) + val newAttr = incrementMap(edge.dstAttr) + if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr)) else Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index daf162085e3e..a5d598053f9c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ */ object TriangleCount { - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Remove redundant edges val g = graph.groupEdges((a, b) => a).cache() @@ -49,7 +49,7 @@ object TriangleCount { var i = 0 while (i < nbrs.size) { // prevent self cycle - if(nbrs(i) != vid) { + if (nbrs(i) != vid) { set.add(nbrs(i)) } i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932..74a7de18d416 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { @@ -122,7 +121,7 @@ private[graphx] object BytecodeUtils { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { - methodsInvoked.add((Class.forName(owner.replace("/", ".")), name)) + methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 2d6a825b6172..989e22630526 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } @@ -243,14 +243,15 @@ object GraphGenerators { * @return A graph containing vertices with the row and column ids * as their attributes and edge values as 1.0. */ - def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = { + def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = { // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): VertexId = r * cols + c - val vertices: RDD[(VertexId, (Int,Int))] = - sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) ) + val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r => + (0 until cols).map( c => (sub2ind(r, c), (r, c)) ) + } val edges: RDD[Edge[Double]] = - vertices.flatMap{ case (vid, (r,c)) => + vertices.flatMap{ case (vid, (r, c)) => (if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++ (if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty }) }.map{ case (src, dst) => Edge(src, dst, 1.0) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 57b01b6f2e1f..e2754ea699da 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -56,7 +56,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = keySet.size + override def size: Int = keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -112,7 +112,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -128,9 +128,9 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 287c8e356350..eb3b1999eb99 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fd..f1ecc9e2219d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b41427..094a63472eaa 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 9bc8007ce49c..57a8b95dd12e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => @@ -59,7 +58,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test ("filter") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( @@ -67,11 +66,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, - vpred = (vid: VertexId, deg:Int) => deg > 0 + vpred = (vid: VertexId, deg: Int) => deg > 0 ).cache() val v = filteredGraph.vertices.collect().toSet - assert(v === Set((0,0))) + assert(v === Set((0, 0))) // the map is necessary because of object-reuse in the edge iterator val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index b61d9f0fbe5e..1f5e27d5508b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import com.google.common.io.Files - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") @@ -39,12 +36,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { val doubleRing = ring ++ ring val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1) assert(graph.edges.count() === doubleRing.size) - assert(graph.edges.collect.forall(e => e.attr == 1)) + assert(graph.edges.collect().forall(e => e.attr == 1)) // uniqueEdges option should uniquify edges and store duplicate count in edge attributes val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut)) assert(uniqueGraph.edges.count() === ring.size) - assert(uniqueGraph.edges.collect.forall(e => e.attr == 2)) + assert(uniqueGraph.edges.collect().forall(e => e.attr == 2)) } } @@ -65,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.collect().map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } @@ -76,15 +73,17 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet === - (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) + assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect().toSet + === (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) } } test("partitionBy") { withSpark { sc => - def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) - def nonemptyParts(graph: Graph[Int, Int]) = { + def mkGraph(edges: List[(Long, Long)]): Graph[Int, Int] = { + Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) + } + def nonemptyParts(graph: Graph[Int, Int]): RDD[List[Edge[Int]]] = { graph.edges.partitionsRDD.mapPartitions { iter => Iterator(iter.next()._2.iterator.toList) }.filter(_.nonEmpty) @@ -103,7 +102,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1) // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into // the same partition - assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) + assert( + nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) // partitionBy(EdgePartition2D) puts identical edges in the same partition assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1) @@ -141,10 +141,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val g = Graph( sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))), sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2)) - assert(g.triplets.collect.map(_.toTuple).toSet === + assert(g.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) val gPart = g.partitionBy(EdgePartition2D) - assert(gPart.triplets.collect.map(_.toTuple).toSet === + assert(gPart.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) } } @@ -155,10 +155,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val star = starGraph(sc, n) // mapVertices preserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") - assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) + assert(mappedVAttrs.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length) - assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) + assert(mappedVAttrs2.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -178,12 +178,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Trigger initial vertex replication graph0.triplets.foreach(x => {}) // Change type of replicated vertices, but preserve erased type - val graph1 = graph0.mapVertices { - case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + val graph1 = graph0.mapVertices { case (vid, integerOpt) => + integerOpt.map((x: java.lang.Integer) => x.toDouble: java.lang.Double) } // Access replicated vertices, exposing the erased type val graph2 = graph1.mapTriplets(t => t.srcAttr.get) - assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + assert(graph2.edges.map(_.attr).collect().toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) } } @@ -203,7 +203,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet === + assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect().toSet === (1L to n).map(x => Edge(0, x, "vv")).toSet) } } @@ -212,7 +212,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) + assert(star.reverse.outDegrees.collect().toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -222,7 +222,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(result.collect.toSet === Set((1L, 2))) + assert(result.collect().toSet === Set((1L, 2))) } } @@ -238,14 +238,15 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet) // And 4 edges. - assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet) + assert(subgraph.edges.map(_.copy()).collect().toSet === + (2 to n by 2).map(x => Edge(0, x, 1)).toSet) } } test("mask") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() @@ -257,11 +258,11 @@ class GraphSuite extends FunSuite with LocalSparkContext { val projectedGraph = graph.mask(subgraph) val v = projectedGraph.vertices.collect().toSet - assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5))) + assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5))) // the map is necessary because of object-reuse in the edge iterator val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet - assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5))) + assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5))) } } @@ -274,9 +275,9 @@ class GraphSuite extends FunSuite with LocalSparkContext { sc.parallelize((1 to n).flatMap(x => List((0: VertexId, x: VertexId), (0: VertexId, x: VertexId))), 1), "v") val star2 = doubleStar.groupEdges { (a, b) => a} - assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) === - star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int])) - assert(star2.vertices.collect.toSet === star.vertices.collect.toSet) + assert(star2.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]) === + star.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int])) + assert(star2.vertices.collect().toSet === star.vertices.collect().toSet) } } @@ -301,21 +302,23 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) } Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet + }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x+1) % n: VertexId)), 3) + val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) } + val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => + newOpt.getOrElse(old) + } val numOddNeighbors = changedGraph.mapReduceTriplets(et => { // Map function should only run on edges with source in the active set if (et.srcId % 2 != 1) { throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) } Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet + }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) } @@ -341,17 +344,18 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type - val reverseStarDegrees = - reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } + val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { + (vid, a, bOpt) => bOpt.getOrElse(0) + } val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), - (a: Int, b: Int) => a + b).collect.toSet + (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } - assert(newReverseStar.vertices.map(_._2).collect.toSet === + assert(newReverseStar.vertices.map(_._2).collect().toSet === (0 to n).map(x => "v%d".format(x)).toSet) } } @@ -362,15 +366,14 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) val graph = Graph(verts, edges) val triplets = graph.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)) - .collect.toSet + .collect().toSet assert(triplets === Set((1: VertexId, 2: VertexId, "a", "b"), (2: VertexId, 1: VertexId, "b", "a"))) } } test("checkpoint") { - val checkpointDir = Files.createTempDir() - checkpointDir.deleteOnExit() + val checkpointDir = Utils.createTempDir() withSpark { sc => sc.setCheckpointDir(checkpointDir.getAbsolutePath) val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)} @@ -419,7 +422,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val neighborAttrSums = graph.mapReduceTriplets[Int]( et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) } finally { sc.stop() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index a3e28efc75a9..d2ad9be55577 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext */ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ - def withSpark[T](f: SparkContext => T) = { + def withSpark[T](f: SparkContext => T): T = { val conf = new SparkConf() GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1..8afa2d403b53 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 131959cea3ef..f1aa685a79c9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { - def vertices(sc: SparkContext, n: Int) = { + private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) } @@ -46,6 +45,35 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("minus") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) + } + } + + test("minus with RDD[(VertexId, VD)]") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB: RDD[(VertexId, Int)] = + sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) + } + } + + test("minus with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 5).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect().toSet === (0 until 50).toSet) + } + } + test("diff") { withSpark { sc => val n = 100 @@ -58,42 +86,90 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("diff with RDD[(VertexId, VD)]") { + withSpark { sc => + val n = 100 + val verts = vertices(sc, n).cache() + val flipEvens: RDD[(VertexId, Int)] = + sc.parallelize(0L to 100L) + .map(id => if (id % 2 == 0) (id, -id.toInt) else (id, id.toInt)).cache() + // diff should keep only the changed vertices + assert(verts.diff(flipEvens).map(_._2).collect().toSet === (2 to n by 2).map(-_).toSet) + } + } + + test("diff vertices with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 24, 3).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.diff(vertexB) + assert(vertexC.map(_._1).collect().toSet === (8 until 16).toSet) + } + } + test("leftJoin") { withSpark { sc => val n = 100 val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // leftJoin with another VertexRDD - assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) // leftJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) } } + test("leftJoin vertices with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) + val vertexB = VertexRDD( + vertexA.filter(v => v._1 % 2 == 0).partitionBy(new HashPartitioner(3))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.leftJoin(vertexB) { (vid, old, newOpt) => + old - newOpt.getOrElse(0) + } + assert(vertexC.filter(v => v._2 != 0).map(_._1).collect().toSet == (1 to 99 by 2).toSet) + } + } + test("innerJoin") { withSpark { sc => val n = 100 val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // innerJoin with another VertexRDD - assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) // innerJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) } } + test("innerJoin vertices with the non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) + val vertexB = VertexRDD( + vertexA.filter(v => v._1 % 2 == 0).partitionBy(new HashPartitioner(3))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.innerJoin(vertexB) { (vid, old, newVal) => + old - newVal + } + assert(vertexC.filter(v => v._2 == 0).map(_._1).collect().toSet == (0 to 98 by 2).toSet) + } + } + test("aggregateUsingIndex") { withSpark { sc => val n = 100 val verts = vertices(sc, n) val messageTargets = (0 to n) ++ (0 to n by 2) val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1))) - assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet === + assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect().toSet === (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet) } } @@ -105,7 +181,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) // test merge function - assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9))) + assert(rdd.collect().toSet == Set((0L, 0), (1L, 3), (2L, 9))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02e..7435647c6d9e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc3..1203f8959f50 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 3915be15b343..c965a6eb8df1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,22 +17,20 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -42,7 +40,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -50,15 +48,18 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - if(id < 10) { assert(cc === 0) } - else { assert(cc === 10) } + if (id < 10) { + assert(cc === 0) + } else { + assert(cc === 10) + } } val ccMap = vertices.toMap for (id <- 0 until 20) { @@ -73,12 +74,12 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if (id < 10) { assert(cc === 0) @@ -106,9 +107,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { (4L, ("peter", "student")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = - sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), + sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"), - Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) + Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) // Edges are: // 2 ---> 5 ---> 3 // | \ @@ -120,9 +121,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { // Build the initial Graph val graph = Graph(users, relationships, defaultUser) val ccGraph = graph.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - assert(cc == 0) + assert(cc === 0) } } } // end of toy connected components diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c460556..808877f0590f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index fc491ae327c2..45f1e3011035 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,31 +17,27 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ + object GridPageRank { - def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { + def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double): Seq[(VertexId, Double)] = { val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int]) val outDegree = Array.fill(nRows * nCols)(0) // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): Int = r * nCols + c // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { - val ind = sub2ind(r,c) - if (r+1 < nRows) { + val ind = sub2ind(r, c) + if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r+1,c)) += ind + inNbrs(sub2ind(r + 1, c)) += ind } - if (c+1 < nCols) { + if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c+1)) += ind + inNbrs(sub2ind(r, c + 1)) += ind } } // compute the pagerank @@ -60,11 +56,11 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } - .map { case (id, error) => error }.sum + .map { case (id, error) => error }.sum() } test("Star PageRank") { @@ -80,12 +76,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum + }.map { case (vid, test) => test }.sum() assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5) + val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) + val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) @@ -95,7 +91,35 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank + test("Star PersonalPageRank") { + withSpark { sc => + val nVertices = 100 + val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() + val resetProb = 0.15 + val errorTol = 1.0e-5 + val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) + .vertices.cache() + + // Static PageRank should only take 2 iterations to converge + val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => + if (pr1 != pr2) 1 else 0 + }.map { case (vid, test) => test }.sum + assert(notMatching === 0) + + val staticErrors = staticRanks2.map { case (vid, pr) => + val correct = (vid > 0 && pr == resetProb) || + (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * + (nVertices - 1)) )) < 1.0E-5) + if (!correct) 1 else 0 + } + assert(staticErrors.sum === 0) + + val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() + assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + } + } // end of test Star PageRank test("Grid PageRank") { withSpark { sc => @@ -109,18 +133,18 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() - val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() + val referenceRanks = VertexRDD( + sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) } } // end of Grid PageRank - test("Chain PageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 @@ -133,4 +157,21 @@ class PageRankSuite extends FunSuite with LocalSparkContext { assert(compareRanks(staticRanks, dynamicRanks) < errorTol) } } + + test("Chain PersonalizedPageRank") { + withSpark { sc => + val chain1 = (0 until 9).map(x => (x, x + 1) ) + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } + val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 10 + val errorTol = 1.0e-1 + + val staticRanks = chain.staticPersonalizedPageRank(4, numIter, resetProb).vertices + val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices + + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + } + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index e01df56e94de..2991438f5e57 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => @@ -32,11 +31,11 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations - var (graph, u) = SVDPlusPlus.run(edges, conf) + val (graph, _) = SVDPlusPlus.run(edges, conf) graph.cache() - val err = graph.vertices.collect().map{ case (vid, vd) => + val err = graph.vertices.map { case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 - }.reduce(_ + _) / graph.triplets.collect().size + }.reduce(_ + _) / graph.numEdges assert(err <= svdppErr) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index 265827b3341c..d7eaa70ce640 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => @@ -40,7 +38,7 @@ class ShortestPathsSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val landmarks = Seq(1, 4).map(_.toLong) val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { - case (v, spMap) => (v, spMap.mapValues(_.get)) + case (v, spMap) => (v, spMap.mapValues(i => i)) } assert(results.toSet === shortestPaths) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index df54aa37cad6..d6b03208180d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => @@ -34,8 +32,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) val sccGraph = graph.stronglyConnectedComponents(5) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(id === scc) } } } @@ -45,8 +43,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(0L == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(0L === scc) } } } @@ -60,13 +58,14 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - if (id < 3) - assert(0L == scc) - else if (id < 6) - assert(3L == scc) - else - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + if (id < 3) { + assert(0L === scc) + } else if (id < 6) { + assert(3L === scc) + } else { + assert(id === scc) + } } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 293c7f3ba4c2..c47552cf3a3b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => @@ -58,7 +57,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) - val revTriangles = triangles.map { case (a,b) => (b,a) } + val revTriangles = triangles.map { case (a, b) => (b, a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0da..61e44dcab578 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +// scalastyle:off println +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends FunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3..32e0c841c699 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/launcher/pom.xml b/launcher/pom.xml new file mode 100644 index 000000000000..2fd768d8119c --- /dev/null +++ b/launcher/pom.xml @@ -0,0 +1,78 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-launcher_2.10 + jar + Spark Project Launcher + http://spark.apache.org/ + + launcher + + + + + + log4j + log4j + test + + + junit + junit + test + + + org.mockito + mockito-core + test + + + org.slf4j + slf4j-api + test + + + org.slf4j + slf4j-log4j12 + test + + + + + org.apache.hadoop + hadoop-client + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java new file mode 100644 index 000000000000..5e793a5c4877 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileFilter; +import java.io.FileInputStream; +import java.io.InputStreamReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.jar.JarFile; +import java.util.regex.Pattern; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Abstract Spark command builder that defines common functionality. + */ +abstract class AbstractCommandBuilder { + + boolean verbose; + String appName; + String appResource; + String deployMode; + String javaHome; + String mainClass; + String master; + String propertiesFile; + final List appArgs; + final List jars; + final List files; + final List pyFiles; + final Map childEnv; + final Map conf; + + public AbstractCommandBuilder() { + this.appArgs = new ArrayList(); + this.childEnv = new HashMap(); + this.conf = new HashMap(); + this.files = new ArrayList(); + this.jars = new ArrayList(); + this.pyFiles = new ArrayList(); + } + + /** + * Builds the command to execute. + * + * @param env A map containing environment variables for the child process. It may already contain + * entries defined by the user (such as SPARK_HOME, or those defined by the + * SparkLauncher constructor that takes an environment), and may be modified to + * include other variables needed by the process to be executed. + */ + abstract List buildCommand(Map env) throws IOException; + + /** + * Builds a list of arguments to run java. + * + * This method finds the java executable to use and appends JVM-specific options for running a + * class with Spark in the classpath. It also loads options from the "java-opts" file in the + * configuration directory being used. + * + * Callers should still add at least the class to run, as well as any arguments to pass to the + * class. + */ + List buildJavaCommand(String extraClassPath) throws IOException { + List cmd = new ArrayList(); + String envJavaHome; + + if (javaHome != null) { + cmd.add(join(File.separator, javaHome, "bin", "java")); + } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) { + cmd.add(join(File.separator, envJavaHome, "bin", "java")); + } else { + cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); + } + + // Load extra JAVA_OPTS from conf/java-opts, if it exists. + File javaOpts = new File(join(File.separator, getConfDir(), "java-opts")); + if (javaOpts.isFile()) { + BufferedReader br = new BufferedReader(new InputStreamReader( + new FileInputStream(javaOpts), "UTF-8")); + try { + String line; + while ((line = br.readLine()) != null) { + addOptionString(cmd, line); + } + } finally { + br.close(); + } + } + + cmd.add("-cp"); + cmd.add(join(File.pathSeparator, buildClassPath(extraClassPath))); + return cmd; + } + + /** + * Adds the default perm gen size option for Spark if the VM requires it and the user hasn't + * set it. + */ + void addPermGenSizeOpt(List cmd) { + // Don't set MaxPermSize for IBM Java, or Oracle Java 8 and later. + if (getJavaVendor() == JavaVendor.IBM) { + return; + } + String[] version = System.getProperty("java.version").split("\\."); + if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { + return; + } + + for (String arg : cmd) { + if (arg.startsWith("-XX:MaxPermSize=")) { + return; + } + } + + cmd.add("-XX:MaxPermSize=256m"); + } + + void addOptionString(List cmd, String options) { + if (!isEmpty(options)) { + for (String opt : parseOptionString(options)) { + cmd.add(opt); + } + } + } + + /** + * Builds the classpath for the application. Returns a list with one classpath entry per element; + * each entry is formatted in the way expected by java.net.URLClassLoader (more + * specifically, with trailing slashes for directories). + */ + List buildClassPath(String appClassPath) throws IOException { + String sparkHome = getSparkHome(); + + List cp = new ArrayList(); + addToClassPath(cp, getenv("SPARK_CLASSPATH")); + addToClassPath(cp, appClassPath); + + addToClassPath(cp, getConfDir()); + + boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES")); + boolean isTesting = "1".equals(getenv("SPARK_TESTING")); + if (prependClasses || isTesting) { + String scala = getScalaVersion(); + List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", + "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", + "yarn", "launcher"); + if (prependClasses) { + System.err.println( + "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + + "assembly."); + for (String project : projects) { + addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, + scala)); + } + } + if (isTesting) { + for (String project : projects) { + addToClassPath(cp, String.format("%s/%s/target/scala-%s/test-classes", sparkHome, + project, scala)); + } + } + + // Add this path to include jars that are shaded in the final deliverable created during + // the maven build. These jars are copied to this directory during the build. + addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); + } + + // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as + // when running unit tests, or user code that embeds Spark and creates a SparkContext + // with a local or local-cluster master, will cause this code to be called from an + // environment where that env variable is not guaranteed to exist. + // + // For the testing case, we rely on the test code to set and propagate the test classpath + // appropriately. + // + // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. + // That duplicates some of the code in the shell scripts that look for the assembly, though. + String assembly = getenv(ENV_SPARK_ASSEMBLY); + if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + assembly = findAssembly(); + } + addToClassPath(cp, assembly); + + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + } else { + libdir = new File(sparkHome, "lib_managed/jars"); + } + + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); + } + } + + addToClassPath(cp, getenv("HADOOP_CONF_DIR")); + addToClassPath(cp, getenv("YARN_CONF_DIR")); + addToClassPath(cp, getenv("SPARK_DIST_CLASSPATH")); + return cp; + } + + /** + * Adds entries to the classpath. + * + * @param cp List to which the new entries are appended. + * @param entries New classpath entries (separated by File.pathSeparator). + */ + private void addToClassPath(List cp, String entries) { + if (isEmpty(entries)) { + return; + } + String[] split = entries.split(Pattern.quote(File.pathSeparator)); + for (String entry : split) { + if (!isEmpty(entry)) { + if (new File(entry).isDirectory() && !entry.endsWith(File.separator)) { + entry += File.separator; + } + cp.add(entry); + } + } + } + + String getScalaVersion() { + String scala = getenv("SPARK_SCALA_VERSION"); + if (scala != null) { + return scala; + } + String sparkHome = getSparkHome(); + File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); + File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); + checkState(!scala210.isDirectory() || !scala211.isDirectory(), + "Presence of build for both scala versions (2.10 and 2.11) detected.\n" + + "Either clean one of them or set SPARK_SCALA_VERSION in your environment."); + if (scala210.isDirectory()) { + return "2.10"; + } else { + checkState(scala211.isDirectory(), "Cannot find any assembly build directories."); + return "2.11"; + } + } + + String getSparkHome() { + String path = getenv(ENV_SPARK_HOME); + checkState(path != null, + "Spark home not found; set it explicitly or use the SPARK_HOME environment variable."); + return path; + } + + /** + * Loads the configuration file for the application, if it exists. This is either the + * user-specified properties file, or the spark-defaults.conf file under the Spark configuration + * directory. + */ + Properties loadPropertiesFile() throws IOException { + Properties props = new Properties(); + File propsFile; + if (propertiesFile != null) { + propsFile = new File(propertiesFile); + checkArgument(propsFile.isFile(), "Invalid properties file '%s'.", propertiesFile); + } else { + propsFile = new File(getConfDir(), DEFAULT_PROPERTIES_FILE); + } + + if (propsFile.isFile()) { + FileInputStream fd = null; + try { + fd = new FileInputStream(propsFile); + props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } + } finally { + if (fd != null) { + try { + fd.close(); + } catch (IOException e) { + // Ignore. + } + } + } + } + + return props; + } + + String getenv(String key) { + return firstNonEmpty(childEnv.get(key), System.getenv(key)); + } + + private String findAssembly() { + String sparkHome = getSparkHome(); + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); + } + + final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); + FileFilter filter = new FileFilter() { + @Override + public boolean accept(File file) { + return file.isFile() && re.matcher(file.getName()).matches(); + } + }; + File[] assemblies = libdir.listFiles(filter); + checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); + checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); + return assemblies[0].getAbsolutePath(); + } + + private String getConfDir() { + String confDir = getenv("SPARK_CONF_DIR"); + return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java new file mode 100644 index 000000000000..a16c0d2b5ca0 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Helper methods for command builders. + */ +class CommandBuilderUtils { + + static final String DEFAULT_MEM = "1g"; + static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; + static final String ENV_SPARK_HOME = "SPARK_HOME"; + static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; + + /** The set of known JVM vendors. */ + static enum JavaVendor { + Oracle, IBM, OpenJDK, Unknown + }; + + /** Returns whether the given string is null or empty. */ + static boolean isEmpty(String s) { + return s == null || s.isEmpty(); + } + + /** Joins a list of strings using the given separator. */ + static String join(String sep, String... elements) { + StringBuilder sb = new StringBuilder(); + for (String e : elements) { + if (e != null) { + if (sb.length() > 0) { + sb.append(sep); + } + sb.append(e); + } + } + return sb.toString(); + } + + /** Joins a list of strings using the given separator. */ + static String join(String sep, Iterable elements) { + StringBuilder sb = new StringBuilder(); + for (String e : elements) { + if (e != null) { + if (sb.length() > 0) { + sb.append(sep); + } + sb.append(e); + } + } + return sb.toString(); + } + + /** + * Returns the first non-empty value mapped to the given key in the given maps, or null otherwise. + */ + static String firstNonEmptyValue(String key, Map... maps) { + for (Map map : maps) { + String value = (String) map.get(key); + if (!isEmpty(value)) { + return value; + } + } + return null; + } + + /** Returns the first non-empty, non-null string in the given list, or null otherwise. */ + static String firstNonEmpty(String... candidates) { + for (String s : candidates) { + if (!isEmpty(s)) { + return s; + } + } + return null; + } + + /** Returns the name of the env variable that holds the native library path. */ + static String getLibPathEnvName() { + if (isWindows()) { + return "PATH"; + } + + String os = System.getProperty("os.name"); + if (os.startsWith("Mac OS X")) { + return "DYLD_LIBRARY_PATH"; + } else { + return "LD_LIBRARY_PATH"; + } + } + + /** Returns whether the OS is Windows. */ + static boolean isWindows() { + String os = System.getProperty("os.name"); + return os.startsWith("Windows"); + } + + /** Returns an enum value indicating whose JVM is being used. */ + static JavaVendor getJavaVendor() { + String vendorString = System.getProperty("java.vendor"); + if (vendorString.contains("Oracle")) { + return JavaVendor.Oracle; + } + if (vendorString.contains("IBM")) { + return JavaVendor.IBM; + } + if (vendorString.contains("OpenJDK")) { + return JavaVendor.OpenJDK; + } + return JavaVendor.Unknown; + } + + /** + * Updates the user environment, appending the given pathList to the existing value of the given + * environment variable (or setting it if it hasn't yet been set). + */ + static void mergeEnvPathList(Map userEnv, String envKey, String pathList) { + if (!isEmpty(pathList)) { + String current = firstNonEmpty(userEnv.get(envKey), System.getenv(envKey)); + userEnv.put(envKey, join(File.pathSeparator, current, pathList)); + } + } + + /** + * Parse a string as if it were a list of arguments, following bash semantics. + * For example: + * + * Input: "\"ab cd\" efgh 'i \" j'" + * Output: [ "ab cd", "efgh", "i \" j" ] + */ + static List parseOptionString(String s) { + List opts = new ArrayList(); + StringBuilder opt = new StringBuilder(); + boolean inOpt = false; + boolean inSingleQuote = false; + boolean inDoubleQuote = false; + boolean escapeNext = false; + + // This is needed to detect when a quoted empty string is used as an argument ("" or ''). + boolean hasData = false; + + for (int i = 0; i < s.length(); i++) { + int c = s.codePointAt(i); + if (escapeNext) { + opt.appendCodePoint(c); + escapeNext = false; + } else if (inOpt) { + switch (c) { + case '\\': + if (inSingleQuote) { + opt.appendCodePoint(c); + } else { + escapeNext = true; + } + break; + case '\'': + if (inDoubleQuote) { + opt.appendCodePoint(c); + } else { + inSingleQuote = !inSingleQuote; + } + break; + case '"': + if (inSingleQuote) { + opt.appendCodePoint(c); + } else { + inDoubleQuote = !inDoubleQuote; + } + break; + default: + if (!Character.isWhitespace(c) || inSingleQuote || inDoubleQuote) { + opt.appendCodePoint(c); + } else { + opts.add(opt.toString()); + opt.setLength(0); + inOpt = false; + hasData = false; + } + } + } else { + switch (c) { + case '\'': + inSingleQuote = true; + inOpt = true; + hasData = true; + break; + case '"': + inDoubleQuote = true; + inOpt = true; + hasData = true; + break; + case '\\': + escapeNext = true; + inOpt = true; + hasData = true; + break; + default: + if (!Character.isWhitespace(c)) { + inOpt = true; + hasData = true; + opt.appendCodePoint(c); + } + } + } + } + + checkArgument(!inSingleQuote && !inDoubleQuote && !escapeNext, "Invalid option string: %s", s); + if (hasData) { + opts.add(opt.toString()); + } + return opts; + } + + /** Throws IllegalArgumentException if the given object is null. */ + static void checkNotNull(Object o, String arg) { + if (o == null) { + throw new IllegalArgumentException(String.format("'%s' must not be null.", arg)); + } + } + + /** Throws IllegalArgumentException with the given message if the check is false. */ + static void checkArgument(boolean check, String msg, Object... args) { + if (!check) { + throw new IllegalArgumentException(String.format(msg, args)); + } + } + + /** Throws IllegalStateException with the given message if the check is false. */ + static void checkState(boolean check, String msg, Object... args) { + if (!check) { + throw new IllegalStateException(String.format(msg, args)); + } + } + + /** + * Quote a command argument for a command to be run by a Windows batch script, if the argument + * needs quoting. Arguments only seem to need quotes in batch scripts if they have certain + * special characters, some of which need extra (and different) escaping. + * + * For example: + * original single argument: ab="cde fgh" + * quoted: "ab^=""cde fgh""" + */ + static String quoteForBatchScript(String arg) { + + boolean needsQuotes = false; + for (int i = 0; i < arg.length(); i++) { + int c = arg.codePointAt(i); + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { + needsQuotes = true; + break; + } + } + if (!needsQuotes) { + return arg; + } + StringBuilder quoted = new StringBuilder(); + quoted.append("\""); + for (int i = 0; i < arg.length(); i++) { + int cp = arg.codePointAt(i); + switch (cp) { + case '"': + quoted.append('"'); + break; + + default: + break; + } + quoted.appendCodePoint(cp); + } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } + quoted.append("\""); + return quoted.toString(); + } + + /** + * Quotes a string so that it can be used in a command string. + * Basically, just add simple escapes. E.g.: + * original single argument : ab "cd" ef + * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. + */ + static String quoteForCommandString(String s) { + StringBuilder quoted = new StringBuilder().append('"'); + for (int i = 0; i < s.length(); i++) { + int cp = s.codePointAt(i); + if (cp == '"' || cp == '\\') { + quoted.appendCodePoint('\\'); + } + quoted.appendCodePoint(cp); + } + return quoted.append('"').toString(); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java new file mode 100644 index 000000000000..a4e3acc674f3 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Command line interface for the Spark launcher. Used internally by Spark scripts. + */ +class Main { + + /** + * Usage: Main [class] [class args] + *

    + * This CLI works in two different modes: + *

      + *
    • "spark-submit": if class is "org.apache.spark.deploy.SparkSubmit", the + * {@link SparkLauncher} class is used to launch a Spark application.
    • + *
    • "spark-class": if another class is provided, an internal Spark class is run.
    • + *
    + * + * This class works in tandem with the "bin/spark-class" script on Unix-like systems, and + * "bin/spark-class2.cmd" batch script on Windows to execute the final command. + *

    + * On Unix-like systems, the output is a list of command arguments, separated by the NULL + * character. On Windows, the output is a command line suitable for direct execution from the + * script. + */ + public static void main(String[] argsArray) throws Exception { + checkArgument(argsArray.length > 0, "Not enough arguments: missing class name."); + + List args = new ArrayList(Arrays.asList(argsArray)); + String className = args.remove(0); + + boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); + AbstractCommandBuilder builder; + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + try { + builder = new SparkSubmitCommandBuilder(args); + } catch (IllegalArgumentException e) { + printLaunchCommand = false; + System.err.println("Error: " + e.getMessage()); + System.err.println(); + + MainClassOptionParser parser = new MainClassOptionParser(); + try { + parser.parse(args); + } catch (Exception ignored) { + // Ignore parsing exceptions. + } + + List help = new ArrayList(); + if (parser.className != null) { + help.add(parser.CLASS); + help.add(parser.className); + } + help.add(parser.USAGE_ERROR); + builder = new SparkSubmitCommandBuilder(help); + } + } else { + builder = new SparkClassCommandBuilder(className, args); + } + + Map env = new HashMap(); + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + + if (isWindows()) { + System.out.println(prepareWindowsCommand(cmd, env)); + } else { + // In bash, use NULL as the arg separator since it cannot be used in an argument. + List bashCmd = prepareBashCommand(cmd, env); + for (String c : bashCmd) { + System.out.print(c); + System.out.print('\0'); + } + } + } + + /** + * Prepare a command line for execution from a Windows batch script. + * + * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments + * are "double quoted" (which is batch for escaping a quote). This page has more details about + * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html + */ + private static String prepareWindowsCommand(List cmd, Map childEnv) { + StringBuilder cmdline = new StringBuilder(); + for (Map.Entry e : childEnv.entrySet()) { + cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); + cmdline.append(" && "); + } + for (String arg : cmd) { + cmdline.append(quoteForBatchScript(arg)); + cmdline.append(" "); + } + return cmdline.toString(); + } + + /** + * Prepare the command for execution from a bash script. The final command will have commands to + * set up any needed environment variables needed by the child process. + */ + private static List prepareBashCommand(List cmd, Map childEnv) { + if (childEnv.isEmpty()) { + return cmd; + } + + List newCmd = new ArrayList(); + newCmd.add("env"); + + for (Map.Entry e : childEnv.entrySet()) { + newCmd.add(String.format("%s=%s", e.getKey(), e.getValue())); + } + newCmd.addAll(cmd); + return newCmd; + } + + /** + * A parser used when command line parsing fails for spark-submit. It's used as a best-effort + * at trying to identify the class the user wanted to invoke, since that may require special + * usage strings (handled by SparkSubmitArguments). + */ + private static class MainClassOptionParser extends SparkSubmitOptionParser { + + String className; + + @Override + protected boolean handle(String opt, String value) { + if (opt == CLASS) { + className = value; + } + return false; + } + + @Override + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java new file mode 100644 index 000000000000..931a24cfd4b1 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Command builder for internal Spark classes. + *

    + * This class handles building the command to launch all internal Spark classes except for + * SparkSubmit (which is handled by {@link SparkSubmitCommandBuilder} class. + */ +class SparkClassCommandBuilder extends AbstractCommandBuilder { + + private final String className; + private final List classArgs; + + SparkClassCommandBuilder(String className, List classArgs) { + this.className = className; + this.classArgs = classArgs; + } + + @Override + public List buildCommand(Map env) throws IOException { + List javaOptsKeys = new ArrayList(); + String memKey = null; + String extraClassPath = null; + + // Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + + // SPARK_DAEMON_MEMORY. + if (className.equals("org.apache.spark.deploy.master.Master")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_MASTER_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.worker.Worker")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_WORKER_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.history.HistoryServer")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_HISTORY_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.equals("org.apache.spark.executor.CoarseGrainedExecutorBackend")) { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); + memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { + javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); + memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") || + className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.startsWith("org.apache.spark.tools.")) { + String sparkHome = getSparkHome(); + File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", + "scala-" + getScalaVersion())); + checkState(toolsDir.isDirectory(), "Cannot find tools build directory."); + + Pattern re = Pattern.compile("spark-tools_.*\\.jar"); + for (File f : toolsDir.listFiles()) { + if (re.matcher(f.getName()).matches()) { + extraClassPath = f.getAbsolutePath(); + break; + } + } + + checkState(extraClassPath != null, + "Failed to find Spark Tools Jar in %s.\n" + + "You need to run \"build/sbt tools/package\" before running %s.", + toolsDir.getAbsolutePath(), className); + + javaOptsKeys.add("SPARK_JAVA_OPTS"); + } else { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + memKey = "SPARK_DRIVER_MEMORY"; + } + + List cmd = buildJavaCommand(extraClassPath); + for (String key : javaOptsKeys) { + addOptionString(cmd, System.getenv(key)); + } + + String mem = firstNonEmpty(memKey != null ? System.getenv(memKey) : null, DEFAULT_MEM); + cmd.add("-Xms" + mem); + cmd.add("-Xmx" + mem); + addPermGenSizeOpt(cmd); + cmd.add(className); + cmd.addAll(classArgs); + return cmd; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java new file mode 100644 index 000000000000..57993405e47b --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Launcher for Spark applications. + *

    + * Use this class to start Spark applications programmatically. The class uses a builder pattern + * to allow clients to configure the Spark application and launch it as a child process. + *

    + */ +public class SparkLauncher { + + /** The Spark master. */ + public static final String SPARK_MASTER = "spark.master"; + + /** Configuration key for the driver memory. */ + public static final String DRIVER_MEMORY = "spark.driver.memory"; + /** Configuration key for the driver class path. */ + public static final String DRIVER_EXTRA_CLASSPATH = "spark.driver.extraClassPath"; + /** Configuration key for the driver VM options. */ + public static final String DRIVER_EXTRA_JAVA_OPTIONS = "spark.driver.extraJavaOptions"; + /** Configuration key for the driver native library path. */ + public static final String DRIVER_EXTRA_LIBRARY_PATH = "spark.driver.extraLibraryPath"; + + /** Configuration key for the executor memory. */ + public static final String EXECUTOR_MEMORY = "spark.executor.memory"; + /** Configuration key for the executor class path. */ + public static final String EXECUTOR_EXTRA_CLASSPATH = "spark.executor.extraClassPath"; + /** Configuration key for the executor VM options. */ + public static final String EXECUTOR_EXTRA_JAVA_OPTIONS = "spark.executor.extraJavaOptions"; + /** Configuration key for the executor native library path. */ + public static final String EXECUTOR_EXTRA_LIBRARY_PATH = "spark.executor.extraLibraryPath"; + /** Configuration key for the number of executor CPU cores. */ + public static final String EXECUTOR_CORES = "spark.executor.cores"; + + // Visible for testing. + final SparkSubmitCommandBuilder builder; + + public SparkLauncher() { + this(null); + } + + /** + * Creates a launcher that will set the given environment variables in the child. + * + * @param env Environment variables to set. + */ + public SparkLauncher(Map env) { + this.builder = new SparkSubmitCommandBuilder(); + if (env != null) { + this.builder.childEnv.putAll(env); + } + } + + /** + * Set a custom JAVA_HOME for launching the Spark application. + * + * @param javaHome Path to the JAVA_HOME to use. + * @return This launcher. + */ + public SparkLauncher setJavaHome(String javaHome) { + checkNotNull(javaHome, "javaHome"); + builder.javaHome = javaHome; + return this; + } + + /** + * Set a custom Spark installation location for the application. + * + * @param sparkHome Path to the Spark installation to use. + * @return This launcher. + */ + public SparkLauncher setSparkHome(String sparkHome) { + checkNotNull(sparkHome, "sparkHome"); + builder.childEnv.put(ENV_SPARK_HOME, sparkHome); + return this; + } + + /** + * Set a custom properties file with Spark configuration for the application. + * + * @param path Path to custom properties file to use. + * @return This launcher. + */ + public SparkLauncher setPropertiesFile(String path) { + checkNotNull(path, "path"); + builder.propertiesFile = path; + return this; + } + + /** + * Set a single configuration value for the application. + * + * @param key Configuration key. + * @param value The value to use. + * @return This launcher. + */ + public SparkLauncher setConf(String key, String value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + checkArgument(key.startsWith("spark."), "'key' must start with 'spark.'"); + builder.conf.put(key, value); + return this; + } + + /** + * Set the application name. + * + * @param appName Application name. + * @return This launcher. + */ + public SparkLauncher setAppName(String appName) { + checkNotNull(appName, "appName"); + builder.appName = appName; + return this; + } + + /** + * Set the Spark master for the application. + * + * @param master Spark master. + * @return This launcher. + */ + public SparkLauncher setMaster(String master) { + checkNotNull(master, "master"); + builder.master = master; + return this; + } + + /** + * Set the deploy mode for the application. + * + * @param mode Deploy mode. + * @return This launcher. + */ + public SparkLauncher setDeployMode(String mode) { + checkNotNull(mode, "mode"); + builder.deployMode = mode; + return this; + } + + /** + * Set the main application resource. This should be the location of a jar file for Scala/Java + * applications, or a python script for PySpark applications. + * + * @param resource Path to the main application resource. + * @return This launcher. + */ + public SparkLauncher setAppResource(String resource) { + checkNotNull(resource, "resource"); + builder.appResource = resource; + return this; + } + + /** + * Sets the application class name for Java/Scala applications. + * + * @param mainClass Application's main class. + * @return This launcher. + */ + public SparkLauncher setMainClass(String mainClass) { + checkNotNull(mainClass, "mainClass"); + builder.mainClass = mainClass; + return this; + } + + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

    + * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + + /** + * Adds command line arguments for the application. + * + * @param args Arguments to pass to the application's main class. + * @return This launcher. + */ + public SparkLauncher addAppArgs(String... args) { + for (String arg : args) { + checkNotNull(arg, "arg"); + builder.appArgs.add(arg); + } + return this; + } + + /** + * Adds a jar file to be submitted with the application. + * + * @param jar Path to the jar file. + * @return This launcher. + */ + public SparkLauncher addJar(String jar) { + checkNotNull(jar, "jar"); + builder.jars.add(jar); + return this; + } + + /** + * Adds a file to be submitted with the application. + * + * @param file Path to the file. + * @return This launcher. + */ + public SparkLauncher addFile(String file) { + checkNotNull(file, "file"); + builder.files.add(file); + return this; + } + + /** + * Adds a python file / zip / egg to be submitted with the application. + * + * @param file Path to the file. + * @return This launcher. + */ + public SparkLauncher addPyFile(String file) { + checkNotNull(file, "file"); + builder.pyFiles.add(file); + return this; + } + + /** + * Enables verbose reporting for SparkSubmit. + * + * @param verbose Whether to enable verbose output. + * @return This launcher. + */ + public SparkLauncher setVerbose(boolean verbose) { + builder.verbose = verbose; + return this; + } + + /** + * Launches a sub-process that will start the configured Spark application. + * + * @return A process handle for the Spark app. + */ + public Process launch() throws IOException { + List cmd = new ArrayList(); + String script = isWindows() ? "spark-submit.cmd" : "spark-submit"; + cmd.add(join(File.separator, builder.getSparkHome(), "bin", script)); + cmd.addAll(builder.buildSparkSubmitArgs()); + + // Since the child process is a batch script, let's quote things so that special characters are + // preserved, otherwise the batch interpreter will mess up the arguments. Batch scripts are + // weird. + if (isWindows()) { + List winCmd = new ArrayList(); + for (String arg : cmd) { + winCmd.add(quoteForBatchScript(arg)); + } + cmd = winCmd; + } + + ProcessBuilder pb = new ProcessBuilder(cmd.toArray(new String[cmd.size()])); + for (Map.Entry e : builder.childEnv.entrySet()) { + pb.environment().put(e.getKey(), e.getValue()); + } + return pb.start(); + } + + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java new file mode 100644 index 000000000000..fc87814a59ed --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Special command builder for handling a CLI invocation of SparkSubmit. + *

    + * This builder adds command line parsing compatible with SparkSubmit. It handles setting + * driver-side options and special parsing behavior needed for the special-casing certain internal + * Spark applications. + *

    + * This class has also some special features to aid launching pyspark. + */ +class SparkSubmitCommandBuilder extends AbstractCommandBuilder { + + /** + * Name of the app resource used to identify the PySpark shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "pyspark-shell" since that identifies the PySpark shell to SparkSubmit + * (see java_gateway.py), and can cause this code to enter into an infinite loop. + */ + static final String PYSPARK_SHELL = "pyspark-shell-main"; + + /** + * This is the actual resource name that identifies the PySpark shell to SparkSubmit. + */ + static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + + /** + * This map must match the class names for available special classes, since this modifies the way + * command line parsing works. This maps the class name to the resource to use when calling + * spark-submit. + */ + private static final Map specialClasses = new HashMap(); + static { + specialClasses.put("org.apache.spark.repl.Main", "spark-shell"); + specialClasses.put("org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver", + "spark-internal"); + specialClasses.put("org.apache.spark.sql.hive.thriftserver.HiveThriftServer2", + "spark-internal"); + } + + final List sparkArgs; + private final boolean printHelp; + + /** + * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed + * to parse the command lines for things like bin/spark-shell, which allows users to mix and + * match arguments (e.g. "bin/spark-shell SparkShellArg --master foo"). + */ + private boolean allowsMixedArguments; + + SparkSubmitCommandBuilder() { + this.sparkArgs = new ArrayList(); + this.printHelp = false; + } + + SparkSubmitCommandBuilder(List args) { + this.sparkArgs = new ArrayList(); + List submitArgs = args; + if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { + this.allowsMixedArguments = true; + appResource = PYSPARK_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); + } else { + this.allowsMixedArguments = false; + } + + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.printHelp = parser.helpRequested; + } + + @Override + public List buildCommand(Map env) throws IOException { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { + return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { + return buildSparkRCommand(env); + } else { + return buildSparkSubmitCommand(env); + } + } + + List buildSparkSubmitArgs() { + List args = new ArrayList(); + SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + + if (verbose) { + args.add(parser.VERBOSE); + } + + if (master != null) { + args.add(parser.MASTER); + args.add(master); + } + + if (deployMode != null) { + args.add(parser.DEPLOY_MODE); + args.add(deployMode); + } + + if (appName != null) { + args.add(parser.NAME); + args.add(appName); + } + + for (Map.Entry e : conf.entrySet()) { + args.add(parser.CONF); + args.add(String.format("%s=%s", e.getKey(), e.getValue())); + } + + if (propertiesFile != null) { + args.add(parser.PROPERTIES_FILE); + args.add(propertiesFile); + } + + if (!jars.isEmpty()) { + args.add(parser.JARS); + args.add(join(",", jars)); + } + + if (!files.isEmpty()) { + args.add(parser.FILES); + args.add(join(",", files)); + } + + if (!pyFiles.isEmpty()) { + args.add(parser.PY_FILES); + args.add(join(",", pyFiles)); + } + + if (mainClass != null) { + args.add(parser.CLASS); + args.add(mainClass); + } + + args.addAll(sparkArgs); + if (appResource != null) { + args.add(appResource); + } + args.addAll(appArgs); + + return args; + } + + private List buildSparkSubmitCommand(Map env) throws IOException { + // Load the properties file and check whether spark-submit will be running the app's driver + // or just launching a cluster app. When running the driver, the JVM's argument will be + // modified to cover the driver's configuration. + Properties props = loadPropertiesFile(); + boolean isClientMode = isClientMode(props); + String extraClassPath = isClientMode ? + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null; + + List cmd = buildJavaCommand(extraClassPath); + // Take Thrift Server as daemon + if (isThriftServer(mainClass)) { + addOptionString(cmd, System.getenv("SPARK_DAEMON_JAVA_OPTS")); + } + addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); + addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); + + if (isClientMode) { + // Figuring out where the memory value come from is a little tricky due to precedence. + // Precedence is observed in the following order: + // - explicit configuration (setConf()), which also covers --driver-memory cli argument. + // - properties file. + // - SPARK_DRIVER_MEMORY env variable + // - SPARK_MEM env variable + // - default value (1g) + // Take Thrift Server as daemon + String tsMemory = + isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; + String memory = firstNonEmpty(tsMemory, + firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), + System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); + cmd.add("-Xms" + memory); + cmd.add("-Xmx" + memory); + addOptionString(cmd, firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, conf, props)); + mergeEnvPathList(env, getLibPathEnvName(), + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + } + + addPermGenSizeOpt(cmd); + cmd.add("org.apache.spark.deploy.SparkSubmit"); + cmd.addAll(buildSparkSubmitArgs()); + return cmd; + } + + private List buildPySparkShellCommand(Map env) throws IOException { + // For backwards compatibility, if a script is specified in + // the pyspark command line, then run it using spark-submit. + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".py")) { + System.err.println( + "WARNING: Running python applications through 'pyspark' is deprecated as of Spark 1.0.\n" + + "Use ./bin/spark-submit "); + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + + checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); + + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. + List pyargs = new ArrayList(); + pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); + String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); + if (!isEmpty(pyOpts)) { + pyargs.addAll(parseOptionString(pyOpts)); + } + + return pyargs; + } + + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { + Properties props = loadPropertiesFile(); + mergeEnvPathList(env, getLibPathEnvName(), + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + + StringBuilder submitArgs = new StringBuilder(); + for (String arg : buildSparkSubmitArgs()) { + if (submitArgs.length() > 0) { + submitArgs.append(" "); + } + submitArgs.append(quoteForCommandString(arg)); + } + env.put(submitArgsEnvVariable, submitArgs.toString()); + } + + + private boolean isClientMode(Properties userProps) { + String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); + // Default master is "local[*]", so assume client mode in that case. + return userMaster == null || + "client".equals(deployMode) || + (!userMaster.equals("yarn-cluster") && deployMode == null); + } + + /** + * Return whether the given main class represents a thrift server. + */ + private boolean isThriftServer(String mainClass) { + return (mainClass != null && + mainClass.equals("org.apache.spark.sql.hive.thriftserver.HiveThriftServer2")); + } + + + private class OptionParser extends SparkSubmitOptionParser { + + boolean helpRequested = false; + + @Override + protected boolean handle(String opt, String value) { + if (opt.equals(MASTER)) { + master = value; + } else if (opt.equals(DEPLOY_MODE)) { + deployMode = value; + } else if (opt.equals(PROPERTIES_FILE)) { + propertiesFile = value; + } else if (opt.equals(DRIVER_MEMORY)) { + conf.put(SparkLauncher.DRIVER_MEMORY, value); + } else if (opt.equals(DRIVER_JAVA_OPTIONS)) { + conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, value); + } else if (opt.equals(DRIVER_LIBRARY_PATH)) { + conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, value); + } else if (opt.equals(DRIVER_CLASS_PATH)) { + conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, value); + } else if (opt.equals(CONF)) { + String[] setConf = value.split("=", 2); + checkArgument(setConf.length == 2, "Invalid argument to %s: %s", CONF, value); + conf.put(setConf[0], setConf[1]); + } else if (opt.equals(CLASS)) { + // The special classes require some special command line handling, since they allow + // mixing spark-submit arguments with arguments that should be propagated to the shell + // itself. Note that for this to work, the "--class" argument must come before any + // non-spark-submit arguments. + mainClass = value; + if (specialClasses.containsKey(value)) { + allowsMixedArguments = true; + appResource = specialClasses.get(value); + } + } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { + helpRequested = true; + sparkArgs.add(opt); + } else { + sparkArgs.add(opt); + if (value != null) { + sparkArgs.add(value); + } + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // When mixing arguments, add unrecognized parameters directly to the user arguments list. In + // normal mode, any unrecognized parameter triggers the end of command line parsing, and the + // parameter itself will be interpreted by SparkSubmit as the application resource. The + // remaining params will be appended to the list of SparkSubmit arguments. + if (allowsMixedArguments) { + appArgs.add(opt); + return true; + } else { + checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); + sparkArgs.add(opt); + return false; + } + } + + @Override + protected void handleExtraArgs(List extra) { + for (String arg : extra) { + sparkArgs.add(arg); + } + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java new file mode 100644 index 000000000000..6767cc507964 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Parser for spark-submit command line options. + *

    + * This class encapsulates the parsing code for spark-submit command line options, so that there + * is a single list of options that needs to be maintained (well, sort of, but it makes it harder + * to break things). + */ +class SparkSubmitOptionParser { + + // The following constants define the "main" name for the available options. They're defined + // to avoid copy & paste of the raw strings where they're needed. + // + // The fields are not static so that they're exposed to Scala code that uses this class. See + // SparkSubmitArguments.scala. That is also why this class is not abstract - to allow code to + // easily use these constants without having to create dummy implementations of this class. + protected final String CLASS = "--class"; + protected final String CONF = "--conf"; + protected final String DEPLOY_MODE = "--deploy-mode"; + protected final String DRIVER_CLASS_PATH = "--driver-class-path"; + protected final String DRIVER_CORES = "--driver-cores"; + protected final String DRIVER_JAVA_OPTIONS = "--driver-java-options"; + protected final String DRIVER_LIBRARY_PATH = "--driver-library-path"; + protected final String DRIVER_MEMORY = "--driver-memory"; + protected final String EXECUTOR_MEMORY = "--executor-memory"; + protected final String FILES = "--files"; + protected final String JARS = "--jars"; + protected final String KILL_SUBMISSION = "--kill"; + protected final String MASTER = "--master"; + protected final String NAME = "--name"; + protected final String PACKAGES = "--packages"; + protected final String PACKAGES_EXCLUDE = "--exclude-packages"; + protected final String PROPERTIES_FILE = "--properties-file"; + protected final String PROXY_USER = "--proxy-user"; + protected final String PY_FILES = "--py-files"; + protected final String REPOSITORIES = "--repositories"; + protected final String STATUS = "--status"; + protected final String TOTAL_EXECUTOR_CORES = "--total-executor-cores"; + + // Options that do not take arguments. + protected final String HELP = "--help"; + protected final String SUPERVISE = "--supervise"; + protected final String USAGE_ERROR = "--usage-error"; + protected final String VERBOSE = "--verbose"; + protected final String VERSION = "--version"; + + // Standalone-only options. + + // YARN-only options. + protected final String ARCHIVES = "--archives"; + protected final String EXECUTOR_CORES = "--executor-cores"; + protected final String KEYTAB = "--keytab"; + protected final String NUM_EXECUTORS = "--num-executors"; + protected final String PRINCIPAL = "--principal"; + protected final String QUEUE = "--queue"; + + /** + * This is the canonical list of spark-submit options. Each entry in the array contains the + * different aliases for the same option; the first element of each entry is the "official" + * name of the option, passed to {@link #handle(String, String)}. + *

    + * Options not listed here nor in the "switch" list below will result in a call to + * {@link $#handleUnknown(String)}. + *

    + * These two arrays are visible for tests. + */ + final String[][] opts = { + { ARCHIVES }, + { CLASS }, + { CONF, "-c" }, + { DEPLOY_MODE }, + { DRIVER_CLASS_PATH }, + { DRIVER_CORES }, + { DRIVER_JAVA_OPTIONS }, + { DRIVER_LIBRARY_PATH }, + { DRIVER_MEMORY }, + { EXECUTOR_CORES }, + { EXECUTOR_MEMORY }, + { FILES }, + { JARS }, + { KEYTAB }, + { KILL_SUBMISSION }, + { MASTER }, + { NAME }, + { NUM_EXECUTORS }, + { PACKAGES }, + { PACKAGES_EXCLUDE }, + { PRINCIPAL }, + { PROPERTIES_FILE }, + { PROXY_USER }, + { PY_FILES }, + { QUEUE }, + { REPOSITORIES }, + { STATUS }, + { TOTAL_EXECUTOR_CORES }, + }; + + /** + * List of switches (command line options that do not take parameters) recognized by spark-submit. + */ + final String[][] switches = { + { HELP, "-h" }, + { SUPERVISE }, + { USAGE_ERROR }, + { VERBOSE, "-v" }, + { VERSION }, + }; + + /** + * Parse a list of spark-submit command line options. + *

    + * See SparkSubmitArguments.scala for a more formal description of available options. + * + * @throws IllegalArgumentException If an error is found during parsing. + */ + protected final void parse(List args) { + Pattern eqSeparatedOpt = Pattern.compile("(--[^=]+)=(.+)"); + + int idx = 0; + for (idx = 0; idx < args.size(); idx++) { + String arg = args.get(idx); + String value = null; + + Matcher m = eqSeparatedOpt.matcher(arg); + if (m.matches()) { + arg = m.group(1); + value = m.group(2); + } + + // Look for options with a value. + String name = findCliOption(arg, opts); + if (name != null) { + if (value == null) { + if (idx == args.size() - 1) { + throw new IllegalArgumentException( + String.format("Missing argument for option '%s'.", arg)); + } + idx++; + value = args.get(idx); + } + if (!handle(name, value)) { + break; + } + continue; + } + + // Look for a switch. + name = findCliOption(arg, switches); + if (name != null) { + if (!handle(name, null)) { + break; + } + continue; + } + + if (!handleUnknown(arg)) { + break; + } + } + + if (idx < args.size()) { + idx++; + } + handleExtraArgs(args.subList(idx, args.size())); + } + + /** + * Callback for when an option with an argument is parsed. + * + * @param opt The long name of the cli option (might differ from actual command line). + * @param value The value. This will be null if the option does not take a value. + * @return Whether to continue parsing the argument list. + */ + protected boolean handle(String opt, String value) { + throw new UnsupportedOperationException(); + } + + /** + * Callback for when an unrecognized option is parsed. + * + * @param opt Unrecognized option from the command line. + * @return Whether to continue parsing the argument list. + */ + protected boolean handleUnknown(String opt) { + throw new UnsupportedOperationException(); + } + + /** + * Callback for remaining command line arguments after either {@link #handle(String, String)} or + * {@link #handleUnknown(String)} return "false". This will be called at the end of parsing even + * when there are no remaining arguments. + * + * @param extra List of remaining arguments. + */ + protected void handleExtraArgs(List extra) { + throw new UnsupportedOperationException(); + } + + private String findCliOption(String name, String[][] available) { + for (String[] candidates : available) { + for (String candidate : candidates) { + if (candidate.equals(name)) { + return candidates[0]; + } + } + } + return null; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java new file mode 100644 index 000000000000..7c97dba511b2 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Library for launching Spark applications. + * + *

    + * This library allows applications to launch Spark programmatically. There's only one entry + * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. + *

    + * + *

    + * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} + * and configure the application to run. For example: + *

    + * + *
    + * {@code
    + *   import org.apache.spark.launcher.SparkLauncher;
    + *
    + *   public class MyLauncher {
    + *     public static void main(String[] args) throws Exception {
    + *       Process spark = new SparkLauncher()
    + *         .setAppResource("/my/app.jar")
    + *         .setMainClass("my.spark.app.Main")
    + *         .setMaster("local")
    + *         .setConf(SparkLauncher.DRIVER_MEMORY, "2g")
    + *         .launch();
    + *       spark.waitFor();
    + *     }
    + *   }
    + * }
    + * 
    + */ +package org.apache.spark.launcher; diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java new file mode 100644 index 000000000000..bc513ec9b3d1 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; +import static org.junit.Assert.*; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +public class CommandBuilderUtilsSuite { + + @Test + public void testValidOptionStrings() { + testOpt("a b c d e", Arrays.asList("a", "b", "c", "d", "e")); + testOpt("a 'b c' \"d\" e", Arrays.asList("a", "b c", "d", "e")); + testOpt("a 'b\\\"c' \"'d'\" e", Arrays.asList("a", "b\\\"c", "'d'", "e")); + testOpt("a 'b\"c' \"\\\"d\\\"\" e", Arrays.asList("a", "b\"c", "\"d\"", "e")); + testOpt(" a b c \\\\ ", Arrays.asList("a", "b", "c", "\\")); + + // Following tests ported from UtilsSuite.scala. + testOpt("", new ArrayList()); + testOpt("a", Arrays.asList("a")); + testOpt("aaa", Arrays.asList("aaa")); + testOpt("a b c", Arrays.asList("a", "b", "c")); + testOpt(" a b\t c ", Arrays.asList("a", "b", "c")); + testOpt("a 'b c'", Arrays.asList("a", "b c")); + testOpt("a 'b c' d", Arrays.asList("a", "b c", "d")); + testOpt("'b c'", Arrays.asList("b c")); + testOpt("a \"b c\"", Arrays.asList("a", "b c")); + testOpt("a \"b c\" d", Arrays.asList("a", "b c", "d")); + testOpt("\"b c\"", Arrays.asList("b c")); + testOpt("a 'b\" c' \"d' e\"", Arrays.asList("a", "b\" c", "d' e")); + testOpt("a\t'b\nc'\nd", Arrays.asList("a", "b\nc", "d")); + testOpt("a \"b\\\\c\"", Arrays.asList("a", "b\\c")); + testOpt("a \"b\\\"c\"", Arrays.asList("a", "b\"c")); + testOpt("a 'b\\\"c'", Arrays.asList("a", "b\\\"c")); + testOpt("'a'b", Arrays.asList("ab")); + testOpt("'a''b'", Arrays.asList("ab")); + testOpt("\"a\"b", Arrays.asList("ab")); + testOpt("\"a\"\"b\"", Arrays.asList("ab")); + testOpt("''", Arrays.asList("")); + testOpt("\"\"", Arrays.asList("")); + } + + @Test + public void testInvalidOptionStrings() { + testInvalidOpt("\\"); + testInvalidOpt("\"abcde"); + testInvalidOpt("'abcde"); + } + + @Test + public void testWindowsBatchQuoting() { + assertEquals("abc", quoteForBatchScript("abc")); + assertEquals("\"a b c\"", quoteForBatchScript("a b c")); + assertEquals("\"a \"\"b\"\" c\"", quoteForBatchScript("a \"b\" c")); + assertEquals("\"a\"\"b\"\"c\"", quoteForBatchScript("a\"b\"c")); + assertEquals("\"ab=\"\"cd\"\"\"", quoteForBatchScript("ab=\"cd\"")); + assertEquals("\"a,b,c\"", quoteForBatchScript("a,b,c")); + assertEquals("\"a;b;c\"", quoteForBatchScript("a;b;c")); + assertEquals("\"a,b,c\\\\\"", quoteForBatchScript("a,b,c\\")); + } + + @Test + public void testPythonArgQuoting() { + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); + } + + private void testOpt(String opts, List expected) { + assertEquals(String.format("test string failed to parse: [[ %s ]]", opts), + expected, parseOptionString(opts)); + } + + private void testInvalidOpt(String opts) { + try { + parseOptionString(opts); + fail("Expected exception for invalid option string."); + } catch (IllegalArgumentException e) { + // pass. + } + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java new file mode 100644 index 000000000000..d0c26dd05679 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +/** + * These tests require the Spark assembly to be built before they can be run. + */ +public class SparkLauncherSuite { + + private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } + + @Test + public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + Map env = new HashMap(); + env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); + + SparkLauncher launcher = new SparkLauncher(env) + .setSparkHome(System.getProperty("spark.test.home")) + .setMaster("local") + .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) + .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, + "-Dfoo=bar -Dtest.name=-testChildProcLauncher") + .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") + .setMainClass(SparkLauncherTestApp.class.getName()) + .addAppArgs("proc"); + final Process app = launcher.launch(); + new Redirector("stdout", app.getInputStream()).start(); + new Redirector("stderr", app.getErrorStream()).start(); + assertEquals(0, app.waitFor()); + } + + public static class SparkLauncherTestApp { + + public static void main(String[] args) throws Exception { + assertEquals(1, args.length); + assertEquals("proc", args[0]); + assertEquals("bar", System.getProperty("foo")); + assertEquals("local", System.getProperty(SparkLauncher.SPARK_MASTER)); + } + + } + + private static class Redirector extends Thread { + + private final InputStream in; + + Redirector(String name, InputStream in) { + this.in = in; + setName(name); + setDaemon(true); + } + + @Override + public void run() { + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + String line; + while ((line = reader.readLine()) != null) { + LOG.warn(line); + } + } catch (Exception e) { + LOG.error("Error reading process output.", e); + } + } + + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java new file mode 100644 index 000000000000..7329ac9f7fb8 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +public class SparkSubmitCommandBuilderSuite { + + private static File dummyPropsFile; + private static SparkSubmitOptionParser parser; + + @BeforeClass + public static void setUp() throws Exception { + dummyPropsFile = File.createTempFile("spark", "properties"); + parser = new SparkSubmitOptionParser(); + } + + @AfterClass + public static void cleanUp() throws Exception { + dummyPropsFile.delete(); + } + + @Test + public void testDriverCmdBuilder() throws Exception { + testCmdBuilder(true); + } + + @Test + public void testClusterCmdBuilder() throws Exception { + testCmdBuilder(false); + } + + @Test + public void testCliParser() throws Exception { + List sparkSubmitArgs = Arrays.asList( + parser.MASTER, + "local", + parser.DRIVER_MEMORY, + "42g", + parser.DRIVER_CLASS_PATH, + "/driverCp", + parser.DRIVER_JAVA_OPTIONS, + "extraJavaOpt", + parser.CONF, + "spark.randomOption=foo", + parser.CONF, + SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH + "=/driverLibPath"); + Map env = new HashMap(); + List cmd = buildCommand(sparkSubmitArgs, env); + + assertTrue(findInStringList(env.get(CommandBuilderUtils.getLibPathEnvName()), + File.pathSeparator, "/driverLibPath")); + assertTrue(findInStringList(findArgValue(cmd, "-cp"), File.pathSeparator, "/driverCp")); + assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms42g")); + assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx42g")); + assertTrue("Command should contain user-defined conf.", + Collections.indexOfSubList(cmd, Arrays.asList(parser.CONF, "spark.randomOption=foo")) > 0); + } + + @Test + public void testShellCliParser() throws Exception { + List sparkSubmitArgs = Arrays.asList( + parser.CLASS, + "org.apache.spark.repl.Main", + parser.MASTER, + "foo", + "--app-arg", + "bar", + "--app-switch", + parser.FILES, + "baz", + parser.NAME, + "appName"); + + List args = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List expected = Arrays.asList("spark-shell", "--app-arg", "bar", "--app-switch"); + assertEquals(expected, args.subList(args.size() - expected.size(), args.size())); + } + + @Test + public void testAlternateSyntaxParsing() throws Exception { + List sparkSubmitArgs = Arrays.asList( + parser.CLASS + "=org.my.Class", + parser.MASTER + "=foo", + parser.DEPLOY_MODE + "=bar"); + + List cmd = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + assertEquals("org.my.Class", findArgValue(cmd, parser.CLASS)); + assertEquals("foo", findArgValue(cmd, parser.MASTER)); + assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); + } + + @Test + public void testPySparkLauncher() throws Exception { + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.PYSPARK_SHELL, + "--master=foo", + "--deploy-mode=bar"); + + Map env = new HashMap(); + List cmd = buildCommand(sparkSubmitArgs, env); + assertEquals("python", cmd.get(cmd.size() - 1)); + assertEquals( + String.format("\"%s\" \"foo\" \"%s\" \"bar\" \"%s\"", + parser.MASTER, parser.DEPLOY_MODE, SparkSubmitCommandBuilder.PYSPARK_SHELL_RESOURCE), + env.get("PYSPARK_SUBMIT_ARGS")); + } + + @Test + public void testPySparkFallback() throws Exception { + List sparkSubmitArgs = Arrays.asList( + "--master=foo", + "--deploy-mode=bar", + "script.py", + "arg1"); + + Map env = new HashMap(); + List cmd = buildCommand(sparkSubmitArgs, env); + + assertEquals("foo", findArgValue(cmd, "--master")); + assertEquals("bar", findArgValue(cmd, "--deploy-mode")); + assertEquals("script.py", cmd.get(cmd.size() - 2)); + assertEquals("arg1", cmd.get(cmd.size() - 1)); + } + + private void testCmdBuilder(boolean isDriver) throws Exception { + String deployMode = isDriver ? "client" : "cluster"; + + SparkSubmitCommandBuilder launcher = + newCommandBuilder(Collections.emptyList()); + launcher.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, + System.getProperty("spark.test.home")); + launcher.master = "yarn"; + launcher.deployMode = deployMode; + launcher.appResource = "/foo"; + launcher.appName = "MyApp"; + launcher.mainClass = "my.Class"; + launcher.propertiesFile = dummyPropsFile.getAbsolutePath(); + launcher.appArgs.add("foo"); + launcher.appArgs.add("bar"); + launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver -XX:MaxPermSize=256m"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); + launcher.conf.put("spark.foo", "foo"); + + Map env = new HashMap(); + List cmd = launcher.buildCommand(env); + + // Checks below are different for driver and non-driver mode. + + if (isDriver) { + assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms1g")); + assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx1g")); + } else { + boolean found = false; + for (String arg : cmd) { + if (arg.startsWith("-Xms") || arg.startsWith("-Xmx")) { + found = true; + break; + } + } + assertFalse("Memory arguments should not be set.", found); + } + + for (String arg : cmd) { + if (arg.startsWith("-XX:MaxPermSize=")) { + if (isDriver) { + assertEquals("-XX:MaxPermSize=256m", arg); + } else { + assertEquals("-XX:MaxPermSize=256m", arg); + } + } + } + + String[] cp = findArgValue(cmd, "-cp").split(Pattern.quote(File.pathSeparator)); + if (isDriver) { + assertTrue("Driver classpath should contain provided entry.", contains("/driver", cp)); + } else { + assertFalse("Driver classpath should not be in command.", contains("/driver", cp)); + } + + String libPath = env.get(CommandBuilderUtils.getLibPathEnvName()); + if (isDriver) { + assertNotNull("Native library path should be set.", libPath); + assertTrue("Native library path should contain provided entry.", + contains("/native", libPath.split(Pattern.quote(File.pathSeparator)))); + } else { + assertNull("Native library should not be set.", libPath); + } + + // Checks below are the same for both driver and non-driver mode. + assertEquals(dummyPropsFile.getAbsolutePath(), findArgValue(cmd, parser.PROPERTIES_FILE)); + assertEquals("yarn", findArgValue(cmd, parser.MASTER)); + assertEquals(deployMode, findArgValue(cmd, parser.DEPLOY_MODE)); + assertEquals("my.Class", findArgValue(cmd, parser.CLASS)); + assertEquals("MyApp", findArgValue(cmd, parser.NAME)); + + boolean appArgsOk = false; + for (int i = 0; i < cmd.size(); i++) { + if (cmd.get(i).equals("/foo")) { + assertEquals("foo", cmd.get(i + 1)); + assertEquals("bar", cmd.get(i + 2)); + assertEquals(cmd.size(), i + 3); + appArgsOk = true; + break; + } + } + assertTrue("App resource and args should be added to command.", appArgsOk); + + Map conf = parseConf(cmd, parser); + assertEquals("foo", conf.get("spark.foo")); + } + + private boolean contains(String needle, String[] haystack) { + for (String entry : haystack) { + if (entry.equals(needle)) { + return true; + } + } + return false; + } + + private Map parseConf(List cmd, SparkSubmitOptionParser parser) { + Map conf = new HashMap(); + for (int i = 0; i < cmd.size(); i++) { + if (cmd.get(i).equals(parser.CONF)) { + String[] val = cmd.get(i + 1).split("=", 2); + conf.put(val[0], val[1]); + i += 1; + } + } + return conf; + } + + private String findArgValue(List cmd, String name) { + for (int i = 0; i < cmd.size(); i++) { + if (cmd.get(i).equals(name)) { + return cmd.get(i + 1); + } + } + fail(String.format("arg '%s' not found", name)); + return null; + } + + private boolean findInStringList(String list, String sep, String needle) { + return contains(needle, list.split(sep)); + } + + private SparkSubmitCommandBuilder newCommandBuilder(List args) { + SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); + return builder; + } + + private List buildCommand(List args, Map env) throws Exception { + return newCommandBuilder(args).buildCommand(env); + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java new file mode 100644 index 000000000000..f3d210991705 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import static org.apache.spark.launcher.SparkSubmitOptionParser.*; + +public class SparkSubmitOptionParserSuite { + + private SparkSubmitOptionParser parser; + + @Before + public void setUp() { + parser = spy(new DummyParser()); + } + + @Test + public void testAllOptions() { + int count = 0; + for (String[] optNames : parser.opts) { + for (String optName : optNames) { + String value = optName + "-value"; + parser.parse(Arrays.asList(optName, value)); + count++; + verify(parser).handle(eq(optNames[0]), eq(value)); + verify(parser, times(count)).handle(anyString(), anyString()); + verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); + } + } + + for (String[] switchNames : parser.switches) { + int switchCount = 0; + for (String name : switchNames) { + parser.parse(Arrays.asList(name)); + count++; + switchCount++; + verify(parser, times(switchCount)).handle(eq(switchNames[0]), same((String) null)); + verify(parser, times(count)).handle(anyString(), any(String.class)); + verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); + } + } + } + + @Test + public void testExtraOptions() { + List args = Arrays.asList(parser.MASTER, parser.MASTER, "foo", "bar"); + parser.parse(args); + verify(parser).handle(eq(parser.MASTER), eq(parser.MASTER)); + verify(parser).handleUnknown(eq("foo")); + verify(parser).handleExtraArgs(eq(Arrays.asList("bar"))); + } + + @Test(expected=IllegalArgumentException.class) + public void testMissingArg() { + parser.parse(Arrays.asList(parser.MASTER)); + } + + @Test + public void testEqualSeparatedOption() { + List args = Arrays.asList(parser.MASTER + "=" + parser.MASTER); + parser.parse(args); + verify(parser).handle(eq(parser.MASTER), eq(parser.MASTER)); + verify(parser).handleExtraArgs(eq(Collections.emptyList())); + } + + private static class DummyParser extends SparkSubmitOptionParser { + + @Override + protected boolean handle(String opt, String value) { + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + + } + + } + +} diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties new file mode 100644 index 000000000000..67a6a9821711 --- /dev/null +++ b/launcher/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false + +# Some tests will set "test.name" to avoid overwriting the main log file. +log4j.appender.file.file=target/unit-tests${test.name}.log + +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/make-distribution.sh b/make-distribution.sh index 051c87c0894a..04ad0052eb24 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -26,13 +26,14 @@ set -o pipefail set -e +set -x # Figure out where the Spark framework is installed SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.5.0" +TACHYON_VERSION="0.7.1" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" @@ -57,7 +58,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-0.23, hdaoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) @@ -98,9 +99,9 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found if [ $(command -v rpm) ]; then - RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" if [ "$RPM_JAVA_HOME" != "%java_home" ]; then - JAVA_HOME=$RPM_JAVA_HOME + JAVA_HOME="$RPM_JAVA_HOME" echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" fi fi @@ -113,46 +114,33 @@ fi if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) - if [ ! -z $GITREV ]; then + if [ ! -z "$GITREV" ]; then GITREVSTRING=" (git revision $GITREV)" fi unset GITREV fi -if [ ! $(command -v $MVN) ] ; then +if [ ! $(command -v "$MVN") ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$($MVN help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$($MVN help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) +SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ +SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | tail -n 1) +SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ # because we use "set -o pipefail" echo -n) -JAVA_CMD="$JAVA_HOME"/bin/java -JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) -if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then - echo "***NOTE***: JAVA_HOME is not set to a JDK 6 installation. The resulting" - echo " distribution may not work well with PySpark and will not run" - echo " with Java 6 (See SPARK-1703 and SPARK-1911)." - echo " This test can be disabled by adding --skip-java-test." - echo "Output from 'java -version' was:" - echo "$JAVA_VERSION" - read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! $REPLY =~ ^[Yy]$ ]]; then - echo "Okay, exiting." - exit 1 - fi -fi - if [ "$NAME" == "none" ]; then NAME=$SPARK_HADOOP_VERSION fi @@ -227,12 +215,17 @@ cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" cp -r "$SPARK_HOME/ec2" "$DISTDIR" +# Copy SparkR if it exists +if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then + mkdir -p "$DISTDIR"/R/lib + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib +fi # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - pushd $TMPD > /dev/null + pushd "$TMPD" > /dev/null echo "Fetching tachyon tgz" TACHYON_DL="${TACHYON_TGZ}.part" @@ -259,7 +252,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi popd > /dev/null - rm -rf $TMPD + rm -rf "$TMPD" fi if [ "$MAKE_TGZ" == "true" ]; then diff --git a/mllib/pom.xml b/mllib/pom.xml index a8cee3d51a78..a5db14407b4f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version}
    + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -59,11 +66,12 @@ org.jblas jblas ${jblas.version} + test org.scalanlp breeze_${scala.binary.version} - 0.10 + 0.11.2 @@ -98,7 +106,7 @@ org.mockito - mockito-all + mockito-core test @@ -108,6 +116,21 @@ test-jar test + + org.jpmml + pmml-model + 1.1.15 + + + com.sun.xml.fastinfoset + FastInfoset + + + com.sun.istack + istack-commons-runtime + + + @@ -116,7 +139,7 @@ com.github.fommil.netlib all - 1.1.2 + ${netlib.java.version} pom @@ -125,16 +148,5 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - ../python - - pyspark/mllib/*.py - pyspark/mllib/stat/*.py - pyspark/ml/*.py - pyspark/ml/param/*.py - - - diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index bc3defe968af..57e416591de6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,27 +19,31 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for estimators that fit models to data. */ -@AlphaComponent -abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { +@DeveloperApi +abstract class Estimator[M <: Model[M]] extends PipelineStage { /** * Fits a single model to the input data with optional parameters. * * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) + * @param firstParamPair the first param pair, overrides embedded params + * @param otherParamPairs other param pairs. These values override any specified in this + * Estimator's embedded ParamMap. * @return fitted model */ @varargs - def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + val map = new ParamMap() + .put(firstParamPair) + .put(otherParamPairs: _*) fit(dataset, map) } @@ -47,21 +51,32 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with provided parameter map. * * @param dataset input dataset - * @param paramMap parameter map + * @param paramMap Parameter map. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M = { + copy(paramMap).fit(dataset) + } + + /** + * Fits a model to the input data. + */ + def fit(dataset: DataFrame): M /** * Fits multiple models to the input data with multiple sets of parameters. * The default implementation uses a for loop on each parameter map. - * Subclasses could overwrite this to optimize multi-model training. + * Subclasses could override this to optimize multi-model training. * * @param dataset input dataset - * @param paramMaps an array of parameter maps + * @param paramMaps An array of parameter maps. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } + + override def copy(extra: ParamMap): Estimator[M] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala deleted file mode 100644 index d2ca2e6871e6..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml - -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.DataFrame - -/** - * :: AlphaComponent :: - * Abstract class for evaluators that compute metrics from predictions. - */ -@AlphaComponent -abstract class Evaluator extends Identifiable { - - /** - * Evaluates the output. - * - * @param dataset a dataset that contains labels/observations and predictions. - * @param paramMap parameter map that specifies the input columns and output metrics - * @return metric - */ - def evaluate(dataset: DataFrame, paramMap: ParamMap): Double -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala deleted file mode 100644 index cd84b05bfb49..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml - -import java.util.UUID - -/** - * Object with a unique id. - */ -private[ml] trait Identifiable extends Serializable { - - /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 - * random hex chars. - */ - private[ml] val uid: String = - this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index cae5082b5119..252acc156583 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -17,24 +17,33 @@ package org.apache.spark.ml -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. * * @tparam M model type */ -@AlphaComponent +@DeveloperApi abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. + * Note: For ensembles' component Models, this value can be null. */ - val parent: Estimator[M] + @transient var parent: Estimator[M] = _ /** - * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + * Sets the parent of this model (Java API). */ - val fittingParamMap: ParamMap + def setParent(parent: Estimator[M]): M = { + this.parent = parent + this.asInstanceOf[M] + } + + /** Indicates whether this [[Model]] has a corresponding parent. */ + def hasParent: Boolean = parent != null + + override def copy(extra: ParamMap): M } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index bb291e6e1fd7..a3e59401c5cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -17,46 +17,60 @@ package org.apache.spark.ml +import java.{util => ju} + +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. */ -@AlphaComponent -abstract class PipelineStage extends Serializable with Logging { +@DeveloperApi +abstract class PipelineStage extends Params with Logging { /** - * Derives the output schema from the input schema and parameters. + * :: DeveloperApi :: + * + * Derives the output schema from the input schema. */ - private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + @DeveloperApi + def transformSchema(schema: StructType): StructType /** + * :: DeveloperApi :: + * * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. */ + @DeveloperApi protected def transformSchema( schema: StructType, - paramMap: ParamMap, logging: Boolean): StructType = { if (logging) { logDebug(s"Input schema: ${schema.json}") } - val outputSchema = transformSchema(schema, paramMap) + val outputSchema = transformSchema(schema) if (logging) { logDebug(s"Expected output schema: ${outputSchema.json}") } outputSchema } + + override def copy(extra: ParamMap): PipelineStage } /** - * :: AlphaComponent :: + * :: Experimental :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will @@ -67,13 +81,29 @@ abstract class PipelineStage extends Serializable with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ -@AlphaComponent -class Pipeline extends Estimator[PipelineModel] { +@Experimental +class Pipeline(override val uid: String) extends Estimator[PipelineModel] { + + def this() = this(Identifiable.randomUID("pipeline")) - /** param for pipeline stages */ + /** + * param for pipeline stages + * @group param + */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") + + /** @group setParam */ def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + + // Below, we clone stages so that modifications to the list of stages will not change + // the Param value in the Pipeline. + /** @group getParam */ + def getStages: Array[PipelineStage] = $(stages).clone() + + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) + } /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an @@ -85,13 +115,11 @@ class Pipeline extends Estimator[PipelineModel] { * pipeline stages. If there are no stages, the output model acts as an identity transformer. * * @param dataset input dataset - * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val theStages = map(stages) + override def fit(dataset: DataFrame): PipelineModel = { + transformSchema(dataset.schema, logging = true) + val theStages = $(stages) // Search for the last estimator. var indexOfLastEstimator = -1 theStages.view.zipWithIndex.foreach { case (stage, index) => @@ -107,71 +135,69 @@ class Pipeline extends Estimator[PipelineModel] { if (index <= indexOfLastEstimator) { val transformer = stage match { case estimator: Estimator[_] => - estimator.fit(curDataset, paramMap) + estimator.fit(curDataset) case t: Transformer => t case _ => throw new IllegalArgumentException( s"Do not support stage $stage of type ${stage.getClass}") } - curDataset = transformer.transform(curDataset, paramMap) + if (index < indexOfLastEstimator) { + curDataset = transformer.transform(curDataset) + } transformers += transformer } else { transformers += stage.asInstanceOf[Transformer] } } - new PipelineModel(this, map, transformers.toArray) + new PipelineModel(uid, transformers.toArray).setParent(this) + } + + override def copy(extra: ParamMap): Pipeline = { + val map = extractParamMap(extra) + val newStages = map(stages).map(_.copy(extra)) + new Pipeline().setStages(newStages) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - val theStages = map(stages) - require(theStages.toSet.size == theStages.size, + override def transformSchema(schema: StructType): StructType = { + val theStages = $(stages) + require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") - theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) + theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } } /** - * :: AlphaComponent :: - * Represents a compiled pipeline. + * :: Experimental :: + * Represents a fitted pipeline. */ -@AlphaComponent +@Experimental class PipelineModel private[ml] ( - override val parent: Pipeline, - override val fittingParamMap: ParamMap, - private[ml] val stages: Array[Transformer]) + override val uid: String, + val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { - /** - * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input - * estimator does not exist in the pipeline. - */ - def getModel[M <: Model[M]](stage: Estimator[M]): M = { - val matched = stages.filter { - case m: Model[_] => m.parent.eq(stage) - case _ => false - } - if (matched.isEmpty) { - throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") - } else if (matched.size > 1) { - throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") - } else { - matched.head.asInstanceOf[M] - } + /** A Java/Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, stages: ju.List[Transformer]) = { + this(uid, stages.asScala.toArray) + } + + override def validateParams(): Unit = { + super.validateParams() + stages.foreach(_.validateParams()) + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap - transformSchema(dataset.schema, map, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) + override def transformSchema(schema: StructType): StructType = { + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap - stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) + override def copy(extra: ParamMap): PipelineModel = { + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala new file mode 100644 index 000000000000..19fe039b8fd0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +/** + * (private[ml]) Trait for parameters for prediction (regression and classification). + */ +private[ml] trait PredictorParams extends Params + with HasLabelCol with HasFeaturesCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector + SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) + if (fitting) { + // TODO: Allow other numeric types + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: DeveloperApi :: + * Abstraction for prediction problems (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam Learner Specialization of this class. If you subclass this type, use this type + * parameter to specify the concrete type. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + */ +@DeveloperApi +abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] with PredictorParams { + + /** @group setParam */ + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + override def fit(dataset: DataFrame): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, logging = true) + copyValues(train(dataset).setParent(this)) + } + + override def copy(extra: ParamMap): Learner + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + protected def train(dataset: DataFrame): M + + /** + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + private[ml] def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, featuresDataType) + } + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { + dataset.select($(labelCol), $(featuresCol)) + .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } + } +} + +/** + * :: DeveloperApi :: + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + */ +@DeveloperApi +abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] with PredictorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + /** @group setParam */ + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false, featuresDataType) + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * the predictions as a new column [[predictionCol]]. + * + * @param dataset input dataset + * @return transformed dataset with [[predictionCol]] of type [[Double]] + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + if ($(predictionCol).nonEmpty) { + transformImpl(dataset) + } else { + this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + + " since no output columns were set.") + dataset + } + } + + protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + protected def predict(features: FeaturesType): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index cd95c16aa768..3c7bcf7590e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -20,29 +20,35 @@ package org.apache.spark.ml import scala.annotation.varargs import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for transformers that transform one dataset into another. */ -@AlphaComponent -abstract class Transformer extends PipelineStage with Params { +@DeveloperApi +abstract class Transformer extends PipelineStage { /** * Transforms the dataset with optional parameters * @param dataset input dataset - * @param paramPairs optional list of param pairs, overwrite embedded params + * @param firstParamPair the first param pair, overwrite embedded params + * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ @varargs - def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { + def transform( + dataset: DataFrame, + firstParamPair: ParamPair[_], + otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() - paramPairs.foreach(map.put(_)) + .put(firstParamPair) + .put(otherParamPairs: _*) transform(dataset, map) } @@ -52,17 +58,31 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + this.copy(paramMap).transform(dataset) + } + + /** + * Transforms the input dataset. + */ + def transform(dataset: DataFrame): DataFrame + + override def copy(extra: ParamMap): Transformer } /** + * :: DeveloperApi :: * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] +@DeveloperApi +abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { + /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + + /** @group setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** @@ -70,7 +90,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O * account of the embedded param map. So the param values should be determined solely by the input * param map. */ - protected def createTransformFunc(paramMap: ParamMap): IN => OUT + protected def createTransformFunc: IN => OUT /** * Returns the data type of the output column. @@ -82,22 +102,22 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O */ protected def validateInputType(inputType: DataType): Unit = {} - override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - val inputType = schema(map(inputCol)).dataType + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType validateInputType(inputType) - if (schema.fieldNames.contains(map(outputCol))) { - throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") } val outputFields = schema.fields :+ - StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive) + StructField($(outputCol), outputDataType, nullable = false) StructType(outputFields) } - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - dataset.select($"*", callUDF( - this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + dataset.withColumn($(outputCol), + callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) } + + override def copy(extra: ParamMap): T = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala new file mode 100644 index 000000000000..7429f9d652ac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * In-place DGEMM and DGEMV for Breeze + */ +private[ann] object BreezeUtil { + + // TODO: switch to MLlib BLAS interface + private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + + /** + * DGEMM: C := alpha * A * B + beta * C + * @param alpha alpha + * @param a A + * @param b B + * @param beta beta + * @param c C + */ + def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + // TODO: add code if matrices isTranspose!!! + require(a.cols == b.rows, "A & B Dimension mismatch!") + require(a.rows == c.rows, "A & C Dimension mismatch!") + require(b.cols == c.cols, "A & C Dimension mismatch!") + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) + } + + /** + * DGEMV: y := alpha * A * x + beta * y + * @param alpha alpha + * @param a A + * @param x x + * @param beta beta + * @param y y + */ + def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + require(a.cols == x.length, "A & b Dimension mismatch!") + NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, + alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + beta, y.data, y.offset, y.stride) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala new file mode 100644 index 000000000000..b5258ff34847 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, + sum => Bsum} +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Trait that holds Layer properties, that are needed to instantiate it. + * Implements Layer instantiation. + * + */ +private[ann] trait Layer extends Serializable { + /** + * Returns the instance of the layer based on weights provided + * @param weights vector with layer weights + * @param position position of weights in the vector + * @return the layer model + */ + def getInstance(weights: Vector, position: Int): LayerModel + + /** + * Returns the instance of the layer with random generated weights + * @param seed seed + * @return the layer model + */ + def getInstance(seed: Long): LayerModel +} + +/** + * Trait that holds Layer weights (or parameters). + * Implements functions needed for forward propagation, computing delta and gradient. + * Can return weights in Vector format. + */ +private[ann] trait LayerModel extends Serializable { + /** + * number of weights + */ + val size: Int + + /** + * Evaluates the data (process the data through the layer) + * @param data data + * @return processed data + */ + def eval(data: BDM[Double]): BDM[Double] + + /** + * Computes the delta for back propagation + * @param nextDelta delta of the next layer + * @param input input data + * @return delta + */ + def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + + /** + * Computes the gradient + * @param delta delta for this layer + * @param input input data + * @return gradient + */ + def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] + + /** + * Returns weights for the layer in a single vector + * @return layer weights + */ + def weights(): Vector +} + +/** + * Layer properties of affine transformations, that is y=A*x+b + * @param numIn number of inputs + * @param numOut number of outputs + */ +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { + + override def getInstance(weights: Vector, position: Int): LayerModel = { + AffineLayerModel(this, weights, position) + } + + override def getInstance(seed: Long = 11L): LayerModel = { + AffineLayerModel(this, seed) + } +} + +/** + * Model of Affine layer y=A*x+b + * @param w weights (matrix A) + * @param b bias (vector b) + */ +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { + val size = w.size + b.length + val gwb = new Array[Double](size) + private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) + private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) + private var z: BDM[Double] = null + private var d: BDM[Double] = null + private var ones: BDV[Double] = null + + override def eval(data: BDM[Double]): BDM[Double] = { + if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) + z(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, z) + z + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) + BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) + gwb + } + + override def weights(): Vector = AffineLayerModel.roll(w, b) +} + +/** + * Fabric for Affine layer models + */ +private[ann] object AffineLayerModel { + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param weights vector with weights + * @param position position of weights in the vector + * @return model of Affine layer + */ + def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { + val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) + new AffineLayerModel(w, b) + } + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param seed seed + * @return model of Affine layer + */ + def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { + val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) + new AffineLayerModel(w, b) + } + + /** + * Unrolls the weights from the vector + * @param weights vector with weights + * @param position position of weights for this layer + * @param numIn number of layer inputs + * @param numOut number of layer outputs + * @return matrix A and vector b + */ + def unroll( + weights: Vector, + position: Int, + numIn: Int, + numOut: Int): (BDM[Double], BDV[Double]) = { + val weightsCopy = weights.toArray + // TODO: the array is not copied to BDMs, make sure this is OK! + val a = new BDM[Double](numOut, numIn, weightsCopy, position) + val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) + (a, b) + } + + /** + * Roll the layer weights into a vector + * @param a matrix A + * @param b vector b + * @return vector of weights + */ + def roll(a: BDM[Double], b: BDV[Double]): Vector = { + val result = new Array[Double](a.size + b.length) + // TODO: make sure that we need to copy! + System.arraycopy(a.toArray, 0, result, 0, a.size) + System.arraycopy(b.toArray, 0, result, a.size, b.length) + Vectors.dense(result) + } + + /** + * Generate random weights for the layer + * @param numIn number of inputs + * @param numOut number of outputs + * @param seed seed + * @return (matrix A, vector b) + */ + def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { + val rand: XORShiftRandom = new XORShiftRandom(seed) + val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } + val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } + (weights, bias) + } +} + +/** + * Trait for functions and their derivatives for functional layers + */ +private[ann] trait ActivationFunction extends Serializable { + + /** + * Implements a function + * @param x input data + * @param y output data + */ + def eval(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a derivative of a function (needed for the back propagation) + * @param x input data + * @param y output data + */ + def derivative(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a cross entropy error of a function. + * Needed if the functional layer that contains this function is the output layer + * of the network. + * @param target target output + * @param output computed output + * @param result intermediate result + * @return cross-entropy + */ + def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + + /** + * Implements a mean squared error of a function + * @param target target output + * @param output computed output + * @param result intermediate result + * @return mean squared error + */ + def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double +} + +/** + * Implements in-place application of functions + */ +private[ann] object ActivationFunction { + + def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { + var i = 0 + while (i < x.rows) { + var j = 0 + while (j < x.cols) { + y(i, j) = func(x(i, j)) + j += 1 + } + i += 1 + } + } + + def apply( + x1: BDM[Double], + x2: BDM[Double], + y: BDM[Double], + func: (Double, Double) => Double): Unit = { + var i = 0 + while (i < x1.rows) { + var j = 0 + while (j < x1.cols) { + y(i, j) = func(x1(i, j), x2(i, j)) + j += 1 + } + i += 1 + } + } +} + +/** + * Implements SoftMax activation function + */ +private[ann] class SoftmaxFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < x.cols) { + var i = 0 + var max = Double.MinValue + while (i < x.rows) { + if (x(i, j) > max) { + max = x(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < x.rows) { + val res = Math.exp(x(i, j) - max) + y(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < x.rows) { + y(i, j) /= sum + i += 1 + } + j += 1 + } + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") + } +} + +/** + * Implements Sigmoid activation function + */ +private[ann] class SigmoidFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + def s(z: Double): Double = Bsigmoid(z) + ActivationFunction(x, y, s) + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum(target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + // TODO: make it readable + def m(o: Double, t: Double): Double = (o - t) + ActivationFunction(output, target, result, m) + val e = Bsum(result :* result) / 2 / output.cols + def m2(x: Double, o: Double) = x * (o - o * o) + ActivationFunction(result, output, result, m2) + e + } +} + +/** + * Functional layer properties, y = f(x) + * @param activationFunction activation function + */ +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { + override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) + + override def getInstance(seed: Long): LayerModel = + FunctionalLayerModel(this) +} + +/** + * Functional layer model. Holds no weights. + * @param activationFunction activation function + */ +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) + extends LayerModel { + val size = 0 + // matrices for in-place computations + // outputs + private var f: BDM[Double] = null + // delta + private var d: BDM[Double] = null + // matrix for error computation + private var e: BDM[Double] = null + // delta gradient + private lazy val dg = new Array[Double](0) + + override def eval(data: BDM[Double]): BDM[Double] = { + if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) + activationFunction.eval(data, f) + f + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) + activationFunction.derivative(input, d) + d :*= nextDelta + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg + + override def weights(): Vector = Vectors.dense(new Array[Double](0)) + + def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.crossEntropy(output, target, e) + (e, error) + } + + def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.squared(output, target, e) + (e, error) + } + + def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + // TODO: allow user pick error + activationFunction match { + case sigmoid: SigmoidFunction => squared(output, target) + case softmax: SoftmaxFunction => crossEntropy(output, target) + } + } +} + +/** + * Fabric of functional layer models + */ +private[ann] object FunctionalLayerModel { + def apply(layer: FunctionalLayer): FunctionalLayerModel = + new FunctionalLayerModel(layer.activationFunction) +} + +/** + * Trait for the artificial neural network (ANN) topology properties + */ +private[ann] trait Topology extends Serializable{ + def getInstance(weights: Vector): TopologyModel + def getInstance(seed: Long): TopologyModel +} + +/** + * Trait for ANN topology model + */ +private[ann] trait TopologyModel extends Serializable{ + /** + * Forward propagation + * @param data input data + * @return array of outputs for each of the layers + */ + def forward(data: BDM[Double]): Array[BDM[Double]] + + /** + * Prediction of the model + * @param data input data + * @return prediction + */ + def predict(data: Vector): Vector + + /** + * Computes gradient for the network + * @param data input data + * @param target target output + * @param cumGradient cumulative gradient + * @param blockSize block size + * @return error + */ + def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + blockSize: Int): Double + + /** + * Returns the weights of the ANN + * @return weights + */ + def weights(): Vector +} + +/** + * Feed forward ANN + * @param layers + */ +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { + override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + + override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) +} + +/** + * Factory for some of the frequently-used topologies + */ +private[ml] object FeedForwardTopology { + /** + * Creates a feed forward topology from the array of layers + * @param layers array of layers + * @return feed forward topology + */ + def apply(layers: Array[Layer]): FeedForwardTopology = { + new FeedForwardTopology(layers) + } + + /** + * Creates a multi-layer perceptron + * @param layerSizes sizes of layers including input and output size + * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * Softmax is default + * @return multilayer perceptron topology + */ + def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + val layers = new Array[Layer]((layerSizes.length - 1) * 2) + for(i <- 0 until layerSizes.length - 1){ + layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) + layers(i * 2 + 1) = + if (softmax && i == layerSizes.length - 2) { + new FunctionalLayer(new SoftmaxFunction()) + } else { + new FunctionalLayer(new SigmoidFunction()) + } + } + FeedForwardTopology(layers) + } +} + +/** + * Model of Feed Forward Neural Network. + * Implements forward, gradient computation and can return weights in vector format. + * @param layerModels models of layers + * @param topology topology of the network + */ +private[ml] class FeedForwardModel private( + val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { + override def forward(data: BDM[Double]): Array[BDM[Double]] = { + val outputs = new Array[BDM[Double]](layerModels.length) + outputs(0) = layerModels(0).eval(data) + for (i <- 1 until layerModels.length) { + outputs(i) = layerModels(i).eval(outputs(i-1)) + } + outputs + } + + override def computeGradient( + data: BDM[Double], + target: BDM[Double], + cumGradient: Vector, + realBatchSize: Int): Double = { + val outputs = forward(data) + val deltas = new Array[BDM[Double]](layerModels.length) + val L = layerModels.length - 1 + val (newE, newError) = layerModels.last match { + case flm: FunctionalLayerModel => flm.error(outputs.last, target) + case _ => + throw new UnsupportedOperationException("Non-functional layer not supported at the top") + } + deltas(L) = new BDM[Double](0, 0) + deltas(L - 1) = newE + for (i <- (L - 2) to (0, -1)) { + deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) + } + val grads = new Array[Array[Double]](layerModels.length) + for (i <- 0 until layerModels.length) { + val input = if (i==0) data else outputs(i - 1) + grads(i) = layerModels(i).grad(deltas(i), input) + } + // update cumGradient + val cumGradientArray = cumGradient.toArray + var offset = 0 + // TODO: extract roll + for (i <- 0 until grads.length) { + val gradArray = grads(i) + var k = 0 + while (k < gradArray.length) { + cumGradientArray(offset + k) += gradArray(k) + k += 1 + } + offset += gradArray.length + } + newError + } + + // TODO: do we really need to copy the weights? they should be read-only + override def weights(): Vector = { + // TODO: extract roll + var size = 0 + for (i <- 0 until layerModels.length) { + size += layerModels(i).size + } + val array = new Array[Double](size) + var offset = 0 + for (i <- 0 until layerModels.length) { + val layerWeights = layerModels(i).weights().toArray + System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) + offset += layerWeights.length + } + Vectors.dense(array) + } + + override def predict(data: Vector): Vector = { + val size = data.size + val result = forward(new BDM[Double](size, 1, data.toArray)) + Vectors.dense(result.last.toArray) + } +} + +/** + * Fabric for feed forward ANN models + */ +private[ann] object FeedForwardModel { + + /** + * Creates a model from a topology and weights + * @param topology topology + * @param weights weights + * @return model + */ + def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).getInstance(weights, offset) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } + + /** + * Creates a model given a topology and seed + * @param topology topology + * @param seed seed for generating the weights + * @return model + */ + def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(seed) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } +} + +/** + * Neural network gradient. Does nothing but calling Model's gradient + * @param topology topology + * @param dataStacker data stacker + */ +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val (input, target, realBatchSize) = dataStacker.unstack(data) + val model = topology.getInstance(weights) + model.computeGradient(input, target, cumGradient, realBatchSize) + } +} + +/** + * Stacks pairs of training samples (input, output) in one vector allowing them to pass + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks + * or matrices of inputs and outputs and then stack them in one vector. + * This can be used for further batch computations after unstacking. + * @param stackSize stack size + * @param inputSize size of the input vectors + * @param outputSize size of the output vectors + */ +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) + extends Serializable { + + /** + * Stacks the data + * @param data RDD of vector pairs + * @return RDD of double (always zero) and vector that contains the stacked vectors + */ + def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { + val stackedData = if (stackSize == 1) { + data.map { v => + (0.0, + Vectors.fromBreeze(BDV.vertcat( + v._1.toBreeze.toDenseVector, + v._2.toBreeze.toDenseVector)) + ) } + } else { + data.mapPartitions { it => + it.grouped(stackSize).map { seq => + val size = seq.size + val bigVector = new Array[Double](inputSize * size + outputSize * size) + var i = 0 + seq.foreach { case (in, out) => + System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize) + System.arraycopy(out.toArray, 0, bigVector, + inputSize * size + i * outputSize, outputSize) + i += 1 + } + (0.0, Vectors.dense(bigVector)) + } + } + } + stackedData + } + + /** + * Unstack the stacked vectors into matrices for batch operations + * @param data stacked vector + * @return pair of matrices holding input and output data and the real stack size + */ + def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = { + val arrData = data.toArray + val realStackSize = arrData.length / (inputSize + outputSize) + val input = new BDM(inputSize, realStackSize, arrData) + val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize) + (input, target, realStackSize) + } +} + +/** + * Simple updater + */ +private[ann] class ANNUpdater extends Updater { + + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { + val thisIterStepSize = stepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) + } +} + +/** + * MLlib-style trainer class that trains a network given the data and topology + * @param topology topology of ANN + * @param inputSize input size + * @param outputSize output size + */ +private[ml] class FeedForwardTrainer( + topology: Topology, + val inputSize: Int, + val outputSize: Int) extends Serializable { + + // TODO: what if we need to pass random seed? + private var _weights = topology.getInstance(11L).weights() + private var _stackSize = 128 + private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) + private var _gradient: Gradient = new ANNGradient(topology, dataStacker) + private var _updater: Updater = new ANNUpdater() + private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + + /** + * Returns weights + * @return weights + */ + def getWeights: Vector = _weights + + /** + * Sets weights + * @param value weights + * @return trainer + */ + def setWeights(value: Vector): FeedForwardTrainer = { + _weights = value + this + } + + /** + * Sets the stack size + * @param value stack size + * @return trainer + */ + def setStackSize(value: Int): FeedForwardTrainer = { + _stackSize = value + dataStacker = new DataStacker(value, inputSize, outputSize) + this + } + + /** + * Sets the SGD optimizer + * @return SGD optimizer + */ + def SGDOptimizer: GradientDescent = { + val sgd = new GradientDescent(_gradient, _updater) + optimizer = sgd + sgd + } + + /** + * Sets the LBFGS optimizer + * @return LBGS optimizer + */ + def LBFGSOptimizer: LBFGS = { + val lbfgs = new LBFGS(_gradient, _updater) + optimizer = lbfgs + lbfgs + } + + /** + * Sets the updater + * @param value updater + * @return trainer + */ + def setUpdater(value: Updater): FeedForwardTrainer = { + _updater = value + updateUpdater(value) + this + } + + /** + * Sets the gradient + * @param value gradient + * @return trainer + */ + def setGradient(value: Gradient): FeedForwardTrainer = { + _gradient = value + updateGradient(value) + this + } + + private[this] def updateGradient(gradient: Gradient): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setGradient(gradient) + case sgd: GradientDescent => sgd.setGradient(gradient) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + private[this] def updateUpdater(updater: Updater): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setUpdater(updater) + case sgd: GradientDescent => sgd.setUpdater(updater) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + /** + * Trains the ANN + * @param data RDD of input and output vector pairs + * @return model + */ + def train(data: RDD[(Vector, Vector)]): TopologyModel = { + val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) + topology.getInstance(newWeights) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala new file mode 100644 index 000000000000..457c15830fd3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} + +/** + * :: DeveloperApi :: + * Attributes that describe a vector ML column. + * + * @param name name of the attribute group (the ML column name) + * @param numAttributes optional number of attributes. At most one of `numAttributes` and `attrs` + * can be defined. + * @param attrs optional array of attributes. Attribute will be copied with their corresponding + * indices in the array. + */ +@DeveloperApi +class AttributeGroup private ( + val name: String, + val numAttributes: Option[Int], + attrs: Option[Array[Attribute]]) extends Serializable { + + require(name.nonEmpty, "Cannot have an empty string for name.") + require(!(numAttributes.isDefined && attrs.isDefined), + "Cannot have both numAttributes and attrs defined.") + + /** + * Creates an attribute group without attribute info. + * @param name name of the attribute group + */ + def this(name: String) = this(name, None, None) + + /** + * Creates an attribute group knowing only the number of attributes. + * @param name name of the attribute group + * @param numAttributes number of attributes + */ + def this(name: String, numAttributes: Int) = this(name, Some(numAttributes), None) + + /** + * Creates an attribute group with attributes. + * @param name name of the attribute group + * @param attrs array of attributes. Attributes will be copied with their corresponding indices in + * the array. + */ + def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs)) + + /** + * Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined. + */ + val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) => + attr.withIndex(i) + }.toArray) + + private lazy val nameToIndex: Map[String, Int] = { + attributes.map(_.view.flatMap { attr => + attr.name.map(_ -> attr.index.get) + }.toMap).getOrElse(Map.empty) + } + + /** Size of the attribute group. Returns -1 if the size is unknown. */ + def size: Int = { + if (numAttributes.isDefined) { + numAttributes.get + } else if (attributes.isDefined) { + attributes.get.length + } else { + -1 + } + } + + /** Test whether this attribute group contains a specific attribute. */ + def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName) + + /** Index of an attribute specified by name. */ + def indexOf(attrName: String): Int = nameToIndex(attrName) + + /** Gets an attribute by its name. */ + def apply(attrName: String): Attribute = { + attributes.get(indexOf(attrName)) + } + + /** Gets an attribute by its name. */ + def getAttr(attrName: String): Attribute = this(attrName) + + /** Gets an attribute by its index. */ + def apply(attrIndex: Int): Attribute = attributes.get(attrIndex) + + /** Gets an attribute by its index. */ + def getAttr(attrIndex: Int): Attribute = this(attrIndex) + + /** Converts to metadata without name. */ + private[attribute] def toMetadataImpl: Metadata = { + import AttributeKeys._ + val bldr = new MetadataBuilder() + if (attributes.isDefined) { + val numericMetadata = ArrayBuffer.empty[Metadata] + val nominalMetadata = ArrayBuffer.empty[Metadata] + val binaryMetadata = ArrayBuffer.empty[Metadata] + attributes.get.foreach { + case numeric: NumericAttribute => + // Skip default numeric attributes. + if (numeric.withoutIndex != NumericAttribute.defaultAttr) { + numericMetadata += numeric.toMetadataImpl(withType = false) + } + case nominal: NominalAttribute => + nominalMetadata += nominal.toMetadataImpl(withType = false) + case binary: BinaryAttribute => + binaryMetadata += binary.toMetadataImpl(withType = false) + case UnresolvedAttribute => + } + val attrBldr = new MetadataBuilder + if (numericMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Numeric.name, numericMetadata.toArray) + } + if (nominalMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Nominal.name, nominalMetadata.toArray) + } + if (binaryMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Binary.name, binaryMetadata.toArray) + } + bldr.putMetadata(ATTRIBUTES, attrBldr.build()) + bldr.putLong(NUM_ATTRIBUTES, attributes.get.length) + } else if (numAttributes.isDefined) { + bldr.putLong(NUM_ATTRIBUTES, numAttributes.get) + } + bldr.build() + } + + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata)) + } + + /** Converts to a StructField. */ + def toStructField(): StructField = toStructField(Metadata.empty) + + override def equals(other: Any): Boolean = { + other match { + case o: AttributeGroup => + (name == o.name) && + (numAttributes == o.numAttributes) && + (attributes.map(_.toSeq) == o.attributes.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + numAttributes.hashCode + sum = 37 * sum + attributes.map(_.toSeq).hashCode + sum + } +} + +/** + * :: DeveloperApi :: + * Factory methods to create attribute groups. + */ +@DeveloperApi +object AttributeGroup { + + import AttributeKeys._ + + /** Creates an attribute group from a [[Metadata]] instance with name. */ + private[attribute] def fromMetadata(metadata: Metadata, name: String): AttributeGroup = { + import org.apache.spark.ml.attribute.AttributeType._ + if (metadata.contains(ATTRIBUTES)) { + val numAttrs = metadata.getLong(NUM_ATTRIBUTES).toInt + val attributes = new Array[Attribute](numAttrs) + val attrMetadata = metadata.getMetadata(ATTRIBUTES) + if (attrMetadata.contains(Numeric.name)) { + attrMetadata.getMetadataArray(Numeric.name) + .map(NumericAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + if (attrMetadata.contains(Nominal.name)) { + attrMetadata.getMetadataArray(Nominal.name) + .map(NominalAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + if (attrMetadata.contains(Binary.name)) { + attrMetadata.getMetadataArray(Binary.name) + .map(BinaryAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + var i = 0 + while (i < numAttrs) { + if (attributes(i) == null) { + attributes(i) = NumericAttribute.defaultAttr + } + i += 1 + } + new AttributeGroup(name, attributes) + } else if (metadata.contains(NUM_ATTRIBUTES)) { + new AttributeGroup(name, metadata.getLong(NUM_ATTRIBUTES).toInt) + } else { + new AttributeGroup(name) + } + } + + /** Creates an attribute group from a [[StructField]] instance. */ + def fromStructField(field: StructField): AttributeGroup = { + require(field.dataType == new VectorUDT) + if (field.metadata.contains(ML_ATTR)) { + fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name) + } else { + new AttributeGroup(field.name) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala new file mode 100644 index 000000000000..f714f7becc7e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +/** + * Keys used to store attributes. + */ +private[attribute] object AttributeKeys { + val ML_ATTR: String = "ml_attr" + val TYPE: String = "type" + val NAME: String = "name" + val INDEX: String = "idx" + val MIN: String = "min" + val MAX: String = "max" + val STD: String = "std" + val SPARSITY: String = "sparsity" + val ORDINAL: String = "ord" + val VALUES: String = "vals" + val NUM_VALUES: String = "num_vals" + val ATTRIBUTES: String = "attrs" + val NUM_ATTRIBUTES: String = "num_attrs" +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala new file mode 100644 index 000000000000..5c7089b49167 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], + * and [[AttributeType$#Binary]]. + */ +@DeveloperApi +sealed abstract class AttributeType(val name: String) + +@DeveloperApi +object AttributeType { + + /** Numeric type. */ + val Numeric: AttributeType = { + case object Numeric extends AttributeType("numeric") + Numeric + } + + /** Nominal type. */ + val Nominal: AttributeType = { + case object Nominal extends AttributeType("nominal") + Nominal + } + + /** Binary type. */ + val Binary: AttributeType = { + case object Binary extends AttributeType("binary") + Binary + } + + /** Unresolved type. */ + val Unresolved: AttributeType = { + case object Unresolved extends AttributeType("unresolved") + Unresolved + } + + /** + * Gets the [[AttributeType]] object from its name. + * @param name attribute type name: "numeric", "nominal", or "binary" + */ + def fromName(name: String): AttributeType = { + if (name == Numeric.name) { + Numeric + } else if (name == Nominal.name) { + Nominal + } else if (name == Binary.name) { + Binary + } else if (name == Unresolved.name) { + Unresolved + } else { + throw new IllegalArgumentException(s"Cannot recognize type $name.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala new file mode 100644 index 000000000000..e479f169021d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -0,0 +1,597 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import scala.annotation.varargs + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} + +/** + * :: DeveloperApi :: + * Abstract class for ML attributes. + */ +@DeveloperApi +sealed abstract class Attribute extends Serializable { + + name.foreach { n => + require(n.nonEmpty, "Cannot have an empty string for name.") + } + index.foreach { i => + require(i >= 0, s"Index cannot be negative but got $i") + } + + /** Attribute type. */ + def attrType: AttributeType + + /** Name of the attribute. None if it is not set. */ + def name: Option[String] + + /** Copy with a new name. */ + def withName(name: String): Attribute + + /** Copy without the name. */ + def withoutName: Attribute + + /** Index of the attribute. None if it is not set. */ + def index: Option[Int] + + /** Copy with a new index. */ + def withIndex(index: Int): Attribute + + /** Copy without the index. */ + def withoutIndex: Attribute + + /** + * Tests whether this attribute is numeric, true for [[NumericAttribute]] and [[BinaryAttribute]]. + */ + def isNumeric: Boolean + + /** + * Tests whether this attribute is nominal, true for [[NominalAttribute]] and [[BinaryAttribute]]. + */ + def isNominal: Boolean + + /** + * Converts this attribute to [[Metadata]]. + * @param withType whether to include the type info + */ + private[attribute] def toMetadataImpl(withType: Boolean): Metadata + + /** + * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to + * save space, because numeric type is the default attribute type. For nominal and binary + * attributes, the type info is included. + */ + private[attribute] def toMetadataImpl(): Metadata = { + if (attrType == AttributeType.Numeric) { + toMetadataImpl(withType = false) + } else { + toMetadataImpl(withType = true) + } + } + + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl()) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + + /** + * Converts to a [[StructField]] with some existing metadata. + * @param existingMetadata existing metadata to carry over + */ + def toStructField(existingMetadata: Metadata): StructField = { + val newMetadata = new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl()) + .build() + StructField(name.get, DoubleType, nullable = false, newMetadata) + } + + /** Converts to a [[StructField]]. */ + def toStructField(): StructField = toStructField(Metadata.empty) + + override def toString: String = toMetadataImpl(withType = true).toString +} + +/** Trait for ML attribute factories. */ +private[attribute] trait AttributeFactory { + + /** + * Creates an [[Attribute]] from a [[Metadata]] instance. + */ + private[attribute] def fromMetadata(metadata: Metadata): Attribute + + /** + * Creates an [[Attribute]] from a [[StructField]] instance. + */ + def fromStructField(field: StructField): Attribute = { + require(field.dataType.isInstanceOf[NumericType]) + val metadata = field.metadata + val mlAttr = AttributeKeys.ML_ATTR + if (metadata.contains(mlAttr)) { + fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name) + } else { + UnresolvedAttribute + } + } +} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Attribute extends AttributeFactory { + + private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val attrType = if (metadata.contains(TYPE)) { + metadata.getString(TYPE) + } else { + AttributeType.Numeric.name + } + getFactory(attrType).fromMetadata(metadata) + } + + /** Gets the attribute factory given the attribute type name. */ + private def getFactory(attrType: String): AttributeFactory = { + if (attrType == AttributeType.Numeric.name) { + NumericAttribute + } else if (attrType == AttributeType.Nominal.name) { + NominalAttribute + } else if (attrType == AttributeType.Binary.name) { + BinaryAttribute + } else { + throw new IllegalArgumentException(s"Cannot recognize type $attrType.") + } + } +} + + +/** + * :: DeveloperApi :: + * A numeric attribute with optional summary statistics. + * @param name optional name + * @param index optional index + * @param min optional min value + * @param max optional max value + * @param std optional standard deviation + * @param sparsity optional sparsity (ratio of zeros) + */ +@DeveloperApi +class NumericAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val min: Option[Double] = None, + val max: Option[Double] = None, + val std: Option[Double] = None, + val sparsity: Option[Double] = None) extends Attribute { + + std.foreach { s => + require(s >= 0.0, s"Standard deviation cannot be negative but got $s.") + } + sparsity.foreach { s => + require(s >= 0.0 && s <= 1.0, s"Sparsity must be in [0, 1] but got $s.") + } + + override def attrType: AttributeType = AttributeType.Numeric + + override def withName(name: String): NumericAttribute = copy(name = Some(name)) + override def withoutName: NumericAttribute = copy(name = None) + + override def withIndex(index: Int): NumericAttribute = copy(index = Some(index)) + override def withoutIndex: NumericAttribute = copy(index = None) + + /** Copy with a new min value. */ + def withMin(min: Double): NumericAttribute = copy(min = Some(min)) + + /** Copy without the min value. */ + def withoutMin: NumericAttribute = copy(min = None) + + + /** Copy with a new max value. */ + def withMax(max: Double): NumericAttribute = copy(max = Some(max)) + + /** Copy without the max value. */ + def withoutMax: NumericAttribute = copy(max = None) + + /** Copy with a new standard deviation. */ + def withStd(std: Double): NumericAttribute = copy(std = Some(std)) + + /** Copy without the standard deviation. */ + def withoutStd: NumericAttribute = copy(std = None) + + /** Copy with a new sparsity. */ + def withSparsity(sparsity: Double): NumericAttribute = copy(sparsity = Some(sparsity)) + + /** Copy without the sparsity. */ + def withoutSparsity: NumericAttribute = copy(sparsity = None) + + /** Copy without summary statistics. */ + def withoutSummary: NumericAttribute = copy(min = None, max = None, std = None, sparsity = None) + + override def isNumeric: Boolean = true + + override def isNominal: Boolean = false + + /** Convert this attribute to metadata. */ + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder() + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + min.foreach(bldr.putDouble(MIN, _)) + max.foreach(bldr.putDouble(MAX, _)) + std.foreach(bldr.putDouble(STD, _)) + sparsity.foreach(bldr.putDouble(SPARSITY, _)) + bldr.build() + } + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + min: Option[Double] = min, + max: Option[Double] = max, + std: Option[Double] = std, + sparsity: Option[Double] = sparsity): NumericAttribute = { + new NumericAttribute(name, index, min, max, std, sparsity) + } + + override def equals(other: Any): Boolean = { + other match { + case o: NumericAttribute => + (name == o.name) && + (index == o.index) && + (min == o.min) && + (max == o.max) && + (std == o.std) && + (sparsity == o.sparsity) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + min.hashCode + sum = 37 * sum + max.hashCode + sum = 37 * sum + std.hashCode + sum = 37 * sum + sparsity.hashCode + sum + } +} + +/** + * :: DeveloperApi :: + * Factory methods for numeric attributes. + */ +@DeveloperApi +object NumericAttribute extends AttributeFactory { + + /** The default numeric attribute. */ + val defaultAttr: NumericAttribute = new NumericAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): NumericAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val min = if (metadata.contains(MIN)) Some(metadata.getDouble(MIN)) else None + val max = if (metadata.contains(MAX)) Some(metadata.getDouble(MAX)) else None + val std = if (metadata.contains(STD)) Some(metadata.getDouble(STD)) else None + val sparsity = if (metadata.contains(SPARSITY)) Some(metadata.getDouble(SPARSITY)) else None + new NumericAttribute(name, index, min, max, std, sparsity) + } +} + +/** + * :: DeveloperApi :: + * A nominal attribute. + * @param name optional name + * @param index optional index + * @param isOrdinal whether this attribute is ordinal (optional) + * @param numValues optional number of values. At most one of `numValues` and `values` can be + * defined. + * @param values optional values. At most one of `numValues` and `values` can be defined. + */ +@DeveloperApi +class NominalAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val isOrdinal: Option[Boolean] = None, + val numValues: Option[Int] = None, + val values: Option[Array[String]] = None) extends Attribute { + + numValues.foreach { n => + require(n >= 0, s"numValues cannot be negative but got $n.") + } + require(!(numValues.isDefined && values.isDefined), + "Cannot have both numValues and values defined.") + + override def attrType: AttributeType = AttributeType.Nominal + + override def isNumeric: Boolean = false + + override def isNominal: Boolean = true + + private lazy val valueToIndex: Map[String, Int] = { + values.map(_.zipWithIndex.toMap).getOrElse(Map.empty) + } + + /** Index of a specific value. */ + def indexOf(value: String): Int = { + valueToIndex(value) + } + + /** Tests whether this attribute contains a specific value. */ + def hasValue(value: String): Boolean = valueToIndex.contains(value) + + /** Gets a value given its index. */ + def getValue(index: Int): String = values.get(index) + + override def withName(name: String): NominalAttribute = copy(name = Some(name)) + override def withoutName: NominalAttribute = copy(name = None) + + override def withIndex(index: Int): NominalAttribute = copy(index = Some(index)) + override def withoutIndex: NominalAttribute = copy(index = None) + + /** Copy with new values and empty `numValues`. */ + def withValues(values: Array[String]): NominalAttribute = { + copy(numValues = None, values = Some(values)) + } + + /** Copy with new values and empty `numValues`. */ + @varargs + def withValues(first: String, others: String*): NominalAttribute = { + copy(numValues = None, values = Some((first +: others).toArray)) + } + + /** Copy without the values. */ + def withoutValues: NominalAttribute = { + copy(values = None) + } + + /** Copy with a new `numValues` and empty `values`. */ + def withNumValues(numValues: Int): NominalAttribute = { + copy(numValues = Some(numValues), values = None) + } + + /** Copy without the `numValues`. */ + def withoutNumValues: NominalAttribute = copy(numValues = None) + + /** + * Get the number of values, either from `numValues` or from `values`. + * Return None if unknown. + */ + def getNumValues: Option[Int] = { + if (numValues.nonEmpty) { + numValues + } else if (values.nonEmpty) { + Some(values.get.length) + } else { + None + } + } + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + isOrdinal: Option[Boolean] = isOrdinal, + numValues: Option[Int] = numValues, + values: Option[Array[String]] = values): NominalAttribute = { + new NominalAttribute(name, index, isOrdinal, numValues, values) + } + + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder() + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + isOrdinal.foreach(bldr.putBoolean(ORDINAL, _)) + numValues.foreach(bldr.putLong(NUM_VALUES, _)) + values.foreach(v => bldr.putStringArray(VALUES, v)) + bldr.build() + } + + override def equals(other: Any): Boolean = { + other match { + case o: NominalAttribute => + (name == o.name) && + (index == o.index) && + (isOrdinal == o.isOrdinal) && + (numValues == o.numValues) && + (values.map(_.toSeq) == o.values.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + isOrdinal.hashCode + sum = 37 * sum + numValues.hashCode + sum = 37 * sum + values.map(_.toSeq).hashCode + sum + } +} + +/** + * :: DeveloperApi :: + * Factory methods for nominal attributes. + */ +@DeveloperApi +object NominalAttribute extends AttributeFactory { + + /** The default nominal attribute. */ + final val defaultAttr: NominalAttribute = new NominalAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): NominalAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val isOrdinal = if (metadata.contains(ORDINAL)) Some(metadata.getBoolean(ORDINAL)) else None + val numValues = + if (metadata.contains(NUM_VALUES)) Some(metadata.getLong(NUM_VALUES).toInt) else None + val values = + if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None + new NominalAttribute(name, index, isOrdinal, numValues, values) + } +} + +/** + * :: DeveloperApi :: + * A binary attribute. + * @param name optional name + * @param index optional index + * @param values optionla values. If set, its size must be 2. + */ +@DeveloperApi +class BinaryAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val values: Option[Array[String]] = None) + extends Attribute { + + values.foreach { v => + require(v.length == 2, s"Number of values must be 2 for a binary attribute but got ${v.toSeq}.") + } + + override def attrType: AttributeType = AttributeType.Binary + + override def isNumeric: Boolean = true + + override def isNominal: Boolean = true + + override def withName(name: String): BinaryAttribute = copy(name = Some(name)) + override def withoutName: BinaryAttribute = copy(name = None) + + override def withIndex(index: Int): BinaryAttribute = copy(index = Some(index)) + override def withoutIndex: BinaryAttribute = copy(index = None) + + /** + * Copy with new values. + * @param negative name for negative + * @param positive name for positive + */ + def withValues(negative: String, positive: String): BinaryAttribute = + copy(values = Some(Array(negative, positive))) + + /** Copy without the values. */ + def withoutValues: BinaryAttribute = copy(values = None) + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + values: Option[Array[String]] = values): BinaryAttribute = { + new BinaryAttribute(name, index, values) + } + + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + values.foreach(v => bldr.putStringArray(VALUES, v)) + bldr.build() + } + + override def equals(other: Any): Boolean = { + other match { + case o: BinaryAttribute => + (name == o.name) && + (index == o.index) && + (values.map(_.toSeq) == o.values.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + values.map(_.toSeq).hashCode + sum + } +} + +/** + * :: DeveloperApi :: + * Factory methods for binary attributes. + */ +@DeveloperApi +object BinaryAttribute extends AttributeFactory { + + /** The default binary attribute. */ + final val defaultAttr: BinaryAttribute = new BinaryAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): BinaryAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val values = + if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None + new BinaryAttribute(name, index, values) + } +} + +/** + * :: DeveloperApi :: + * An unresolved attribute. + */ +@DeveloperApi +object UnresolvedAttribute extends Attribute { + + override def attrType: AttributeType = AttributeType.Unresolved + + override def withIndex(index: Int): Attribute = this + + override def isNumeric: Boolean = false + + override def withoutIndex: Attribute = this + + override def isNominal: Boolean = false + + override def name: Option[String] = None + + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + Metadata.empty + } + + override def withoutName: Attribute = this + + override def index: Option[Int] = None + + override def withName(name: String): Attribute = this + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java new file mode 100644 index 000000000000..e3474f3c1d3f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The content here should be in sync with `package.scala`. + +/** + *

    ML attributes

    + * + * The ML pipeline API uses {@link org.apache.spark.sql.DataFrame}s as ML datasets. + * Each dataset consists of typed columns, e.g., string, double, vector, etc. + * However, knowing only the column type may not be sufficient to handle the data properly. + * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, + * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may + * want to know the names and types of features stored in a vector column. + * ML attributes are used to provide additional information to describe columns in a dataset. + * + *

    ML columns

    + * + * A column with ML attributes attached is called an ML column. + * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column + * of double values or a vector column. + * Columns of other types must be encoded into ML columns using transformers. + * We use {@link org.apache.spark.ml.attribute.Attribute} to describe a scalar ML column, and + * {@link org.apache.spark.ml.attribute.AttributeGroup} to describe a vector ML column. + * ML attributes are stored in the metadata field of the column schema. + */ +package org.apache.spark.ml.attribute; diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala new file mode 100644 index 000000000000..7ac21d7d563f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} + +/** + * ==ML attributes== + * + * The ML pipeline API uses [[DataFrame]]s as ML datasets. + * Each dataset consists of typed columns, e.g., string, double, vector, etc. + * However, knowing only the column type may not be sufficient to handle the data properly. + * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, + * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may + * want to know the names and types of features stored in a vector column. + * ML attributes are used to provide additional information to describe columns in a dataset. + * + * ===ML columns=== + * + * A column with ML attributes attached is called an ML column. + * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column + * of double values or a vector column. + * Columns of other types must be encoded into ML columns using transformers. + * We use [[Attribute]] to describe a scalar ML column, and [[AttributeGroup]] to describe a vector + * ML column. + * ML attributes are stored in the metadata field of the column schema. + */ +package object attribute diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala new file mode 100644 index 000000000000..45df557a8990 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} + + +/** + * (private[spark]) Params for classification. + */ +private[spark] trait ClassifierParams + extends PredictorParams with HasRawPredictionCol { + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) + } +} + +/** + * :: DeveloperApi :: + * + * Single-label binary or multiclass classification. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + */ +@DeveloperApi +abstract class Classifier[ + FeaturesType, + E <: Classifier[FeaturesType, E, M], + M <: ClassificationModel[FeaturesType, M]] + extends Predictor[FeaturesType, E, M] with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: DeveloperApi :: + * + * Model produced by a [[Classifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + */ +@DeveloperApi +abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + + /** Number of classes (values which the label can take). */ + def numClasses: Int + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @return transformed dataset + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = dataset + var numColsOutput = 0 + if (getRawPredictionCol != "") { + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) + numColsOutput += 1 + } + if (getPredictionCol != "") { + val predUDF = if (getRawPredictionCol != "") { + udf(raw2prediction _).apply(col(getRawPredictionCol)) + } else { + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col(getFeaturesCol)) + } + outputData = outputData.withColumn(getPredictionCol, predUDF) + numColsOutput += 1 + } + + if (numColsOutput == 0) { + logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + * + * This default implementation for classification predicts the index of the maximum value + * from [[predictRaw()]]. + */ + override protected def predict(features: FeaturesType): Double = { + raw2prediction(predictRaw(features)) + } + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + protected def predictRaw(features: FeaturesType): Vector + + /** + * Given a vector of raw predictions, select the predicted label. + * This may be overridden to support thresholds which favor particular labels. + * @return predicted label + */ + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala new file mode 100644 index 000000000000..6f70b96b17ec --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} +import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@Experimental +final class DecisionTreeClassifier(override val uid: String) + extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + with DecisionTreeParams with TreeClassifierParams { + + def this() = this(Identifiable.randomUID("dtc")) + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + + s" with invalid label column ${$(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = getOldStrategy(categoricalFeatures, numClasses) + val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeClassificationModel] + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, + subsamplingRate = 1.0) + } + + override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) +} + +@Experimental +object DecisionTreeClassifier { + /** Accessor for supported impurities: entropy, gini */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities +} + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@Experimental +final class DecisionTreeClassificationModel private[ml] ( + override val uid: String, + override val rootNode: Node, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + /** + * Construct a decision tree classification model. + * @param rootNode Root node of tree, with other nodes attached. + */ + private[ml] def this(rootNode: Node, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numClasses) + + override protected def predict(features: Vector): Double = { + rootNode.predictImpl(features).prediction + } + + override protected def predictRaw(features: Vector): Vector = { + Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + + override def copy(extra: ParamMap): DecisionTreeClassificationModel = { + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + .setParent(parent) + } + + override def toString: String = { + s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) + } +} + +private[ml] object DecisionTreeClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeClassifier, + categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, + s"Cannot convert non-classification DecisionTreeModel (old API) to" + + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") + new DecisionTreeClassificationModel(uid, rootNode, -1) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala new file mode 100644 index 000000000000..3073a2a61ce8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + */ +@Experimental +final class GBTClassifier(override val uid: String) + extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + with GBTParams with TreeClassifierParams with Logging { + + def this() = this(Identifiable.randomUID("gbtc")) + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTClassifier.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTClassifier: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group setParam */ + def setLossType(value: String): this.type = set(lossType, value) + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } + + override protected def train(dataset: DataFrame): GBTClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("GBTClassifier was given input" + + s" with invalid label column ${$(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + require(numClasses == 2, + s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) + } + + override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) +} + +@Experimental +object GBTClassifier { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@Experimental +final class GBTClassificationModel( + override val uid: String, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override protected def predict(features: Vector): Double = { + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override def copy(extra: ParamMap): GBTClassificationModel = { + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) + } + + override def toString: String = { + s"GBTClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTClassifier, + categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) + } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") + new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 18be35ad5945..21fbe38ca823 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -17,129 +17,868 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql._ -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** - * :: AlphaComponent :: * Params for logistic regression. */ -@AlphaComponent -private[classification] trait LogisticRegressionParams extends Params - with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol - with HasScoreCol with HasPredictionCol { - - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param paramMap additional parameters - * @param fitting whether this is in fitting - * @return output schema - */ - protected def validateAndTransformSchema( - schema: StructType, - paramMap: ParamMap, - fitting: Boolean): StructType = { - val map = this.paramMap ++ paramMap - val featuresType = schema(map(featuresCol)).dataType - // TODO: Support casting Array[Double] and Array[Float] to Vector. - require(featuresType.isInstanceOf[VectorUDT], - s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") - if (fitting) { - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") - } - val fieldNames = schema.fieldNames - require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") - require(!fieldNames.contains(map(predictionCol)), - s"Prediction column ${map(predictionCol)} already exists.") - val outputFields = schema.fields ++ Seq( - StructField(map(scoreCol), DoubleType, false), - StructField(map(predictionCol), DoubleType, false)) - StructType(outputFields) +private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams + with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol + with HasStandardization with HasThreshold { + + /** + * Set threshold in binary classification, in range [0, 1]. + * + * If the estimated probability of class label 1 is > threshold, then predict 1, else 0. + * A high threshold encourages the model to predict 0 more often; + * a low threshold encourages the model to predict 1 more often. + * + * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`. + * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * Default is 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = { + if (isSet(thresholds)) clear(thresholds) + set(threshold, value) + } + + /** + * Get threshold for binary classification. + * + * If [[threshold]] is set, returns that value. + * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * Otherwise, returns [[threshold]] default value. + * + * @group getParam + * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. + */ + override def getThreshold: Double = { + checkThresholdConsistency() + if (isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(",")) + 1.0 / (1.0 + ts(0) / ts(1)) + } else { + $(threshold) + } + } + + /** + * Set thresholds in multiclass (or binary) classification to adjust the probability of + * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * The class with largest value p/t is predicted, where p is the original probability of that + * class and t is the class' threshold. + * + * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * @group setParam + */ + def setThresholds(value: Array[Double]): this.type = { + if (isSet(threshold)) clear(threshold) + set(thresholds, value) + } + + /** + * Get thresholds for binary or multiclass classification. + * + * If [[thresholds]] is set, return its value. + * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary + * classification: (1-threshold, threshold). + * If neither are set, throw an exception. + * + * @group getParam + */ + override def getThresholds: Array[Double] = { + checkThresholdConsistency() + if (!isSet(thresholds) && isSet(threshold)) { + val t = $(threshold) + Array(1-t, t) + } else { + $(thresholds) + } + } + + /** + * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. + * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent + */ + protected def checkThresholdConsistency(): Unit = { + if (isSet(threshold) && isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + + s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + + s" classification, but Param thresholds is set with length ${ts.length}." + + " Clear one Param value to fix this problem.") + val t = 1.0 / (1.0 + ts(0) / ts(1)) + require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + + s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") + } + } + + override def validateParams(): Unit = { + checkThresholdConsistency() } } /** + * :: Experimental :: * Logistic regression. + * Currently, this class only supports binary classification. It will support multiclass + * in the future. */ -class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { +@Experimental +class LogisticRegression(override val uid: String) + extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] + with LogisticRegressionParams with Logging { - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) + def this() = this(Identifiable.randomUID("logreg")) + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) - def setLabelCol(value: String): this.type = set(labelCol, value) - def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - - override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol), map(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - }.persist(StorageLevel.MEMORY_AND_DISK) - val lr = new LogisticRegressionWithLBFGS - lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) - val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) - instances.unpersist() - // copy model params - Params.inheritValues(map, this, lrm) - lrm - } - - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = true) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Whether to fit an intercept term. + * Default is true. + * @group setParam + */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Whether to standardize the training features before fitting the model. + * The coefficients of models will be always returned on the original scale, + * so it will be transparent for users. Note that with/without standardization, + * the models should be always converged to the same solution when no regularization + * is applied. In R's GLMNET package, the default behavior is true as well. + * Default is true. + * @group setParam + */ + def setStandardization(value: Boolean): this.type = set(standardization, value) + setDefault(standardization -> true) + + override def setThreshold(value: Double): this.type = super.setThreshold(value) + + override def getThreshold: Double = super.getThreshold + + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + + override protected def train(dataset: DataFrame): LogisticRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractLabeledPoints(dataset).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val (summarizer, labelSummarizer) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( + seqOp = (c, v) => (c, v) match { + case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer), + (label: Double, features: Vector)) => + (summarizer.add(features), labelSummarizer.add(label)) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((summarizer1: MultivariateOnlineSummarizer, + classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer, + classSummarizer2: MultiClassSummarizer)) => + (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2)) + }) + + val histogram = labelSummarizer.histogram + val numInvalid = labelSummarizer.countInvalid + val numClasses = histogram.length + val numFeatures = summarizer.mean.size + + if (numInvalid != 0) { + val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + + s"Found $numInvalid invalid labels." + logError(msg) + throw new SparkException(msg) + } + + if (numClasses > 2) { + val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + + s"binary classification. Found $numClasses in the input dataset." + logError(msg) + throw new SparkException(msg) + } + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + val regParamL1 = $(elasticNetParam) * $(regParam) + val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) + + val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization), + featuresStd, featuresMean, regParamL2) + + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + } else { + def regParamL1Fun = (index: Int) => { + // Remove the L1 penalization on the intercept + if (index == numFeatures) { + 0.0 + } else { + if ($(standardization)) { + regParamL1 + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0 + } + } + } + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) + } + + val initialWeightsWithIntercept = + Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) + + if ($(fitIntercept)) { + /* + For binary logistic regression, when we initialize the weights as zeros, + it will converge faster if we initialize the intercept such that + it follows the distribution of the labels. + + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} + */ + initialWeightsWithIntercept.toArray(numFeatures) + = math.log(histogram(1).toDouble / histogram(0).toDouble) + } + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeightsWithIntercept.toBreeze.toDenseVector) + + val (weights, intercept, objectiveHistory) = { + /* + Note that in Logistic Regression, the objective history (loss + regularization) + is log-likelihood which is invariance under feature standardization. As a result, + the objective history from optimizer is the same as the one in the original space. + */ + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + logError(msg) + throw new SparkException(msg) + } + + /* + The weights are trained in the scaled space; we're converting them back to + the original space. + Note that the intercept in scaled space and original space is the same; + as a result, no scaling is needed. + */ + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < numFeatures) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } + + if ($(fitIntercept)) { + (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result()) + } else { + (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result()) + } + } + + if (handlePersistence) instances.unpersist() + + val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val logRegSummary = new BinaryLogisticRegressionTrainingSummary( + model.transform(dataset), + $(probabilityCol), + $(labelCol), + objectiveHistory) + model.setSummary(logRegSummary) } + + override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } /** - * :: AlphaComponent :: + * :: Experimental :: * Model produced by [[LogisticRegression]]. */ -@AlphaComponent +@Experimental class LogisticRegressionModel private[ml] ( - override val parent: LogisticRegression, - override val fittingParamMap: ParamMap, - weights: Vector) - extends Model[LogisticRegressionModel] with LogisticRegressionParams { - - def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = false) - } - - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val scoreFunction: Vector => Double = (v) => { - val margin = BLAS.dot(v, weights) - 1.0 / (1.0 + math.exp(-margin)) - } - val t = map(threshold) - val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } - dataset - .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) - .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) + override val uid: String, + val weights: Vector, + val intercept: Double) + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] + with LogisticRegressionParams { + + override def setThreshold(value: Double): this.type = super.setThreshold(value) + + override def getThreshold: Double = super.getThreshold + + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + + /** Margin (rawPrediction) for class label 1. For binary classification only. */ + private val margin: Vector => Double = (features) => { + BLAS.dot(features, weights) + intercept + } + + /** Score (probability) for class label 1. For binary classification only. */ + private val score: Vector => Double = (features) => { + val m = margin(features) + 1.0 / (1.0 + math.exp(-m)) + } + + override val numClasses: Int = 2 + + private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LogisticRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LogisticRegressionModel", + new NullPointerException()) + } + + private[classification] def setSummary( + summary: LogisticRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + } + + /** + * Predict label for the given feature vector. + * The behavior of this can be adjusted using [[thresholds]]. + */ + override protected def predict(features: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. + if (score(features) > getThreshold) 1 else 0 + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + while (i < size) { + dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i))) + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in LogisticRegressionModel:" + + " raw2probabilitiesInPlace encountered SparseVector") + } + } + + override protected def predictRaw(features: Vector): Vector = { + val m = margin(features) + Vectors.dense(-m, m) + } + + override def copy(extra: ParamMap): LogisticRegressionModel = { + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) + } + + override protected def raw2prediction(rawPrediction: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. + val t = getThreshold + val rawThreshold = if (t == 0.0) { + Double.NegativeInfinity + } else if (t == 1.0) { + Double.PositiveInfinity + } else { + math.log(t / (1.0 - t)) + } + if (rawPrediction(1) > rawThreshold) 1 else 0 + } + + override protected def probability2prediction(probability: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. + if (probability(1) > getThreshold) 1 else 0 + } +} + +/** + * MultiClassSummarizer computes the number of distinct labels and corresponding counts, + * and validates the data to see if the labels used for k class multi-label classification + * are in the range of {0, 1, ..., k - 1} in a online fashion. + * + * Two MultilabelSummarizer can be merged together to have a statistical summary of the + * corresponding joint dataset. + */ +private[classification] class MultiClassSummarizer extends Serializable { + private val distinctMap = new mutable.HashMap[Int, Long] + private var totalInvalidCnt: Long = 0L + + /** + * Add a new label into this MultilabelSummarizer, and update the distinct map. + * @param label The label for this data point. + * @return This MultilabelSummarizer + */ + def add(label: Double): this.type = { + if (label - label.toInt != 0.0 || label < 0) { + totalInvalidCnt += 1 + this + } + else { + val counts: Long = distinctMap.getOrElse(label.toInt, 0L) + distinctMap.put(label.toInt, counts + 1) + this + } + } + + /** + * Merge another MultilabelSummarizer, and update the distinct map. + * (Note that it will merge the smaller distinct map into the larger one using in-place + * merging, so either `this` or `other` object will be modified and returned.) + * + * @param other The other MultilabelSummarizer to be merged. + * @return Merged MultilabelSummarizer object. + */ + def merge(other: MultiClassSummarizer): MultiClassSummarizer = { + val (largeMap, smallMap) = if (this.distinctMap.size > other.distinctMap.size) { + (this, other) + } else { + (other, this) + } + smallMap.distinctMap.foreach { + case (key, value) => + val counts = largeMap.distinctMap.getOrElse(key, 0L) + largeMap.distinctMap.put(key, counts + value) + } + largeMap.totalInvalidCnt += smallMap.totalInvalidCnt + largeMap + } + + /** @return The total invalid input counts. */ + def countInvalid: Long = totalInvalidCnt + + /** @return The number of distinct labels in the input dataset. */ + def numClasses: Int = distinctMap.keySet.max + 1 + + /** @return The counts of each label in the input dataset. */ + def histogram: Array[Long] = { + val result = Array.ofDim[Long](numClasses) + var i = 0 + val len = result.length + while (i < len) { + result(i) = distinctMap.getOrElse(i, 0L) + i += 1 + } + result + } +} + +/** + * Abstraction for multinomial Logistic Regression Training results. + */ +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { + + /** objective function (scaled loss + regularization) at each iteration. */ + def objectiveHistory: Array[Double] + + /** Number of training iterations until termination */ + def totalIterations: Int = objectiveHistory.length + +} + +/** + * Abstraction for Logistic Regression Results for a given model. + */ +sealed trait LogisticRegressionSummary extends Serializable { + + /** Dataframe outputted by the model's `transform` method. */ + def predictions: DataFrame + + /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + def probabilityCol: String + + /** Field in "predictions" which gives the the true label of each sample. */ + def labelCol: String + +} + +/** + * :: Experimental :: + * Logistic regression training results. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each sample as a vector. + * @param labelCol field in "predictions" which gives the true label of each sample. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class BinaryLogisticRegressionTrainingSummary private[classification] ( + predictions: DataFrame, + probabilityCol: String, + labelCol: String, + val objectiveHistory: Array[Double]) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + with LogisticRegressionTrainingSummary { + +} + +/** + * :: Experimental :: + * Binary Logistic regression results for a given model. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each sample. + * @param labelCol field in "predictions" which gives the true label of each sample. + */ +@Experimental +class BinaryLogisticRegressionSummary private[classification] ( + @transient override val predictions: DataFrame, + override val probabilityCol: String, + override val labelCol: String) extends LogisticRegressionSummary { + + private val sqlContext = predictions.sqlContext + import sqlContext.implicits._ + + /** + * Returns a BinaryClassificationMetrics object. + */ + // TODO: Allow the user to vary the number of bins using a setBins method in + // BinaryClassificationMetrics. For now the default is set to 100. + @transient private val binaryMetrics = new BinaryClassificationMetrics( + predictions.select(probabilityCol, labelCol).map { + case Row(score: Vector, label: Double) => (score(1), label) + }, 100 + ) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an Dataframe having two fields (FPR, TPR) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() + + /** + * Returns the precision-recall curve, which is an Dataframe containing + * two fields recall, precision with (0.0, 1.0) prepended to it. + */ + @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") + + /** + * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + */ + @transient lazy val fMeasureByThreshold: DataFrame = { + binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") + } + + /** + * Returns a dataframe with two fields (threshold, precision) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the precision. + */ + @transient lazy val precisionByThreshold: DataFrame = { + binaryMetrics.precisionByThreshold().toDF("threshold", "precision") + } + + /** + * Returns a dataframe with two fields (threshold, recall) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the recall. + */ + @transient lazy val recallByThreshold: DataFrame = { + binaryMetrics.recallByThreshold().toDF("threshold", "recall") + } +} + +/** + * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used + * in binary classification for samples in sparse or dense vector in a online fashion. + * + * Note that multinomial logistic loss is not supported yet! + * + * Two LogisticAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * @param weights The weights/coefficients corresponding to the features. + * @param numClasses the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. + * @param fitIntercept Whether to fit an intercept term. + * @param featuresStd The standard deviation values of the features. + * @param featuresMean The mean values of the features. + */ +private class LogisticAggregator( + weights: Vector, + numClasses: Int, + fitIntercept: Boolean, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + + private val weightsArray = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + + private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length + + private val gradientSumArray = Array.ofDim[Double](weightsArray.length) + + /** + * Add a new training data to this LogisticAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LogisticAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val localWeightsArray = weightsArray + val localGradientSumArray = gradientSumArray + + numClasses match { + case 2 => + // For Binary Logistic Regression. + val margin = - { + var sum = 0.0 + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + sum += localWeightsArray(index) * (value / featuresStd(index)) + } + } + sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 } + } + + val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += multiplier * (value / featuresStd(index)) + } + } + + if (fitIntercept) { + localGradientSumArray(dim) += multiplier + } + + if (label > 0) { + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + lossSum += MLUtils.log1pExp(margin) + } else { + lossSum += MLUtils.log1pExp(margin) - margin + } + case _ => + new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " + + "binary classification for now.") + } + totalCnt += 1 + this + } + + /** + * Merge another LogisticAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LogisticAggregator to be merged. + * @return This LogisticAggregator object. + */ + def merge(other: LogisticAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + val len = localThisGradientSumArray.length + while (i < len) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function, + * as used in multi-class classification (it is also used in binary logistic regression). + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LogisticCostFun( + data: RDD[(Double, Vector)], + numClasses: Int, + fitIntercept: Boolean, + standardization: Boolean, + featuresStd: Array[Double], + featuresMean: Array[Double], + regParamL2: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val numFeatures = featuresStd.length + val w = Vectors.fromBreeze(weights) + + val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, + featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + val totalGradientArray = logisticAggregator.gradient.toArray + + // regVal is the sum of weight squares excluding intercept for L2 regularization. + val regVal = if (regParamL2 == 0.0) { + 0.0 + } else { + var sum = 0.0 + w.foreachActive { (index, value) => + // If `fitIntercept` is true, the last term which is intercept doesn't + // contribute to the regularization. + if (index != numFeatures) { + // The following code will compute the loss of the regularization; also + // the gradient of the regularization, and add back to totalGradientArray. + sum += { + if (standardization) { + totalGradientArray(index) += regParamL2 * value + value * value + } else { + if (featuresStd(index) != 0.0) { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + val temp = value / (featuresStd(index) * featuresStd(index)) + totalGradientArray(index) += regParamL2 * temp + value * temp + } else { + 0.0 + } + } + } + } + } + 0.5 * regParamL2 * sum + } + + (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala new file mode 100644 index 000000000000..1e5b0bc4453e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.DataFrame + +/** Params for Multilayer Perceptron. */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams + with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = new IntArrayParam(this, "layers", + "Sizes of layers from input layer to output layer" + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "one hidden layer with 100 neurons and output layer of 10 neurons.", + // TODO: how to check ALSO that all elements are greater than 0? + ParamValidators.arrayLengthGt(1) + ) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** + * Block size for stacking input data in matrices to speed up the computation. + * Data is stacked within partitions. If block size is more than remaining data in + * a partition then it is adjusted to the size of this data. + * Recommended size is between 10 and 1000. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices. Data is stacked within partitions." + + " If block size is more than remaining data in a partition then " + + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", + ParamValidators.gt(0)) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) +} + +/** Label to vector converter. */ +private object LabelConverter { + // TODO: Use OneHotEncoder instead + /** + * Encodes a label as a vector. + * Returns a vector of given length with zeroes at all positions + * and value 1.0 at the position that corresponds to the label. + * + * @param labeledPoint labeled point + * @param labelCount total number of labels + * @return pair of features and vector encoding of a label + */ + def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount)(0.0) + output(labeledPoint.label.toInt) = 1.0 + (labeledPoint.features, Vectors.dense(output)) + } + + /** + * Converts a vector to a label. + * Returns the position of the maximal element of a vector. + * + * @param output label encoded with a vector + * @return label + */ + def decodeLabel(output: Vector): Double = { + output.argmax.toDouble + } +} + +/** + * :: Experimental :: + * Classifier trainer based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * Number of inputs has to be equal to the size of feature vectors. + * Number of outputs has to be equal to the total number of labels. + * + */ +@Experimental +class MultilayerPerceptronClassifier(override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] + with MultilayerPerceptronParams { + + def this() = this(Identifiable.randomUID("mlpc")) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { + val myLayers = $(layers) + val labels = myLayers.last + val lpData = extractLabeledPoints(dataset) + val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) + FeedForwardTrainer.setStackSize($(blockSize)) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) + } +} + +/** + * :: Experimental :: + * Classification model based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * @param uid uid + * @param layers array of layer sizes including input and output layers + * @param weights vector of initial weights for the model that consists of the weights of layers + * @return prediction model + */ +@Experimental +class MultilayerPerceptronClassificationModel private[ml] ( + override val uid: String, + val layers: Array[Int], + val weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] + with Serializable { + + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + override protected def predict(features: Vector): Double = { + LabelConverter.decodeLabel(mlpModel.predict(features)) + } + + override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { + copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala new file mode 100644 index 000000000000..97cbaf1fa876 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkException +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * Params for Naive Bayes Classifiers. + */ +private[ml] trait NaiveBayesParams extends PredictorParams { + + /** + * The smoothing parameter. + * (default = 1.0). + * @group param + */ + final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.", + ParamValidators.gtEq(0)) + + /** @group getParam */ + final def getSmoothing: Double = $(smoothing) + + /** + * The model type which is a string (case-sensitive). + * Supported options: "multinomial" and "bernoulli". + * (default = multinomial) + * @group param + */ + final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + + "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", + ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) + + /** @group getParam */ + final def getModelType: String = $(modelType) +} + +/** + * Naive Bayes Classifiers. + * It supports both Multinomial NB + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) + * which can handle finitely supported discrete data. For example, by converting documents into + * TF-IDF vectors, it can be used for document classification. By making every vector a + * binary (0/1) data, it can also be used as Bernoulli NB + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). + * The input feature values must be nonnegative. + */ +class NaiveBayes(override val uid: String) + extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] + with NaiveBayesParams { + + def this() = this(Identifiable.randomUID("nb")) + + /** + * Set the smoothing parameter. + * Default is 1.0. + * @group setParam + */ + def setSmoothing(value: Double): this.type = set(smoothing, value) + setDefault(smoothing -> 1.0) + + /** + * Set the model type using a string (case-sensitive). + * Supported options: "multinomial" and "bernoulli". + * Default is "multinomial" + * @group setParam + */ + def setModelType(value: String): this.type = set(modelType, value) + setDefault(modelType -> OldNaiveBayes.Multinomial) + + override protected def train(dataset: DataFrame): NaiveBayesModel = { + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) + NaiveBayesModel.fromOld(oldModel, this) + } + + override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) +} + +/** + * Model produced by [[NaiveBayes]] + * @param pi log of class priors, whose dimension is C (number of classes) + * @param theta log of class conditional probabilities, whose dimension is C (number of classes) + * by D (number of features) + */ +class NaiveBayesModel private[ml] ( + override val uid: String, + val pi: Vector, + val theta: Matrix) + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + + import OldNaiveBayes.{Bernoulli, Multinomial} + + /** + * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + * application of this condition (in predict function). + */ + private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match { + case Multinomial => (None, None) + case Bernoulli => + val negTheta = theta.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(theta.numCols){1.0}) + val thetaMinusNegTheta = theta.map { value => + value - math.log(1.0 - math.exp(value)) + } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + + override val numClasses: Int = pi.size + + private def multinomialCalculation(features: Vector) = { + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob + } + + private def bernoulliCalculation(features: Vector) = { + features.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + override protected def predictRaw(features: Vector): Vector = { + $(modelType) match { + case Multinomial => + multinomialCalculation(features) + case Bernoulli => + bernoulliCalculation(features) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val maxLog = dv.values.max + while (i < size) { + dv.values(i) = math.exp(dv.values(i) - maxLog) + i += 1 + } + val probSum = dv.values.sum + i = 0 + while (i < size) { + dv.values(i) = dv.values(i) / probSum + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in NaiveBayesModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + + override def copy(extra: ParamMap): NaiveBayesModel = { + copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) + } + + override def toString: String = { + s"NaiveBayesModel with ${pi.size} classes" + } + +} + +private[ml] object NaiveBayesModel { + + /** Convert a model from the old API */ + def fromOld( + oldModel: OldNaiveBayesModel, + parent: NaiveBayes): NaiveBayesModel = { + val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") + val labels = Vectors.dense(oldModel.labels) + val pi = Vectors.dense(oldModel.pi) + val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, + oldModel.theta.flatten, true) + new NaiveBayesModel(uid, pi, theta) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala new file mode 100644 index 000000000000..c62e132f5d53 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import java.util.UUID + +import scala.language.existentials + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel + +/** + * Params for [[OneVsRest]]. + */ +private[ml] trait OneVsRestParams extends PredictorParams { + + // scalastyle:off structural.type + type ClassifierType = Classifier[F, E, M] forSome { + type F + type M <: ClassificationModel[F, M] + type E <: Classifier[F, E, M] + } + // scalastyle:on structural.type + + /** + * param for the base binary classifier that we reduce multiclass classification into. + * The base classifier input and output columns are ignored in favor of + * the ones specified in [[OneVsRest]]. + * @group param + */ + val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") + + /** @group getParam */ + def getClassifier: ClassifierType = $(classifier) +} + +/** + * :: Experimental :: + * Model produced by [[OneVsRest]]. + * This stores the models resulting from training k binary classifiers: one for each class. + * Each example is scored against all k models, and the model with the highest score + * is picked to label the example. + * + * @param labelMetadata Metadata of label column if it exists, or Nominal attribute + * representing the number of classes in training dataset otherwise. + * @param models The binary classification models for the reduction. + * The i-th model is produced by testing the i-th class (taking label 1) vs the rest + * (taking label 0). + */ +@Experimental +final class OneVsRestModel private[ml] ( + override val uid: String, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_, _]]) + extends Model[OneVsRestModel] with OneVsRestParams { + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) + } + + override def transform(dataset: DataFrame): DataFrame = { + // Check schema + transformSchema(dataset.schema, logging = true) + + // determine the input columns: these need to be passed through + val origCols = dataset.schema.map(f => col(f.name)) + + // add an accumulator column to store predictions of all the models + val accColName = "mbc$acc" + UUID.randomUUID().toString + val initUDF = udf { () => Map[Int, Double]() } + val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) + val newDataset = dataset.withColumn(accColName, initUDF()) + + // persist if underlying dataset is not persistent. + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + newDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // update the accumulator column with the result of prediction of models + val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) { + case (df, (model, index)) => + val rawPredictionCol = model.getRawPredictionCol + val columns = origCols ++ List(col(rawPredictionCol), col(accColName)) + + // add temporary column to store intermediate scores and update + val tmpColName = "mbc$tmp" + UUID.randomUUID().toString + val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => + predictions + ((index, prediction(1))) + } + val transformedDataset = model.transform(df).select(columns : _*) + val updatedDataset = transformedDataset + .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) + val newColumns = origCols ++ List(col(tmpColName)) + + // switch out the intermediate column with the accumulator column + updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) + } + + if (handlePersistence) { + newDataset.unpersist() + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + + // output label and label metadata as prediction + aggregatedDataset + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } + + override def copy(extra: ParamMap): OneVsRestModel = { + val copied = new OneVsRestModel( + uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) + copyValues(copied, extra).setParent(parent) + } +} + +/** + * :: Experimental :: + * + * Reduction of Multiclass Classification to Binary Classification. + * Performs reduction using one against all strategy. + * For a multiclass classification with k classes, train k models (one per class). + * Each example is scored against all k models and the model with highest score + * is picked to label the example. + */ +@Experimental +final class OneVsRest(override val uid: String) + extends Estimator[OneVsRestModel] with OneVsRestParams { + + def this() = this(Identifiable.randomUID("oneVsRest")) + + /** @group setParam */ + def setClassifier(value: Classifier[_, _, _]): this.type = { + set(classifier, value.asInstanceOf[ClassifierType]) + } + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) + } + + override def fit(dataset: DataFrame): OneVsRestModel = { + // determine number of classes either from metadata if provided, or via computation. + val labelSchema = dataset.schema($(labelCol)) + val computeNumClasses: () => Int = () => { + val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() + // classes are assumed to be numbered from 0,...,maxLabelIndex + maxLabelIndex.toInt + 1 + } + val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) + + val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) + + // persist if underlying dataset is not persistent. + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) + } + + // create k columns, one for each binary classifier. + val models = Range(0, numClasses).par.map { index => + val labelUDF = udf { (label: Double) => + if (label.toInt == index) 1.0 else 0.0 + } + + // generate new label metadata for the binary problem. + // TODO: use when ... otherwise after SPARK-7321 is merged + val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() + val labelColName = "mc2b$" + index + val trainingDataset = + multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) + val classifier = getClassifier + val paramMap = new ParamMap() + paramMap.put(classifier.labelCol -> labelColName) + paramMap.put(classifier.featuresCol -> getFeaturesCol) + paramMap.put(classifier.predictionCol -> getPredictionCol) + classifier.fit(trainingDataset, paramMap) + }.toArray[ClassificationModel[_, _]] + + if (handlePersistence) { + multiclassLabeled.unpersist() + } + + // extract label metadata from label column if present, or create a nominal attribute + // to output the number of labels + val labelAttribute = Attribute.fromStructField(labelSchema) match { + case _: NumericAttribute | UnresolvedAttribute => + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + case attr: Attribute => attr + } + val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) + copyValues(model) + } + + override def copy(extra: ParamMap): OneVsRest = { + val copied = defaultCopy(extra).asInstanceOf[OneVsRest] + if (isDefined(classifier)) { + copied.setClassifier($(classifier).copy(extra)) + } + copied + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala new file mode 100644 index 000000000000..fdd1851ae550 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * (private[classification]) Params for probabilistic classification. + */ +private[classification] trait ProbabilisticClassifierParams + extends ClassifierParams with HasProbabilityCol with HasThresholds { + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + SchemaUtils.appendColumn(parentSchema, $(probabilityCol), new VectorUDT) + } +} + + +/** + * :: DeveloperApi :: + * + * Single-label binary or multiclass classifier which can output class conditional probabilities. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + */ +@DeveloperApi +abstract class ProbabilisticClassifier[ + FeaturesType, + E <: ProbabilisticClassifier[FeaturesType, E, M], + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] + + /** @group setParam */ + def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] +} + + +/** + * :: DeveloperApi :: + * + * Model produced by a [[ProbabilisticClassifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + */ +@DeveloperApi +abstract class ProbabilisticClassificationModel[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + + /** @group setParam */ + def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]] + * - probability of each class as [[probabilityCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @return transformed dataset + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".transform() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = dataset + var numColsOutput = 0 + if ($(rawPredictionCol).nonEmpty) { + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) + numColsOutput += 1 + } + if ($(probabilityCol).nonEmpty) { + val probUDF = if ($(rawPredictionCol).nonEmpty) { + udf(raw2probability _).apply(col($(rawPredictionCol))) + } else { + val probabilityUDF = udf { (features: Any) => + predictProbability(features.asInstanceOf[FeaturesType]) + } + probabilityUDF(col($(featuresCol))) + } + outputData = outputData.withColumn($(probabilityCol), probUDF) + numColsOutput += 1 + } + if ($(predictionCol).nonEmpty) { + val predUDF = if ($(rawPredictionCol).nonEmpty) { + udf(raw2prediction _).apply(col($(rawPredictionCol))) + } else if ($(probabilityCol).nonEmpty) { + udf(probability2prediction _).apply(col($(probabilityCol))) + } else { + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col($(featuresCol))) + } + outputData = outputData.withColumn($(predictionCol), predUDF) + numColsOutput += 1 + } + + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + + /** + * Estimate the probability of each class given the raw prediction, + * doing the computation in-place. + * These predictions are also called class conditional probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + * + * @return Estimated class conditional probabilities (modified input vector) + */ + protected def raw2probabilityInPlace(rawPrediction: Vector): Vector + + /** Non-in-place version of [[raw2probabilityInPlace()]] */ + protected def raw2probability(rawPrediction: Vector): Vector = { + val probs = rawPrediction.copy + raw2probabilityInPlace(probs) + } + + override protected def raw2prediction(rawPrediction: Vector): Double = { + if (!isDefined(thresholds)) { + rawPrediction.argmax + } else { + probability2prediction(raw2probability(rawPrediction)) + } + } + + /** + * Predict the probability of each class given the features. + * These predictions are also called class conditional probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + * + * @return Estimated class conditional probabilities + */ + protected def predictProbability(features: FeaturesType): Vector = { + val rawPreds = predictRaw(features) + raw2probabilityInPlace(rawPreds) + } + + /** + * Given a vector of class conditional probabilities, select the predicted label. + * This supports thresholds which favor particular labels. + * @return predicted label + */ + protected def probability2prediction(probability: Vector): Double = { + if (!isDefined(thresholds)) { + probability.argmax + } else { + val thresholds: Array[Double] = getThresholds + val scaledProbability: Array[Double] = + probability.toArray.zip(thresholds).map { case (p, t) => + if (t == 0.0) Double.PositiveInfinity else p / t + } + Vectors.dense(scaledProbability).argmax + } + } +} + +private[ml] object ProbabilisticClassificationModel { + + /** + * Normalize a vector of raw predictions to be a multinomial probability vector, in place. + * + * The input raw predictions should be >= 0. + * The output vector sums to 1, unless the input vector is all-0 (in which case the output is + * all-0 too). + * + * NOTE: This is NOT applicable to all models, only ones which effectively use class + * instance counts for raw predictions. + */ + def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + val sum = v.values.sum + if (sum != 0) { + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala new file mode 100644 index 000000000000..11a6d7246833 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ + + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for + * classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@Experimental +final class RandomForestClassifier(override val uid: String) + extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] + with RandomForestParams with TreeClassifierParams { + + def this() = this(Identifiable.randomUID("rfc")) + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train(dataset: DataFrame): RandomForestClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("RandomForestClassifier was given input" + + s" with invalid label column ${$(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + val trees = + RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) + .map(_.asInstanceOf[DecisionTreeClassificationModel]) + val numFeatures = oldDataset.first().features.size + new RandomForestClassificationModel(trees, numFeatures, numClasses) + } + + override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) +} + +@Experimental +object RandomForestClassifier { + /** Accessor for supported impurity settings: entropy, gini */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + * @param _trees Decision trees in the ensemble. + * Warning: These have null parents. + * @param numFeatures Number of features used by this model + */ +@Experimental +final class RandomForestClassificationModel private[ml] ( + override val uid: String, + private val _trees: Array[DecisionTreeClassificationModel], + val numFeatures: Int, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + + /** + * Construct a random forest classification model, with all trees weighted equally. + * @param trees Component trees + */ + private[ml] def this( + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override protected def predictRaw(features: Vector): Vector = { + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Classifies using majority votes. + // Ignore the tree weights since all are 1.0 for now. + val votes = Array.fill[Double](numClasses)(0.0) + _trees.view.foreach { tree => + val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats + val total = classCounts.sum + if (total != 0) { + var i = 0 + while (i < numClasses) { + votes(i) += classCounts(i) / total + i += 1 + } + } + } + Vectors.dense(votes) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + + override def copy(extra: ParamMap): RandomForestClassificationModel = { + copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) + } + + override def toString: String = { + s"RandomForestClassificationModel with $numTrees trees" + } + + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): RandomForestClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) + } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") + new RandomForestClassificationModel(uid, newTrees, -1, numClasses) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala new file mode 100644 index 000000000000..47a18cdb31b5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.{DataFrame, Row} + + +/** + * Common params for KMeans and KMeansModel + */ +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * @group expertParam + */ + final val initMode = new Param[String](this, "initMode", "initialization algorithm", + (value: String) => MLlibKMeans.validateInitMode(value)) + + /** @group expertGetParam */ + def getInitMode: String = $(initMode) + + /** + * Param for the number of steps for the k-means|| initialization mode. This is an advanced + * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5. + * @group expertParam + */ + final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", + (value: Int) => value > 0) + + /** @group expertGetParam */ + def getInitSteps: Int = $(initSteps) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Model fitted by KMeans. + * + * @param parentModel a model trained by spark.mllib.clustering.KMeans. + */ +@Experimental +class KMeansModel private[ml] ( + override val uid: String, + private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + + override def copy(extra: ParamMap): KMeansModel = { + val copied = new KMeansModel(uid, parentModel) + copyValues(copied, extra) + } + + override def transform(dataset: DataFrame): DataFrame = { + val predictUDF = udf((vector: Vector) => predict(vector)) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + def clusterCenters: Array[Vector] = parentModel.clusterCenters +} + +/** + * :: Experimental :: + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] + */ +@Experimental +class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> MLlibKMeans.K_MEANS_PARALLEL, + initSteps -> 5, + tol -> 1e-4) + + override def copy(extra: ParamMap): KMeans = defaultCopy(extra) + + def this() = this(Identifiable.randomUID("kmeans")) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group expertSetParam */ + def setInitSteps(value: Int): this.type = set(initSteps, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setTol(value: Double): this.type = set(tol, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + override def fit(dataset: DataFrame): KMeansModel = { + val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + + val algo = new MLlibKMeans() + .setK($(k)) + .setInitializationMode($(initMode)) + .setInitializationSteps($(initSteps)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setEpsilon($(tol)) + val parentModel = algo.run(rdd) + val model = new KMeansModel(uid, parentModel) + copyValues(model) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1979ab9eb651..56419a0a1595 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -17,55 +17,79 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * Evaluator for binary classification, which expects two input columns: score and label. + * :: Experimental :: + * Evaluator for binary classification, which expects two input columns: rawPrediction and label. */ -@AlphaComponent -class BinaryClassificationEvaluator extends Evaluator with Params - with HasScoreCol with HasLabelCol { +@Experimental +class BinaryClassificationEvaluator(override val uid: String) + extends Evaluator with HasRawPredictionCol with HasLabelCol { - /** param for metric name in evaluation */ - val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) - def getMetricName: String = get(metricName) + def this() = this(Identifiable.randomUID("binEval")) + + /** + * param for metric name in evaluation + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR")) + new Param( + this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) + /** @group setParam */ + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + + /** + * @group setParam + * @deprecated use [[setRawPredictionCol()]] instead + */ + @deprecated("use setRawPredictionCol instead", "1.5.0") + def setScoreCol(value: String): this.type = set(rawPredictionCol, value) + + /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + setDefault(metricName -> "areaUnderROC") + override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - val scoreType = schema(map(scoreCol)).dataType - require(scoreType == DoubleType, - s"Score column ${map(scoreCol)} must be double type but found $scoreType") - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Label column ${map(labelCol)} must be double type but found $labelType") + SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) - .map { case Row(score: Double, label: Double) => - (score, label) + // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. + val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) + .map { case Row(rawPrediction: Vector, label: Double) => + (rawPrediction(1), label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) - val metric = map(metricName) match { - case "areaUnderROC" => - metrics.areaUnderROC() - case "areaUnderPR" => - metrics.areaUnderPR() - case other => - throw new IllegalArgumentException(s"Does not support metric $other.") + val metric = $(metricName) match { + case "areaUnderROC" => metrics.areaUnderROC() + case "areaUnderPR" => metrics.areaUnderPR() } metrics.unpersist() metric } + + override def isLargerBetter: Boolean = $(metricName) match { + case "areaUnderROC" => true + case "areaUnderPR" => true + } + + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala new file mode 100644 index 000000000000..13bd3307f8a2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.DataFrame + +/** + * :: DeveloperApi :: + * Abstract class for evaluators that compute metrics from predictions. + */ +@DeveloperApi +abstract class Evaluator extends Params { + + /** + * Evaluates model output and returns a scalar metric (larger is better). + * + * @param dataset a dataset that contains labels/observations and predictions. + * @param paramMap parameter map that specifies the input columns and output metrics + * @return metric + */ + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + this.copy(paramMap).evaluate(dataset) + } + + /** + * Evaluates the output. + * @param dataset a dataset that contains labels/observations and predictions. + * @return metric + */ + def evaluate(dataset: DataFrame): Double + + /** + * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) + * or minimized (false). + * A given evaluator may support multiple metrics which may be maximized or minimized. + */ + def isLargerBetter: Boolean = true + + override def copy(extra: ParamMap): Evaluator +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala new file mode 100644 index 000000000000..f73d2345078e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for multiclass classification, which expects two input columns: score and label. + */ +@Experimental +class MulticlassClassificationEvaluator (override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("mcEval")) + + /** + * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, + * `"weightedPrecision"`, `"weightedRecall"`) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("f1", "precision", + "recall", "weightedPrecision", "weightedRecall")) + new Param(this, "metricName", "metric name in evaluation " + + "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "f1") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new MulticlassMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "f1" => metrics.weightedFMeasure + case "precision" => metrics.precision + case "recall" => metrics.recall + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall + } + metric + } + + override def isLargerBetter: Boolean = $(metricName) match { + case "f1" => true + case "precision" => true + case "recall" => true + case "weightedPrecision" => true + case "weightedRecall" => true + } + + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala new file mode 100644 index 000000000000..d21c88ab9b10 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for regression, which expects two input columns: prediction and label. + */ +@Experimental +final class RegressionEvaluator(override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("regEval")) + + /** + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * + * Because we will maximize evaluation value (ref: `CrossValidator`), + * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + * we take and output the negative of this metric. + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) + new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "rmse") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new RegressionMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "rmse" => metrics.rootMeanSquaredError + case "mse" => metrics.meanSquaredError + case "r2" => metrics.r2 + case "mae" => metrics.meanAbsoluteError + } + metric + } + + override def isLargerBetter: Boolean = $(metricName) match { + case "rmse" => false + case "mse" => false + case "r2" => true + case "mae" => false + } + + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala new file mode 100644 index 000000000000..46314854d5e3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.BinaryAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} + +/** + * :: Experimental :: + * Binarize a column of continuous features given a threshold. + */ +@Experimental +final class Binarizer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("binarizer")) + + /** + * Param for threshold used to binarize continuous features. + * The features greater than the threshold, will be binarized to 1.0. + * The features equal to or less than the threshold, will be binarized to 0.0. + * @group param + */ + val threshold: DoubleParam = + new DoubleParam(this, "threshold", "threshold used to binarize continuous features") + + /** @group getParam */ + def getThreshold: Double = $(threshold) + + /** @group setParam */ + def setThreshold(value: Double): this.type = set(threshold, value) + + setDefault(threshold -> 0.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val td = $(threshold) + val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } + val outputColName = $(outputCol) + val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() + dataset.select(col("*"), + binarizer(col($(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + + val inputFields = schema.fields + val outputColName = $(outputCol) + + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + + val attr = BinaryAttribute.defaultAttr.withName(outputColName) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } + + override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala new file mode 100644 index 000000000000..6fdf25b015b0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import java.{util => ju} + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Model +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** + * :: Experimental :: + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + */ +@Experimental +final class Bucketizer(override val uid: String) + extends Model[Bucketizer] with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("bucketizer")) + + /** + * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. + * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which + * also includes y. Splits should be strictly increasing. + * Values at -inf, inf must be explicitly provided to cover all Double values; + * otherwise, values outside the splits specified will be treated as errors. + * @group param + */ + val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, there are n " + + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " + + "bucket, which also includes y. The splits should be strictly increasing. " + + "Values at -inf, inf must be explicitly provided to cover all Double values; " + + "otherwise, values outside the splits specified will be treated as errors.", + Bucketizer.checkSplits) + + /** @group getParam */ + def getSplits: Array[Double] = $(splits) + + /** @group setParam */ + def setSplits(value: Array[Double]): this.type = set(splits, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val bucketizer = udf { feature: Double => + Bucketizer.binarySearchForBuckets($(splits), feature) + } + val newCol = bucketizer(dataset($(inputCol))) + val newField = prepOutputField(dataset.schema) + dataset.withColumn($(outputCol), newCol, newField.metadata) + } + + private def prepOutputField(schema: StructType): StructField = { + val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray + val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), + values = Some(buckets)) + attr.toStructField() + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.appendColumn(schema, prepOutputField(schema)) + } + + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } +} + +private[feature] object Bucketizer { + /** We require splits to be of length >= 3 and to be in strictly increasing order. */ + def checkSplits(splits: Array[Double]): Boolean = { + if (splits.length < 3) { + false + } else { + var i = 0 + val n = splits.length - 1 + while (i < n) { + if (splits(i) >= splits(i + 1)) return false + i += 1 + } + true + } + } + + /** + * Binary searching in several buckets to place each data point. + * @throws SparkException if a feature is < splits.head or > splits.last + */ + def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + if (feature == splits.last) { + splits.length - 2 + } else { + val idx = ju.Arrays.binarySearch(splits, feature) + if (idx >= 0) { + idx + } else { + val insertPos = -idx - 1 + if (insertPos == 0 || insertPos == splits.length) { + throw new SparkException(s"Feature value $feature out of Bucketizer bounds" + + s" [${splits.head}, ${splits.last}]. Check your features, or loosen " + + s"the lower/upper bound constraints.") + } else { + insertPos - 1 + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala new file mode 100644 index 000000000000..49028e4b8506 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.collection.OpenHashMap + +/** + * Params for [[CountVectorizer]] and [[CountVectorizerModel]]. + */ +private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Max size of the vocabulary. + * CountVectorizer will build a vocabulary that only considers the top + * vocabSize terms ordered by term frequency across the corpus. + * + * Default: 2^18^ + * @group param + */ + val vocabSize: IntParam = + new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0)) + + /** @group getParam */ + def getVocabSize: Int = $(vocabSize) + + /** + * Specifies the minimum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer >= 1, this specifies the number of documents the term must appear in; + * if this is a double in [0,1), then this specifies the fraction of documents. + * + * Default: 1 + * @group param + */ + val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMinDF: Double = $(minDF) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + /** + * Filter to ignore rare words in a document. For each document, terms with + * frequency/count less than the given threshold are ignored. + * If this is an integer >= 1, then this specifies a count (of times the term must appear + * in the document); + * if this is a double in [0,1), then this specifies a fraction (out of the document's token + * count). + * + * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not + * affect fitting. + * + * Default: 1 + * @group param + */ + val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given threshold are" + + " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" + + " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" + + " of the document's token count). Note that the parameter is only used in transform of" + + " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0)) + + setDefault(minTF -> 1) + + /** @group getParam */ + def getMinTF: Double = $(minTF) +} + +/** + * :: Experimental :: + * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]]. + */ +@Experimental +class CountVectorizer(override val uid: String) + extends Estimator[CountVectorizerModel] with CountVectorizerParams { + + def this() = this(Identifiable.randomUID("cntVec")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVocabSize(value: Int): this.type = set(vocabSize, value) + + /** @group setParam */ + def setMinDF(value: Double): this.type = set(minDF, value) + + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + setDefault(vocabSize -> (1 << 18), minDF -> 1) + + override def fit(dataset: DataFrame): CountVectorizerModel = { + transformSchema(dataset.schema, logging = true) + val vocSize = $(vocabSize) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val minDf = if ($(minDF) >= 1.0) { + $(minDF) + } else { + $(minDF) * input.cache().count() + } + val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val wc = new OpenHashMap[String, Long] + tokens.foreach { w => + wc.changeValue(w, 1L, _ + 1L) + } + wc.map { case (word, count) => (word, (count, 1)) } + }.reduceByKey { case ((wc1, df1), (wc2, df2)) => + (wc1 + wc2, df1 + df2) + }.filter { case (word, (wc, df)) => + df >= minDf + }.map { case (word, (count, dfCount)) => + (word, count) + }.cache() + val fullVocabSize = wordCounts.count() + val vocab: Array[String] = { + val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocSize) + } + tmpSortedWC.map(_._1) + } + + require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") + copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) + extends Model[CountVectorizerModel] with CountVectorizerParams { + + def this(vocabulary: Array[String]) = { + this(Identifiable.randomUID("cntVecModel"), vocabulary) + set(vocabSize, vocabulary.length) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ + private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None + + override def transform(dataset: DataFrame): DataFrame = { + if (broadcastDict.isEmpty) { + val dict = vocabulary.zipWithIndex.toMap + broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) + } + val dictBr = broadcastDict.get + val minTf = $(minTF) + val vectorizer = udf { (document: Seq[String]) => + val termCounts = new OpenHashMap[Int, Double] + var tokenCount = 0L + document.foreach { term => + dictBr.value.get(term) match { + case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0) + case None => // ignore terms not in the vocabulary + } + tokenCount += 1 + } + val effectiveMinTF = if (minTf >= 1.0) { + minTf + } else { + tokenCount * minTf + } + Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq) + } + dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala new file mode 100644 index 000000000000..228347635c92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import edu.emory.mathcs.jtransforms.dct._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.BooleanParam +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero + * padding is performed on the input vector. + * It returns a real vector of the same length representing the DCT. The return vector is scaled + * such that the transform matrix is unitary (aka scaled DCT-II). + * + * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. + */ +@Experimental +class DCT(override val uid: String) + extends UnaryTransformer[Vector, Vector, DCT] { + + def this() = this(Identifiable.randomUID("dct")) + + /** + * Indicates whether to perform the inverse DCT (true) or forward DCT (false). + * Default: false + * @group param + */ + def inverse: BooleanParam = new BooleanParam( + this, "inverse", "Set transformer to perform inverse DCT") + + /** @group setParam */ + def setInverse(value: Boolean): this.type = set(inverse, value) + + /** @group getParam */ + def getInverse: Boolean = $(inverse) + + setDefault(inverse -> false) + + override protected def createTransformFunc: Vector => Vector = { vec => + val result = vec.toArray + val jTransformer = new DoubleDCT_1D(result.length) + if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true) + Vectors.dense(result) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala new file mode 100644 index 000000000000..a359cb8f37ec --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a + * provided "weight" vector. In other words, it scales each column of the dataset by a scalar + * multiplier. + */ +@Experimental +class ElementwiseProduct(override val uid: String) + extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { + + def this() = this(Identifiable.randomUID("elemProd")) + + /** + * the vector to multiply with input vectors + * @group param + */ + val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") + + /** @group setParam */ + def setScalingVec(value: Vector): this.type = set(scalingVec, value) + + /** @group getParam */ + def getScalingVec: Vector = getOrDefault(scalingVec) + + override protected def createTransformFunc: Vector => Vector = { + require(params.contains(scalingVec), s"transformation requires a weight vector") + val elemScaler = new feature.ElementwiseProduct($(scalingVec)) + elemScaler.transform + } + + override protected def outputDataType: DataType = new VectorUDT() +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0956062643f2..319d23e46cef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,29 +17,63 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. */ -@AlphaComponent -class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { +@Experimental +class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { - /** number of features */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) - def setNumFeatures(value: Int) = set(numFeatures, value) - def getNumFeatures: Int = get(numFeatures) + def this() = this(Identifiable.randomUID("hashingTF")) - override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { - val hashingTF = new feature.HashingTF(paramMap(numFeatures)) - hashingTF.transform + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Number of features. Should be > 0. + * (default = 2^18^) + * @group param + */ + val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", + ParamValidators.gt(0)) + + setDefault(numFeatures -> (1 << 18)) + + /** @group getParam */ + def getNumFeatures: Int = $(numFeatures) + + /** @group setParam */ + def setNumFeatures(value: Int): this.type = set(numFeatures, value) + + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) + val hashingTF = new feature.HashingTF($(numFeatures)) + val t = udf { terms: Seq[_] => hashingTF.transform(terms) } + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) + SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } - override protected def outputDataType: DataType = new VectorUDT() + override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala new file mode 100644 index 000000000000..938447447a0a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * Params for [[IDF]] and [[IDFModel]]. + */ +private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { + + /** + * The minimum of documents in which a term should appear. + * @group param + */ + final val minDocFreq = new IntParam( + this, "minDocFreq", "minimum of documents in which a term should appear for filtering") + + setDefault(minDocFreq -> 0) + + /** @group getParam */ + def getMinDocFreq: Int = $(minDocFreq) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } +} + +/** + * :: Experimental :: + * Compute the Inverse Document Frequency (IDF) given a collection of documents. + */ +@Experimental +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { + + def this() = this(Identifiable.randomUID("idf")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + + override def fit(dataset: DataFrame): IDFModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val idf = new feature.IDF($(minDocFreq)).fit(input) + copyValues(new IDFModel(uid, idf).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): IDF = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[IDF]]. + */ +@Experimental +class IDFModel private[ml] ( + override val uid: String, + idfModel: feature.IDFModel) + extends Model[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val idf = udf { vec: Vector => idfModel.transform(vec) } + dataset.withColumn($(outputCol), idf(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): IDFModel = { + val copied = new IDFModel(uid, idfModel) + copyValues(copied, extra).setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala new file mode 100644 index 000000000000..1b494ec8b172 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]]. + */ +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * lower bound after transformation, shared by all features + * Default: 0.0 + * @group param + */ + val min: DoubleParam = new DoubleParam(this, "min", + "lower bound of the output feature range") + + /** @group getParam */ + def getMin: Double = $(min) + + /** + * upper bound after transformation, shared by all features + * Default: 1.0 + * @group param + */ + val max: DoubleParam = new DoubleParam(this, "max", + "upper bound of the output feature range") + + /** @group getParam */ + def getMax: Double = $(max) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def validateParams(): Unit = { + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to a common range [min, max] linearly using column summary + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + * feature E is calculated as, + * + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min + * + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min) + * Note that since zero values will probably be transformed to non-zero values, output of the + * transformer will be DenseVector even for sparse input. + */ +@Experimental +class MinMaxScaler(override val uid: String) + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + + def this() = this(Identifiable.randomUID("minMaxScal")) + + setDefault(min -> 0.0, max -> 1.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + override def fit(dataset: DataFrame): MinMaxScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[MinMaxScaler]]. + * + * @param originalMin min value for each original column during fitting + * @param originalMax max value for each original column during fitting + * + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). + */ +@Experimental +class MinMaxScalerModel private[ml] ( + override val uid: String, + val originalMin: Vector, + val originalMax: Vector) + extends Model[MinMaxScalerModel] with MinMaxScalerParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + override def transform(dataset: DataFrame): DataFrame = { + val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray + val minArray = originalMin.toArray + + val reScale = udf { (vector: Vector) => + val scale = $(max) - $(min) + + // 0 in sparse vector will probably be rescaled to non-zero + val values = vector.toArray + val size = values.size + var i = 0 + while (i < size) { + val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 + values(i) = raw * scale + $(min) + i += 1 + } + Vectors.dense(values) + } + + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScalerModel = { + val copied = new MinMaxScalerModel(uid, originalMin, originalMax) + copyValues(copied, extra).setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala new file mode 100644 index 000000000000..8de10eb51f92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} + +/** + * :: Experimental :: + * A feature transformer that converts the input array of strings into an array of n-grams. Null + * values in the input array are ignored. + * It returns an array of n-grams where each n-gram is represented by a space-separated string of + * words. + * + * When the input is empty, an empty array is returned. + * When the input array length is less than n (number of elements per n-gram), no n-grams are + * returned. + */ +@Experimental +class NGram(override val uid: String) + extends UnaryTransformer[Seq[String], Seq[String], NGram] { + + def this() = this(Identifiable.randomUID("ngram")) + + /** + * Minimum n-gram length, >= 1. + * Default: 2, bigram features + * @group param + */ + val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", + ParamValidators.gtEq(1)) + + /** @group setParam */ + def setN(value: Int): this.type = set(n, value) + + /** @group getParam */ + def getN: Int = $(n) + + setDefault(n -> 2) + + override protected def createTransformFunc: Seq[String] => Seq[String] = { + _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala new file mode 100644 index 000000000000..8282e5ffa17f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{DoubleParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * Normalize a vector to have unit norm using the given p-norm. + */ +@Experimental +class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { + + def this() = this(Identifiable.randomUID("normalizer")) + + /** + * Normalization in L^p^ space. Must be >= 1. + * (default: p = 2) + * @group param + */ + val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1)) + + setDefault(p -> 2.0) + + /** @group getParam */ + def getP: Double = $(p) + + /** @group setParam */ + def setP(value: Double): this.type = set(p, value) + + override protected def createTransformFunc: Vector => Vector = { + val normalizer = new feature.Normalizer($(p)) + normalizer.transform + } + + override protected def outputDataType: DataType = new VectorUDT() +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala new file mode 100644 index 000000000000..9c60d4084ec4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} + +/** + * :: Experimental :: + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see [[StringIndexer]] for converting categorical values into category indices + */ +@Experimental +class OneHotEncoder(override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("oneHot")) + + /** + * Whether to drop the last category in the encoded vector (default: true) + * @group param + */ + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) + + /** @group setParam */ + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transformSchema(schema: StructType): StructType = { + val inputColName = $(inputCol) + val outputColName = $(outputCol) + + SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + val inputFields = schema.fields + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + val inputAttr = Attribute.fromStructField(schema(inputColName)) + val outputAttrNames: Option[Array[String]] = inputAttr match { + case nominal: NominalAttribute => + if (nominal.values.isDefined) { + nominal.values + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values + } else { + Some(Array.tabulate(2)(_.toString)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column $inputColName cannot be numeric.") + case _ => + None // optimistic about unknown attributes + } + + val filteredOutputAttrNames = outputAttrNames.map { names => + if ($(dropLast)) { + require(names.length > 1, + s"The input column $inputColName should have at least two distinct values.") + names.dropRight(1) + } else { + names + } + } + + val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { + val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup($(outputCol), attrs) + } else { + new AttributeGroup($(outputCol)) + } + + val outputFields = inputFields :+ outputAttrGroup.toStructField() + StructType(outputFields) + } + + override def transform(dataset: DataFrame): DataFrame = { + // schema transformation + val inputColName = $(inputCol) + val outputColName = $(outputCol) + val shouldDropLast = $(dropLast) + var outputAttrGroup = AttributeGroup.fromStructField( + transformSchema(dataset.schema)(outputColName)) + if (outputAttrGroup.size < 0) { + // If the number of attributes is unknown, we check the values from the input column. + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + .aggregate(0.0)( + (m, x) => { + assert(x >=0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) + val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames + val outputAttrs: Array[Attribute] = + filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) + outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + } + val metadata = outputAttrGroup.toMetadata() + + // data transformation + val size = outputAttrGroup.size + val oneValue = Array(1.0) + val emptyValues = Array[Double]() + val emptyIndices = Array[Int]() + val encode = udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) + } else { + Vectors.sparse(size, emptyIndices, emptyValues) + } + } + + dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) + } + + override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala new file mode 100644 index 000000000000..539084704b65 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[PCA]] and [[PCAModel]]. + */ +private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol { + + /** + * The number of principal components. + * @group param + */ + final val k: IntParam = new IntParam(this, "k", "the number of principal components") + + /** @group getParam */ + def getK: Int = $(k) + +} + +/** + * :: Experimental :: + * PCA trains a model to project vectors to a low-dimensional space using PCA. + */ +@Experimental +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { + + def this() = this(Identifiable.randomUID("pca")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** + * Computes a [[PCAModel]] that contains the principal components of the input vectors. + */ + override def fit(dataset: DataFrame): PCAModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val pca = new feature.PCA(k = $(k)) + val pcaModel = pca.fit(input) + copyValues(new PCAModel(uid, pcaModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCA = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[PCA]]. + */ +@Experimental +class PCAModel private[ml] ( + override val uid: String, + pcaModel: feature.PCAModel) + extends Model[PCAModel] with PCAParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a vector by computed Principal Components. + * NOTE: Vectors to be transformed must be the same length + * as the source vectors given to [[PCA.fit()]]. + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val pcaOp = udf { pcaModel.transform _ } + dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCAModel = { + val copied = new PCAModel(uid, pcaModel) + copyValues(copied, extra).setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala new file mode 100644 index 000000000000..d85e468562d4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, + * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an + * expansion of a product of sums expresses it as a sum of products by using the fact that + * multiplication distributes over addition". Take a 2-variable feature vector as an example: + * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. + */ +@Experimental +class PolynomialExpansion(override val uid: String) + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + + def this() = this(Identifiable.randomUID("poly")) + + /** + * The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion. + * Default: 2 + * @group param + */ + val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)", + ParamValidators.gt(1)) + + setDefault(degree -> 2) + + /** @group getParam */ + def getDegree: Int = $(degree) + + /** @group setParam */ + def setDegree(value: Int): this.type = set(degree, value) + + override protected def createTransformFunc: Vector => Vector = { v => + PolynomialExpansion.expand(v, $(degree)) + } + + override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) +} + +/** + * The expansion is done via recursion. Given n features and degree d, the size after expansion is + * (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the + * function that expands [a, b, c] to their monomials of degree 3. We have the following recursion: + * + * {{{ + * f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3] + * }}} + * + * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the + * current index and increment it properly for sparse input. + */ +private[feature] object PolynomialExpansion { + + private def choose(n: Int, k: Int): Int = { + Range(n, n - k, -1).product / Range(k, 1, -1).product + } + + private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) + + private def expandDense( + values: Array[Double], + lastIdx: Int, + degree: Int, + multiplier: Double, + polyValues: Array[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + if (curPolyIdx >= 0) { // skip the very first 1 + polyValues(curPolyIdx) = multiplier + } + } else { + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + var alpha = multiplier + var i = 0 + var curStart = curPolyIdx + while (i <= degree && alpha != 0.0) { + curStart = expandDense(values, lastIdx1, degree - i, alpha, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastIdx + 1, degree) + } + + private def expandSparse( + indices: Array[Int], + values: Array[Double], + lastIdx: Int, + lastFeatureIdx: Int, + degree: Int, + multiplier: Double, + polyIndices: mutable.ArrayBuilder[Int], + polyValues: mutable.ArrayBuilder[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + if (curPolyIdx >= 0) { // skip the very first 1 + polyIndices += curPolyIdx + polyValues += multiplier + } + } else { + // Skip all zeros at the tail. + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + val lastFeatureIdx1 = indices(lastIdx) - 1 + var alpha = multiplier + var curStart = curPolyIdx + var i = 0 + while (i <= degree && alpha != 0.0) { + curStart = expandSparse(indices, values, lastIdx1, lastFeatureIdx1, degree - i, alpha, + polyIndices, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastFeatureIdx + 1, degree) + } + + private def expand(dv: DenseVector, degree: Int): DenseVector = { + val n = dv.size + val polySize = getPolySize(n, degree) + val polyValues = new Array[Double](polySize - 1) + expandDense(dv.values, n - 1, degree, 1.0, polyValues, -1) + new DenseVector(polyValues) + } + + private def expand(sv: SparseVector, degree: Int): SparseVector = { + val polySize = getPolySize(sv.size, degree) + val nnz = sv.values.length + val nnzPolySize = getPolySize(nnz, degree) + val polyIndices = mutable.ArrayBuilder.make[Int] + polyIndices.sizeHint(nnzPolySize - 1) + val polyValues = mutable.ArrayBuilder.make[Double] + polyValues.sizeHint(nnzPolySize - 1) + expandSparse( + sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, -1) + new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) + } + + def expand(v: Vector, degree: Int): Vector = { + v match { + case dv: DenseVector => expand(dv, degree) + case sv: SparseVector => expand(sv, degree) + case _ => throw new IllegalArgumentException + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala new file mode 100644 index 000000000000..a7fa50444209 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types._ + +/** + * Base trait for [[RFormula]] and [[RFormulaModel]]. + */ +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { + + protected def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + +/** + * :: Experimental :: + * Implements the transforms required for fitting a dataset against an R model formula. Currently + * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the + * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +@Experimental +class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { + + def this() = this(Identifiable.randomUID("rFormula")) + + /** + * R formula parameter. The formula is provided in string form. + * @group param + */ + val formula: Param[String] = new Param(this, "formula", "R model formula") + + /** + * Sets the formula to use for this transformer. Must be called before use. + * @group setParam + * @param value an R formula in string form (e.g. "y ~ x + z") + */ + def setFormula(value: String): this.type = set(formula, value) + + /** @group getParam */ + def getFormula: String = $(formula) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** Whether the formula specifies fitting an intercept. */ + private[ml] def hasIntercept: Boolean = { + require(isDefined(formula), "Formula must be defined first.") + RFormulaParser.parse($(formula)).hasIntercept + } + + override def fit(dataset: DataFrame): RFormulaModel = { + require(isDefined(formula), "Formula must be defined first.") + val parsedFormula = RFormulaParser.parse($(formula)) + val resolvedFormula = parsedFormula.resolve(dataset.schema) + // StringType terms and terms representing interactions need to be encoded before assembly. + // TODO(ekl) add support for feature interactions + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) + val encodedTerms = resolvedFormula.terms.map { term => + dataset.schema(term) match { + case column if column.dataType == StringType => + val indexCol = term + "_idx_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) + encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) + encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) + tempColumns += indexCol + tempColumns += encodedCol + encodedCol + case _ => + term + } + } + encoderStages += new VectorAssembler(uid) + .setInputCols(encodedTerms.toArray) + .setOutputCol($(featuresCol)) + encoderStages += new ColumnPruner(tempColumns.toSet) + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) + copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) + } + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + if (hasLabelCol(schema)) { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+ + StructField($(labelCol), DoubleType, true)) + } + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" +} + +/** + * :: Experimental :: + * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. + * @param resolvedFormula the fitted R formula. + * @param pipelineModel the fitted feature model, including factor to index mappings. + */ +@Experimental +class RFormulaModel private[feature]( + override val uid: String, + resolvedFormula: ResolvedRFormula, + pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase { + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(pipelineModel.transform(dataset)) + } + + override def transformSchema(schema: StructType): StructType = { + checkCanTransform(schema) + val withFeatures = pipelineModel.transformSchema(schema) + if (hasLabelCol(schema)) { + withFeatures + } else if (schema.exists(_.name == resolvedFormula.label)) { + val nullable = schema(resolvedFormula.label).dataType match { + case _: NumericType | BooleanType => false + case _ => true + } + StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + withFeatures + } + } + + override def copy(extra: ParamMap): RFormulaModel = copyValues( + new RFormulaModel(uid, resolvedFormula, pipelineModel)) + + override def toString: String = s"RFormulaModel(${resolvedFormula})" + + private def transformLabel(dataset: DataFrame): DataFrame = { + val labelName = resolvedFormula.label + if (hasLabelCol(dataset.schema)) { + dataset + } else if (dataset.schema.exists(_.name == labelName)) { + dataset.schema(labelName).dataType match { + case _: NumericType | BooleanType => + dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) + case other => + throw new IllegalArgumentException("Unsupported type for label: " + other) + } + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + dataset + } + } + + private def checkCanTransform(schema: StructType) { + val columnNames = schema.map(_.name) + require(!columnNames.contains($(featuresCol)), "Features column already exists.") + require( + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, + "Label column already exists and is not of type DoubleType.") + } +} + +/** + * Utility transformer for removing temporary columns from a DataFrame. + * TODO(ekl) make this a public transformer + */ +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { + override val uid = Identifiable.randomUID("columnPruner") + + override def transform(dataset: DataFrame): DataFrame = { + val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) + dataset.select(columnsToKeep.map(dataset.col) : _*) + } + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) + } + + override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala new file mode 100644 index 000000000000..1ca3b92a7d92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types._ + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { + /** + * Resolves formula terms into column names. A schema is necessary for inferring the meaning + * of the special '.' term. Duplicate terms will be removed during resolution. + */ + def resolve(schema: StructType): ResolvedRFormula = { + var includedTerms = Seq[String]() + terms.foreach { + case Dot => + includedTerms ++= simpleTypes(schema).filter(_ != label.value) + case ColumnRef(value) => + includedTerms :+= value + case Deletion(term: Term) => + term match { + case ColumnRef(value) => + includedTerms = includedTerms.filter(_ != value) + case Dot => + // e.g. "- .", which removes all first-order terms + val fromSchema = simpleTypes(schema) + includedTerms = includedTerms.filter(fromSchema.contains(_)) + case _: Deletion => + assert(false, "Deletion terms cannot be nested") + case _: Intercept => + } + case _: Intercept => + } + ResolvedRFormula(label.value, includedTerms.distinct) + } + + /** Whether this formula specifies fitting with an intercept term. */ + def hasIntercept: Boolean = { + var intercept = true + terms.foreach { + case Intercept(enabled) => + intercept = enabled + case Deletion(Intercept(enabled)) => + intercept = !enabled + case _ => + } + intercept + } + + // the dot operator excludes complex column types + private def simpleTypes(schema: StructType): Seq[String] = { + schema.fields.filter(_.dataType match { + case _: NumericType | StringType | BooleanType | _: VectorUDT => true + case _ => false + }).map(_.name) + } +} + +/** + * Represents a fully evaluated and simplified R formula. + */ +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) + +/** + * R formula terms. See the R formula docs here for more information: + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +private[ml] sealed trait Term + +/* R formula reference to all available columns, e.g. "." in a formula */ +private[ml] case object Dot extends Term + +/* R formula reference to a column, e.g. "+ Species" in a formula */ +private[ml] case class ColumnRef(value: String) extends Term + +/* R formula intercept toggle, e.g. "+ 0" in a formula */ +private[ml] case class Intercept(enabled: Boolean) extends Term + +/* R formula deletion of a variable, e.g. "- Species" in a formula */ +private[ml] case class Deletion(term: Term) extends Term + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def intercept: Parser[Intercept] = + "([01])".r ^^ { case a => Intercept(a == "1") } + + def columnRef: Parser[ColumnRef] = + "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + + def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + + def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + case op ~ list => list.foldLeft(List(op)) { + case (left, "+" ~ right) => left ++ Seq(right) + case (left, "-" ~ right) => left ++ Seq(Deletion(right)) + } + } + + def formula: Parser[ParsedRFormula] = + (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 000000000000..95e430563873 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 01a4f5eb205e..f6d0b0c0e9e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,84 +17,125 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ -private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Centers the data with mean before scaling. + * It will build a dense output, so this does not work on sparse input + * and will raise an exception. + * Default: false + * @group param + */ + val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + + /** + * Scales the data to unit standard deviation. + * Default: true + * @group param + */ + val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") +} /** - * :: AlphaComponent :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. */ -@AlphaComponent -class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { +@Experimental +class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] + with StandardScalerParams { + + def this() = this(Identifiable.randomUID("stdScal")) + setDefault(withMean -> false, withStd -> true) + + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val scaler = new feature.StandardScaler().fit(input) - val model = new StandardScalerModel(this, map, scaler) - Params.inheritValues(map, this, model) - model + /** @group setParam */ + def setWithMean(value: Boolean): this.type = set(withMean, value) + + /** @group setParam */ + def setWithStd(value: Boolean): this.type = set(withStd, value) + + override def fit(dataset: DataFrame): StandardScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) + val scalerModel = scaler.fit(input) + copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - val inputType = schema(map(inputCol)).dataType + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], - s"Input column ${map(inputCol)} must be a vector column") - require(!schema.fieldNames.contains(map(outputCol)), - s"Output column ${map(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StandardScaler]]. */ -@AlphaComponent +@Experimental class StandardScalerModel private[ml] ( - override val parent: StandardScaler, - override val fittingParamMap: ParamMap, + override val uid: String, scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { + /** Standard deviation of the StandardScalerModel */ + val std: Vector = scaler.std + + /** Mean of the StandardScalerModel */ + val mean: Vector = scaler.mean + + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val scale: (Vector) => Vector = (v) => { - scaler.transform(v) - } - dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val scale = udf { scaler.transform _ } + dataset.withColumn($(outputCol), scale(col($(inputCol)))) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - val inputType = schema(map(inputCol)).dataType + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], - s"Input column ${map(inputCol)} must be a vector column") - require(!schema.fieldNames.contains(map(outputCol)), - s"Output column ${map(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScalerModel = { + val copied = new StandardScalerModel(uid, scaler) + copyValues(copied, extra).setParent(parent) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala new file mode 100644 index 000000000000..5d77ea08db65 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} + +/** + * stop words list + */ +private object StopWords { + + /** + * Use the same default stopwords list as scikit-learn. + * The original list can be found from "Glasgow Information Retrieval Group" + * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] + */ + val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + "against", "all", "almost", "alone", "along", "already", "also", "although", "always", + "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", + "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", + "around", "as", "at", "back", "be", "became", "because", "become", + "becomes", "becoming", "been", "before", "beforehand", "behind", "being", + "below", "beside", "besides", "between", "beyond", "bill", "both", + "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con", + "could", "couldnt", "cry", "de", "describe", "detail", "do", "done", + "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else", + "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone", + "everything", "everywhere", "except", "few", "fifteen", "fify", "fill", + "find", "fire", "first", "five", "for", "former", "formerly", "forty", + "found", "four", "from", "front", "full", "further", "get", "give", "go", + "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter", + "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", + "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed", + "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter", + "latterly", "least", "less", "ltd", "made", "many", "may", "me", + "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly", + "move", "much", "must", "my", "myself", "name", "namely", "neither", + "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone", + "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on", + "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", + "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps", + "please", "put", "rather", "re", "same", "see", "seem", "seemed", + "seeming", "seems", "serious", "several", "she", "should", "show", "side", + "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone", + "something", "sometime", "sometimes", "somewhere", "still", "such", + "system", "take", "ten", "than", "that", "the", "their", "them", + "themselves", "then", "thence", "there", "thereafter", "thereby", + "therefore", "therein", "thereupon", "these", "they", "thick", "thin", + "third", "this", "those", "though", "three", "through", "throughout", + "thru", "thus", "to", "together", "too", "top", "toward", "towards", + "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us", + "very", "via", "was", "we", "well", "were", "what", "whatever", "when", + "whence", "whenever", "where", "whereafter", "whereas", "whereby", + "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", + "who", "whoever", "whole", "whom", "whose", "why", "will", "with", + "within", "without", "would", "yet", "you", "your", "yours", "yourself", "yourselves") +} + +/** + * :: Experimental :: + * A feature transformer that filters out stop words from input. + * Note: null values from input array are preserved unless adding null to stopWords explicitly. + * @see [[http://en.wikipedia.org/wiki/Stop_words]] + */ +@Experimental +class StopWordsRemover(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("stopWords")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * the stop words set to be filtered out + * @group param + */ + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") + + /** @group setParam */ + def setStopWords(value: Array[String]): this.type = set(stopWords, value) + + /** @group getParam */ + def getStopWords: Array[String] = $(stopWords) + + /** + * whether to do a case sensitive comparison over the stop words + * @group param + */ + val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", + "whether to do case-sensitive comparison during filtering") + + /** @group setParam */ + def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value) + + /** @group getParam */ + def getCaseSensitive: Boolean = $(caseSensitive) + + setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) + val t = if ($(caseSensitive)) { + val stopWordsSet = $(stopWords).toSet + udf { terms: Seq[String] => + terms.filter(s => !stopWordsSet.contains(s)) + } + } else { + val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lowerStopWords = $(stopWords).map(toLower(_)).toSet + udf { terms: Seq[String] => + terms.filter(s => !lowerStopWords.contains(toLower(s))) + } + } + + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + val outputFields = schema.fields :+ + StructField($(outputCol), inputType, schema($(inputCol)).nullable) + StructType(outputFields) + } + + override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala new file mode 100644 index 000000000000..24250e4c4cf9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.util.collection.OpenHashMap + +/** + * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. + */ +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputColName = $(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") + val inputFields = schema.fields + val outputColName = $(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} + +/** + * :: Experimental :: + * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation + */ +@Experimental +class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] + with StringIndexerBase { + + def this() = this(Identifiable.randomUID("strIdx")) + + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + + override def fit(dataset: DataFrame): StringIndexerModel = { + val counts = dataset.select(col($(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() + val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + copyValues(new StringIndexerModel(uid, labels).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[StringIndexer]]. + * + * NOTE: During transformation, if the input column does not exist, + * [[StringIndexerModel.transform]] would return the input dataset unmodified. + * This is a temporary fix for the case when target labels do not exist during prediction. + * + * @param labels Ordered list of labels, corresponding to indices to be assigned + */ +@Experimental +class StringIndexerModel ( + override val uid: String, + val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) + + private val labelToIndex: OpenHashMap[String, Double] = { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map + } + + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + if (!dataset.schema.fieldNames.contains($(inputCol))) { + logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + + "Skip StringIndexerModel.") + return dataset + } + + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + throw new SparkException(s"Unseen label: $label.") + } + } + + val outputColName = $(outputCol) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(labels).toMetadata() + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), + indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + if (schema.fieldNames.contains($(inputCol))) { + validateAndTransformSchema(schema) + } else { + // If the input column does not exist during transformation, we skip StringIndexerModel. + schema + } + } + + override def copy(extra: ParamMap): StringIndexerModel = { + val copied = new StringIndexerModel(uid, labels) + copyValues(copied, extra).setParent(parent) + } +} + +/** + * :: Experimental :: + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] for converting strings into indices + */ +@Experimental +class IndexToString private[ml] ( + override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { + + def this() = + this(Identifiable.randomUID("idxToStr")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. The default value is an empty array, + * but the empty array is ignored and column metadata used instead. + * @group setParam + */ + def setLabels(value: Array[String]): this.type = set(labels, value) + + /** + * Param for array of labels. + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. + * @group param + */ + final val labels: StringArrayParam = new StringArrayParam(this, "labels", + "array of labels, if not provided metadata from inputCol is used instead.") + setDefault(labels, Array.empty[String]) + + /** + * Optional labels to be provided by the user, if not supplied column + * metadata is read for labels. + * @group getParam + */ + final def getLabels: Array[String] = $(labels) + + /** Transform the schema for the inverse transformation */ + override def transformSchema(schema: StructType): StructType = { + val inputColName = $(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be a numeric type, " + + s"but got $inputDataType.") + val inputFields = schema.fields + val outputColName = $(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } + + override def transform(dataset: DataFrame): DataFrame = { + val inputColSchema = dataset.schema($(inputCol)) + // If the labels array is empty use column metadata + val values = if ($(labels).isEmpty) { + Attribute.fromStructField(inputColSchema) + .asInstanceOf[NominalAttribute].values.get + } else { + $(labels) + } + val indexer = udf { index: Double => + val idx = index.toInt + if (0 <= idx && idx < values.length) { + values(idx) + } else { + throw new SparkException(s"Unseen index: $index ??") + } + } + val outputColName = $(outputCol) + dataset.select(col("*"), + indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) + } + + override def copy(extra: ParamMap): IndexToString = { + defaultCopy(extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e622a5cf9e6f..248288ca73e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,25 +17,103 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.types.{DataType, StringType, ArrayType} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: AlphaComponent :: + * :: Experimental :: * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + * + * @see [[RegexTokenizer]] */ -@AlphaComponent -class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { +@Experimental +class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { + + def this() = this(Identifiable.randomUID("tok")) - protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + override protected def createTransformFunc: String => Seq[String] = { _.toLowerCase.split("\\s") } - protected override def validateInputType(inputType: DataType): Unit = { + override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } - override protected def outputDataType: DataType = new ArrayType(StringType, false) + override protected def outputDataType: DataType = new ArrayType(StringType, true) + + override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) +} + +/** + * :: Experimental :: + * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split + * the text (default) or repeatedly matching the regex (if `gaps` is false). + * Optional parameters also allow filtering tokens using a minimal length. + * It returns an array of strings that can be empty. + */ +@Experimental +class RegexTokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + + def this() = this(Identifiable.randomUID("regexTok")) + + /** + * Minimum token length, >= 0. + * Default: 1, to avoid returning empty strings + * @group param + */ + val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)", + ParamValidators.gtEq(0)) + + /** @group setParam */ + def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) + + /** @group getParam */ + def getMinTokenLength: Int = $(minTokenLength) + + /** + * Indicates whether regex splits on gaps (true) or matches tokens (false). + * Default: true + * @group param + */ + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") + + /** @group setParam */ + def setGaps(value: Boolean): this.type = set(gaps, value) + + /** @group getParam */ + def getGaps: Boolean = $(gaps) + + /** + * Regex pattern used to match delimiters if [[gaps]] is true or tokens if [[gaps]] is false. + * Default: `"\\s+"` + * @group param + */ + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") + + /** @group setParam */ + def setPattern(value: String): this.type = set(pattern, value) + + /** @group getParam */ + def getPattern: String = $(pattern) + + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") + + override protected def createTransformFunc: String => Seq[String] = { str => + val re = $(pattern).r + val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = $(minTokenLength) + tokens.filter(_.length >= minLength) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, true) + + override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala new file mode 100644 index 000000000000..086917fa680f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * A feature transformer that merges multiple columns into a vector column. + */ +@Experimental +class VectorAssembler(override val uid: String) + extends Transformer with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("vecAssembler")) + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + // Schema transformation. + val schema = dataset.schema + lazy val first = dataset.first() + val attrs = $(inputCols).flatMap { c => + val field = schema(c) + val index = schema.fieldIndex(c) + field.dataType match { + case DoubleType => + val attr = Attribute.fromStructField(field) + // If the input column doesn't have ML attribute, assume numeric. + if (attr == UnresolvedAttribute) { + Some(NumericAttribute.defaultAttr.withName(c)) + } else { + Some(attr.withName(c)) + } + case _: NumericType | BooleanType => + // If the input column type is a compatible scalar type, assume numeric. + Some(NumericAttribute.defaultAttr.withName(c)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + if (group.attributes.isDefined) { + // If attributes are defined, copy them with updated names. + group.attributes.get.map { attr => + if (attr.name.isDefined) { + // TODO: Define a rigorous naming scheme. + attr.withName(c + "_" + attr.name.get) + } else { + attr + } + } + } else { + // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes + // from metadata, check the first row. + val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) + Array.fill(numAttrs)(NumericAttribute.defaultAttr) + } + } + } + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + + // Data transformation. + val assembleFunc = udf { r: Row => + VectorAssembler.assemble(r.toSeq: _*) + } + val args = $(inputCols).map { c => + schema(c).dataType match { + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") + } + } + + dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColName = $(outputCol) + val inputDataTypes = inputColNames.map(name => schema(name).dataType) + inputDataTypes.foreach { + case _: NumericType | BooleanType => + case t if t.isInstanceOf[VectorUDT] => + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) + } + + override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) +} + +private object VectorAssembler { + + private[feature] def assemble(vv: Any*): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + var cur = 0 + vv.foreach { + case v: Double => + if (v != 0.0) { + indices += cur + values += v + } + cur += 1 + case vec: Vector => + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += cur + i + values += v + } + } + cur += vec.size + case null => + // TODO: output Double.NaN? + throw new SparkException("Values to assemble cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + Vectors.sparse(cur, indices.result(), values.result()).compressed + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala new file mode 100644 index 000000000000..61b925c0fdc0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import java.lang.{Double => JDouble, Integer => JInt} +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.collection.OpenHashSet + +/** Private trait for params for VectorIndexer and VectorIndexerModel */ +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Threshold for the number of values a categorical feature can take. + * If a feature is found to have > maxCategories values, then it is declared continuous. + * Must be >= 2. + * + * (default = 20) + */ + val maxCategories = new IntParam(this, "maxCategories", + "Threshold for the number of values a categorical feature can take (>= 2)." + + " If a feature is found to have > maxCategories values, then it is declared continuous.", + ParamValidators.gtEq(2)) + + setDefault(maxCategories -> 20) + + /** @group getParam */ + def getMaxCategories: Int = $(maxCategories) +} + +/** + * :: Experimental :: + * Class for indexing categorical feature columns in a dataset of [[Vector]]. + * + * This has 2 usage modes: + * - Automatically identify categorical features (default behavior) + * - This helps process a dataset of unknown vectors into a dataset with some continuous + * features and some categorical features. The choice between continuous and categorical + * is based upon a maxCategories parameter. + * - Set maxCategories to the maximum number of categorical any categorical feature should have. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + * and feature 1 will be declared continuous. + * - Index all features, if all features are categorical + * - If maxCategories is set to be very large, then this will build an index of unique + * values for all features. + * - Warning: This can cause problems if features are continuous since this will collect ALL + * unique values to the driver. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories >= 3, then both features will be declared categorical. + * + * This returns a model which can transform categorical features to use 0-based indices. + * + * Index stability: + * - This is not guaranteed to choose the same category index across multiple runs. + * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. + * This maintains vector sparsity. + * - More stability may be added in the future. + * + * TODO: Future extensions: The following functionality is planned for the future: + * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. + * - Specify certain features to not index, either via a parameter or via existing metadata. + * - Add warning if a categorical feature has only 1 category. + * - Add option for allowing unknown categories. + */ +@Experimental +class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] + with VectorIndexerParams { + + def this() = this(Identifiable.randomUID("vecIdx")) + + /** @group setParam */ + def setMaxCategories(value: Int): this.type = set(maxCategories, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame): VectorIndexerModel = { + transformSchema(dataset.schema, logging = true) + val firstRow = dataset.select($(inputCol)).take(1) + require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") + val numFeatures = firstRow(0).getAs[Vector](0).size + val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val maxCats = $(maxCategories) + val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => + val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) + iter.foreach(localCatStats.addVector) + Iterator(localCatStats) + }.reduce((stats1, stats2) => stats1.merge(stats2)) + val model = new VectorIndexerModel(uid, numFeatures, categoryStats.getCategoryMaps) + .setParent(this) + copyValues(model) + } + + override def transformSchema(schema: StructType): StructType = { + // We do not transfer feature metadata since we do not know what types of features we will + // produce in transform(). + val dataType = new VectorUDT + require(isDefined(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(isDefined(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, $(inputCol), dataType) + SchemaUtils.appendColumn(schema, $(outputCol), dataType) + } + + override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) +} + +private object VectorIndexer { + + /** + * Helper class for tracking unique values for each feature. + * + * TODO: Track which features are known to be continuous already; do not update counts for them. + * + * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. + * @param maxCategories This class caps the number of unique values collected at maxCategories. + */ + class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + extends Serializable { + + /** featureValueSets[feature index] = set of unique values */ + private val featureValueSets = + Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]()) + + /** Merge with another instance, modifying this instance. */ + def merge(other: CategoryStats): CategoryStats = { + featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) => + otherValSet.iterator.foreach { x => + // Once we have found > maxCategories values, we know the feature is continuous + // and do not need to collect more values for it. + if (thisValSet.size <= maxCategories) thisValSet.add(x) + } + } + this + } + + /** Add a new vector to this index, updating sets of unique feature values */ + def addVector(v: Vector): Unit = { + require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" + + s" found vector of size ${v.size}.") + v match { + case dv: DenseVector => addDenseVector(dv) + case sv: SparseVector => addSparseVector(sv) + } + } + + /** + * Based on stats collected, decide which features are categorical, + * and choose indices for categories. + * + * Sparsity: This tries to maintain sparsity by treating value 0.0 specially. + * If a categorical feature takes value 0.0, then value 0.0 is given index 0. + * + * @return Feature value index. Keys are categorical feature indices (column indices). + * Values are mappings from original features values to 0-based category indices. + */ + def getCategoryMaps: Map[Int, Map[Double, Int]] = { + // Filter out features which are declared continuous. + featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map { + case (featureValues: OpenHashSet[Double], featureIndex: Int) => + var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted + val zeroExists = sortedFeatureValues.length + 1 == featureValues.size + if (zeroExists) { + sortedFeatureValues = 0.0 +: sortedFeatureValues + } + val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + (featureIndex, categoryMap) + }.toMap + } + + private def addDenseVector(dv: DenseVector): Unit = { + var i = 0 + val size = dv.size + while (i < size) { + if (featureValueSets(i).size <= maxCategories) { + featureValueSets(i).add(dv(i)) + } + i += 1 + } + } + + private def addSparseVector(sv: SparseVector): Unit = { + // TODO: This might be able to handle 0's more efficiently. + var vecIndex = 0 // index into vector + var k = 0 // index into non-zero elements + val size = sv.size + while (vecIndex < size) { + val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { + k += 1 + sv.values(k - 1) + } else { + 0.0 + } + if (featureValueSets(vecIndex).size <= maxCategories) { + featureValueSets(vecIndex).add(featureValue) + } + vecIndex += 1 + } + } + } +} + +/** + * :: Experimental :: + * Transform categorical features to use 0-based indices instead of their original values. + * - Categorical features are mapped to indices. + * - Continuous features (columns) are left unchanged. + * This also appends metadata to the output column, marking features as Numeric (continuous), + * Nominal (categorical), or Binary (either continuous or categorical). + * Non-ML metadata is not carried over from the input to the output column. + * + * This maintains vector sparsity. + * + * @param numFeatures Number of features, i.e., length of Vectors which this transforms + * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices). + * Values are maps from original features values to 0-based category indices. + * If a feature is not in this map, it is treated as continuous. + */ +@Experimental +class VectorIndexerModel private[ml] ( + override val uid: String, + val numFeatures: Int, + val categoryMaps: Map[Int, Map[Double, Int]]) + extends Model[VectorIndexerModel] with VectorIndexerParams { + + /** Java-friendly version of [[categoryMaps]] */ + def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { + categoryMaps.mapValues(_.asJava).asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]] + } + + /** + * Pre-computed feature attributes, with some missing info. + * In transform(), set attribute name and other info, if available. + */ + private val partialFeatureAttributes: Array[Attribute] = { + val attrs = new Array[Attribute](numFeatures) + var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (categoryMaps.contains(featureIndex)) { + // categorical feature + val featureValues: Array[String] = + categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) + if (featureValues.length == 2) { + attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), + values = Some(featureValues)) + } else { + attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex), + isOrdinal = Some(false), values = Some(featureValues)) + } + categoricalFeatureCount += 1 + } else { + // continuous feature + attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex)) + } + featureIndex += 1 + } + require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" + + s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures") + attrs + } + + // TODO: Check more carefully about whether this whole class will be included in a closure. + + /** Per-vector transform function */ + private val transformFunc: Vector => Vector = { + val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted + val localVectorMap = categoryMaps + val localNumFeatures = numFeatures + val f: Vector => Vector = { (v: Vector) => + assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length ${v.size}") + v match { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCatFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCatFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } + } + f + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val newField = prepOutputField(dataset.schema) + val transformUDF = udf { (vector: Vector) => transformFunc(vector) } + val newCol = transformUDF(dataset($(inputCol))) + dataset.withColumn($(outputCol), newCol, newField.metadata) + } + + override def transformSchema(schema: StructType): StructType = { + val dataType = new VectorUDT + require(isDefined(inputCol), + s"VectorIndexerModel requires input column parameter: $inputCol") + require(isDefined(outputCol), + s"VectorIndexerModel requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, $(inputCol), dataType) + + // If the input metadata specifies numFeatures, compare with expected numFeatures. + val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) + val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { + Some(origAttrGroup.attributes.get.length) + } else { + origAttrGroup.numAttributes + } + require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + + s" $numFeatures features, but input column ${$(inputCol)} had metadata specifying" + + s" ${origAttrGroup.numAttributes.get} features.") + + val newField = prepOutputField(schema) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + * @param schema Input schema + * @return Output column field. This field does not contain non-ML metadata. + */ + private def prepOutputField(schema: StructType): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + // Convert original attributes to modified attributes + val origAttrs: Array[Attribute] = origAttrGroup.attributes.get + origAttrs.zip(partialFeatureAttributes).map { + case (origAttr: Attribute, featAttr: BinaryAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NominalAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NumericAttribute) => + origAttr.withIndex(featAttr.index.get) + case (origAttr: Attribute, _) => + origAttr + } + } else { + partialFeatureAttributes + } + val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) + newAttributeGroup.toStructField() + } + + override def copy(extra: ParamMap): VectorIndexerModel = { + val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) + copyValues(copied, extra).setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala new file mode 100644 index 000000000000..c5c227227079 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * This class takes a feature vector and outputs a new feature vector with a subarray of the + * original features. + * + * The subset of features can be specified with either indices ([[setIndices()]]) + * or names ([[setNames()]]). At least one feature must be selected. Duplicate features + * are not allowed, so there can be no overlap between selected indices and names. + * + * The output vector will order features with the selected indices first (in the order given), + * followed by the selected names (in the order given). + */ +@Experimental +final class VectorSlicer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("vectorSlicer")) + + /** + * An array of indices to select features from a vector column. + * There can be no overlap with [[names]]. + * @group param + */ + val indices = new IntArrayParam(this, "indices", + "An array of indices to select features from a vector column." + + " There can be no overlap with names.", VectorSlicer.validIndices) + + setDefault(indices -> Array.empty[Int]) + + /** @group getParam */ + def getIndices: Array[Int] = $(indices) + + /** @group setParam */ + def setIndices(value: Array[Int]): this.type = set(indices, value) + + /** + * An array of feature names to select features from a vector column. + * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s. + * There can be no overlap with [[indices]]. + * @group param + */ + val names = new StringArrayParam(this, "names", + "An array of feature names to select features from a vector column." + + " There can be no overlap with indices.", VectorSlicer.validNames) + + setDefault(names -> Array.empty[String]) + + /** @group getParam */ + def getNames: Array[String] = $(names) + + /** @group setParam */ + def setNames(value: Array[String]): this.type = set(names, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def validateParams(): Unit = { + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") + } + + override def transform(dataset: DataFrame): DataFrame = { + // Validity checks + transformSchema(dataset.schema) + val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) + inputAttr.numAttributes.foreach { numFeatures => + val maxIndex = $(indices).max + require(maxIndex < numFeatures, + s"Selected feature index $maxIndex invalid for only $numFeatures input features.") + } + + // Prepare output attributes + val inds = getSelectedFeatureIndices(dataset.schema) + val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs => + inds.map(index => attrs(index)) + } + val outputAttr = selectedAttrs match { + case Some(attrs) => new AttributeGroup($(outputCol), attrs) + case None => new AttributeGroup($(outputCol), inds.length) + } + + // Select features + val slicer = udf { vec: Vector => + vec match { + case features: DenseVector => Vectors.dense(inds.map(features.apply)) + case features: SparseVector => features.slice(inds) + } + } + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) + } + + /** Get the feature indices in order: indices, names */ + private def getSelectedFeatureIndices(schema: StructType): Array[Int] = { + val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names)) + val indFeatures = $(indices) + val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length + lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" + + s" sets of features, but they overlap." + + s" indices: ${indFeatures.mkString("[", ",", "]")}." + + s" names: " + + nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]") + require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg) + indFeatures ++ nameFeatures + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") + } + val numFeaturesSelected = $(indices).length + $(names).length + val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected) + val outputFields = schema.fields :+ outputAttr.toStructField() + StructType(outputFields) + } + + override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) +} + +private[feature] object VectorSlicer { + + /** Return true if given feature indices are valid */ + def validIndices(indices: Array[Int]): Boolean = { + if (indices.isEmpty) { + true + } else { + indices.length == indices.distinct.length && indices.forall(_ >= 0) + } + } + + /** Return true if given feature names are valid */ + def validNames(names: Array[String]): Boolean = { + names.forall(_.nonEmpty) && names.length == names.distinct.length + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala new file mode 100644 index 000000000000..5af775a4159a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types._ + +/** + * Params for [[Word2Vec]] and [[Word2VecModel]]. + */ +private[feature] trait Word2VecBase extends Params + with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { + + /** + * The dimension of the code that you want to transform from words. + * @group param + */ + final val vectorSize = new IntParam( + this, "vectorSize", "the dimension of codes after transforming from words") + setDefault(vectorSize -> 100) + + /** @group getParam */ + def getVectorSize: Int = $(vectorSize) + + /** + * Number of partitions for sentences of words. + * @group param + */ + final val numPartitions = new IntParam( + this, "numPartitions", "number of partitions for sentences of words") + setDefault(numPartitions -> 1) + + /** @group getParam */ + def getNumPartitions: Int = $(numPartitions) + + /** + * The minimum number of times a token must appear to be included in the word2vec model's + * vocabulary. + * @group param + */ + final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + + "appear to be included in the word2vec model's vocabulary") + setDefault(minCount -> 5) + + /** @group getParam */ + def getMinCount: Int = $(minCount) + + setDefault(stepSize -> 0.025) + setDefault(maxIter -> 1) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } +} + +/** + * :: Experimental :: + * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further + * natural language processing or machine learning process. + */ +@Experimental +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { + + def this() = this(Identifiable.randomUID("w2v")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVectorSize(value: Int): this.type = set(vectorSize, value) + + /** @group setParam */ + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group setParam */ + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + def setMinCount(value: Int): this.type = set(minCount, value) + + override def fit(dataset: DataFrame): Word2VecModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val wordVectors = new feature.Word2Vec() + .setLearningRate($(stepSize)) + .setMinCount($(minCount)) + .setNumIterations($(maxIter)) + .setNumPartitions($(numPartitions)) + .setSeed($(seed)) + .setVectorSize($(vectorSize)) + .fit(input) + copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[Word2Vec]]. + */ +@Experimental +class Word2VecModel private[ml] ( + override val uid: String, + wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase { + + + /** + * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and + * and the vector the DenseVector that it is mapped to. + */ + @transient lazy val getVectors: DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) + sc.parallelize(wordVec.toSeq).toDF("word", "vector") + } + + /** + * Find "num" number of words closest in similarity to the given word. + * Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word. + */ + def findSynonyms(word: String, num: Int): DataFrame = { + findSynonyms(wordVectors.transform(word), num) + } + + /** + * Find "num" number of words closest to similarity to the given vector representation + * of the word. Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word vector. + */ + def findSynonyms(word: Vector, num: Int): DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a sentence column to a vector column to represent the whole sentence. The transform + * is performed by averaging all word vectors it contains. + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val word2Vec = udf { sentence: Seq[String] => + if (sentence.size == 0) { + Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) + } else { + val cum = Vectors.zeros($(vectorSize)) + val model = bWordVectors.value.getVectors + for (word <- sentence) { + if (model.contains(word)) { + axpy(1.0, bWordVectors.value.transform(word), cum) + } else { + // pass words which not belong to model + } + } + scal(1.0 / sentence.size, cum) + cum + } + } + dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): Word2VecModel = { + val copied = new Word2VecModel(uid, wordVectors) + copyValues(copied, extra).setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala new file mode 100644 index 000000000000..4571ab26800c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, VectorAssembler} +import org.apache.spark.sql.DataFrame + +/** + * == Feature transformers == + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as [[Transformer]]s, which transform one [[DataFrame]] + * into another, e.g., [[HashingTF]]. + * Some feature transformers are implemented as [[Estimator]]s, because the transformation requires + * some aggregated information of the dataset, e.g., document frequencies in [[IDF]]. + * For those feature transformers, calling [[Estimator!.fit]] is required to obtain the model first, + * e.g., [[IDFModel]], in order to apply transformation. + * The transformation is usually done by appending new columns to the input [[DataFrame]], so all + * input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * [[Pipeline]] can be used to chain feature transformers, and [[VectorAssembler]] can be used to + * combine multiple feature transformations, for example: + * + * {{{ + * import org.apache.spark.ml.feature._ + * import org.apache.spark.ml.Pipeline + * + * // a DataFrame with three columns: id (integer), text (string), and rating (double). + * val df = sqlContext.createDataFrame(Seq( + * (0, "Hi I heard about Spark", 3.0), + * (1, "I wish Java could use case classes", 4.0), + * (2, "Logistic regression models are neat", 4.0) + * )).toDF("id", "text", "rating") + * + * // define feature transformers + * val tok = new RegexTokenizer() + * .setInputCol("text") + * .setOutputCol("words") + * val sw = new StopWordsRemover() + * .setInputCol("words") + * .setOutputCol("filtered_words") + * val tf = new HashingTF() + * .setInputCol("filtered_words") + * .setOutputCol("tf") + * .setNumFeatures(10000) + * val idf = new IDF() + * .setInputCol("tf") + * .setOutputCol("tf_idf") + * val assembler = new VectorAssembler() + * .setInputCols(Array("tf_idf", "rating")) + * .setOutputCol("features") + * + * // assemble and fit the feature transformation pipeline + * val pipeline = new Pipeline() + * .setStages(Array(tok, sw, tf, idf, assembler)) + * val model = pipeline.fit(df) + * + * // save transformed features with raw data + * model.transform(df) + * .select("id", "text", "rating", "features") + * .write.format("parquet").save("/output/path") + * }}} + * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see [[http://scikit-learn.org/stable/modules/preprocessing.html scikit-learn.preprocessing]] + */ +package object feature diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java index 00d9c802e930..87f4223964ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -16,10 +16,10 @@ */ /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. */ -@AlphaComponent +@Experimental package org.apache.spark.ml; -import org.apache.spark.annotation.AlphaComponent; +import org.apache.spark.annotation.Experimental; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index 51cd48c90432..c589d06d9f7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -18,7 +18,33 @@ package org.apache.spark /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. + * + * @groupname param Parameters + * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get + * the parameter values through setters and getters, respectively. + * @groupprio param -5 + * + * @groupname setParam Parameter setters + * @groupprio setParam 5 + * + * @groupname getParam Parameter getters + * @groupprio getParam 6 + * + * @groupname expertParam (expert-only) Parameters + * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can + * take. Users can set and get the parameter values through setters and getters, + * respectively. + * @groupprio expertParam 7 + * + * @groupname expertSetParam (expert-only) Parameter setters + * @groupprio expertSetParam 8 + * + * @groupname expertGetParam (expert-only) Parameter getters + * @groupprio expertGetParam 9 + * + * @groupname Ungrouped Members + * @groupprio Ungrouped 0 */ package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5fb4379e23c2..91c0a5631319 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,103 +17,332 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable +import scala.collection.JavaConverters._ -import java.lang.reflect.Modifier - -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.Identifiable +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.ml.util.Identifiable /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A param with self-contained documentation and optionally default value. Primitive-typed param * should use the specialized versions, which are more friendly to Java users. * * @param parent parent object * @param name param name * @param doc documentation + * @param isValid optional validation method which indicates if a value is valid. + * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ -@AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) +@DeveloperApi +class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) extends Serializable { + def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: String, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue[T]) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** - * Creates a param pair with the given value (for Java). + * Assert that the given value is valid for this parameter. + * + * Note: Parameter checks involving interactions between multiple parameters should be + * implemented in [[Params.validateParams()]]. Checks for input/output columns should be + * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. + * + * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters + * should be specified via [[ParamPair]]. + * + * @throws IllegalArgumentException if the value is invalid */ + private[param] def validate(value: T): Unit = { + if (!isValid(value)) { + throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + } + } + + /** Creates a param pair with the given value (for Java). */ def w(value: T): ParamPair[T] = this -> value - /** - * Creates a param pair with the given value (for Scala). - */ + /** Creates a param pair with the given value (for Scala). */ def ->(value: T): ParamPair[T] = ParamPair(this, value) - override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" - } else { - s"$name: $doc" + override final def toString: String = s"${parent}__$name" + + override final def hashCode: Int = toString.## + + override final def equals(obj: Any): Boolean = { + obj match { + case p: Param[_] => (p.parent == parent) && (p.name == name) + case _ => false } } } +/** + * :: DeveloperApi :: + * Factory methods for common validation functions for [[Param.isValid]]. + * The numerical methods only support Int, Long, Float, and Double. + */ +@DeveloperApi +object ParamValidators { + + /** (private[param]) Default validation always return true */ + private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true + + /** + * Private method for checking numerical types and converting to Double. + * This is mainly for the sake of compilation; type checks are really handled + * by [[Params]] setters and the [[ParamPair]] constructor. + */ + private def getDouble[T](value: T): Double = value match { + case x: Int => x.toDouble + case x: Long => x.toDouble + case x: Float => x.toDouble + case x: Double => x.toDouble + case _ => + // The type should be checked before this is ever called. + throw new IllegalArgumentException("Numerical Param validation failed because" + + s" of unexpected input type: ${value.getClass}") + } + + /** Check if value > lowerBound */ + def gt[T](lowerBound: Double): T => Boolean = { (value: T) => + getDouble(value) > lowerBound + } + + /** Check if value >= lowerBound */ + def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) => + getDouble(value) >= lowerBound + } + + /** Check if value < upperBound */ + def lt[T](upperBound: Double): T => Boolean = { (value: T) => + getDouble(value) < upperBound + } + + /** Check if value <= upperBound */ + def ltEq[T](upperBound: Double): T => Boolean = { (value: T) => + getDouble(value) <= upperBound + } + + /** + * Check for value in range lowerBound to upperBound. + * @param lowerInclusive If true, check for value >= lowerBound. + * If false, check for value > lowerBound. + * @param upperInclusive If true, check for value <= upperBound. + * If false, check for value < upperBound. + */ + def inRange[T]( + lowerBound: Double, + upperBound: Double, + lowerInclusive: Boolean, + upperInclusive: Boolean): T => Boolean = { (value: T) => + val x: Double = getDouble(value) + val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound + val upperValid = if (upperInclusive) x <= upperBound else x < upperBound + lowerValid && upperValid + } + + /** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */ + def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = { + inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true) + } + + /** Check for value in an allowed set of values. */ + def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) => + allowed.contains(value) + } + + /** Check for value in an allowed set of values. */ + def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => + allowed.contains(value) + } + + /** Check that the array length is greater than lowerBound. */ + def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => + value.length > lowerBound + } +} + // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... -/** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) - extends Param[Double](parent, name, doc, defaultValue) { +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Double]]] for Java. + */ +@DeveloperApi +class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean) + extends Param[Double](parent, name, doc, isValid) { + + def this(parent: String, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + def this(parent: Identifiable, name: String, doc: String, isValid: Double => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) } -/** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) - extends Param[Int](parent, name, doc, defaultValue) { +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Int]]] for Java. + */ +@DeveloperApi +class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean) + extends Param[Int](parent, name, doc, isValid) { + + def this(parent: String, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) } -/** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) - extends Param[Float](parent, name, doc, defaultValue) { +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Float]]] for Java. + */ +@DeveloperApi +class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean) + extends Param[Float](parent, name, doc, isValid) { + + def this(parent: String, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Float => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + + /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) } -/** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) - extends Param[Long](parent, name, doc, defaultValue) { +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Long]]] for Java. + */ +@DeveloperApi +class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean) + extends Param[Long](parent, name, doc, isValid) { + + def this(parent: String, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + def this(parent: Identifiable, name: String, doc: String, isValid: Long => Boolean) = + this(parent.uid, name, doc, isValid) + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + + /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) } -/** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) - extends Param[Boolean](parent, name, doc, defaultValue) { +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Boolean]]] for Java. + */ +@DeveloperApi +class BooleanParam(parent: String, name: String, doc: String) // No need for isValid + extends Param[Boolean](parent, name, doc) { + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } /** - * A param amd its value. + * :: DeveloperApi :: + * Specialized version of [[Param[Array[String]]]] for Java. + */ +@DeveloperApi +class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) + extends Param[Array[String]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) +} + +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Double]]]] for Java. + */ +@DeveloperApi +class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean) + extends Param[Array[Double]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = + w(value.asScala.map(_.asInstanceOf[Double]).toArray) +} + +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Int]]]] for Java. */ -case class ParamPair[T](param: Param[T], value: T) +@DeveloperApi +class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean) + extends Param[Array[Int]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] = + w(value.asScala.map(_.asInstanceOf[Int]).toArray) +} /** - * :: AlphaComponent :: + * :: Experimental :: + * A param and its value. + */ +@Experimental +case class ParamPair[T](param: Param[T], value: T) { + // This is *the* place Param.validate is called. Whenever a parameter is specified, we should + // always construct a ParamPair so that validate is called. + param.validate(value) +} + +/** + * :: DeveloperApi :: * Trait for components that take parameters. This also provides an internal param map to store * parameter values attached to the instance. */ -@AlphaComponent +@DeveloperApi trait Params extends Identifiable with Serializable { - /** Returns all params. */ - def params: Array[Param[_]] = { + /** + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. + * + * Note: Developer should not use this method in constructor because we cannot guarantee that + * this variable gets initialized before other params. + */ + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -124,106 +353,273 @@ trait Params extends Identifiable with Serializable { } /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. + * Validates parameter values stored internally. + * Raise an exception if any parameter value is invalid. + * + * This only needs to check for interactions between parameters. + * Parameter value checks which do not depend on other parameters are handled by + * [[Param.validate()]]. This method does not handle input/output column parameters; + * those are checked during schema validation. */ - def validate(paramMap: ParamMap): Unit = {} + def validateParams(): Unit = { + // Do nothing by default. Override to handle Param interactions. + } /** - * Validates parameter values stored internally. - * Raise an exception if any parameter value is invalid. + * Explains a param. + * @param param input param, must belong to this instance. + * @return a string that contains the input param name, doc, and optionally its default value and + * the user-supplied value */ - def validate(): Unit = validate(ParamMap.empty) + def explainParam(param: Param[_]): String = { + shouldOwn(param) + val valueStr = if (isDefined(param)) { + val defaultValueStr = getDefault(param).map("default: " + _) + val currentValueStr = get(param).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") + } else { + "(undefined)" + } + s"${param.name}: ${param.doc} $valueStr" + } /** - * Returns the documentation of all params. + * Explains all params of this instance. + * @see [[explainParam()]] */ - def explainParams(): String = params.mkString("\n") + def explainParams(): String = { + params.map(explainParam).mkString("\n") + } /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) + final def isSet(param: Param[_]): Boolean = { + shouldOwn(param) paramMap.contains(param) } + /** Checks whether a param is explicitly set or has a default value. */ + final def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) + } + + /** Tests whether this instance contains a param with a given name. */ + def hasParam(paramName: String): Boolean = { + params.exists(_.name == paramName) + } + /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + def getParam(paramName: String): Param[Any] = { + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ - private[ml] def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) - paramMap.put(param.asInstanceOf[Param[Any]], value) - this + protected final def set[T](param: Param[T], value: T): this.type = { + set(param -> value) } /** * Sets a parameter (by name) in the embedded param map. */ - private[ml] def set(param: String, value: Any): this.type = { + protected final def set(param: String, value: Any): this.type = { set(getParam(param), value) } /** - * Gets the value of a parameter in the embedded param map. + * Sets a parameter in the embedded param map. */ - private[ml] def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - paramMap(param) + protected final def set(paramPair: ParamPair[_]): this.type = { + shouldOwn(paramPair.param) + paramMap.put(paramPair) + this } /** - * Internal param map. + * Optionally returns the user-supplied value of a param. */ - protected val paramMap: ParamMap = ParamMap.empty -} + final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) + paramMap.get(param) + } + + /** + * Clears the user-supplied value for the input param. + */ + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } + + /** + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. + */ + final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get + } + + /** An alias for [[getOrDefault()]]. */ + protected final def $[T](param: Param[T]): T = getOrDefault(param) + + /** + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value + */ + protected final def setDefault[T](param: Param[T], value: T): this.type = { + defaultParamMap.put(param -> value) + this + } + + /** + * Sets default values for a list of params. + * + * Note: Java developers should use the single-parameter [[setDefault()]]. + * Annotating this with varargs can cause compilation failures due to a Scala compiler bug. + * See SPARK-9268. + * + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. + */ + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Gets the default value of a parameter. + */ + final def getDefault[T](param: Param[T]): Option[T] = { + shouldOwn(param) + defaultParamMap.get(param) + } + + /** + * Tests whether the input param has a default value set. + */ + final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) + } + + /** + * Creates a copy of this instance with the same UID and some extra params. + * Subclasses should implement this method and set the return type properly. + * + * @see [[defaultCopy()]] + */ + def copy(extra: ParamMap): Params + + /** + * Default implementation of copy with extra params. + * It tries to create a new instance with the same UID. + * Then it copies the embedded and extra parameters over and returns the new instance. + */ + protected final def defaultCopy[T <: Params](extra: ParamMap): T = { + val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) + copyValues(that, extra).asInstanceOf[T] + } + + /** + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts, i.e., with ordering: default param values < user-supplied values < extra. + */ + final def extractParamMap(extra: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extra + } + + /** + * [[extractParamMap]] with no extra values. + */ + final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty -private[ml] object Params { + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent == uid && hasParam(param.name), s"Param $param does not belong to $this.") + } /** - * Copies parameter values from the parent estimator to the child model it produced. - * @param paramMap the param map that holds parameters of the parent - * @param parent the parent estimator - * @param child the child model + * Copies param values from this instance to another instance for params shared by them. + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] + * @return the target instance with param values copied */ - def inheritValues[E <: Params, M <: E]( - paramMap: ParamMap, - parent: E, - child: M): Unit = { - parent.params.foreach { param => - if (paramMap.contains(param)) { - child.set(child.getParam(param.name), paramMap(param)) + protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { + val map = paramMap ++ extra + params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params + if (map.contains(param) && to.hasParam(param.name)) { + to.set(param.name, map(param)) } } + to } } /** - * :: AlphaComponent :: + * :: DeveloperApi :: + * Java-friendly wrapper for [[Params]]. + * Java developers who need to extend [[Params]] should use this class instead. + * If you need to extend a abstract class which already extends [[Params]], then that abstract + * class should be Java-friendly as well. + */ +@DeveloperApi +abstract class JavaParams extends Params + +/** + * :: Experimental :: * A param to value map. */ -@AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +@Experimental +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { + + /* DEVELOPERS: About validating parameter values + * This and ParamPair are the only two collections of parameters. + * This class should always create ParamPairs when + * specifying new parameter values. ParamPair will then call Param.validate(). + */ /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). */ - def put[T](param: Param[T], value: T): this.type = { - map(param.asInstanceOf[Param[Any]]) = value - this - } + def put[T](param: Param[T], value: T): this.type = put(param -> value) /** * Puts a list of param pairs (overwrites if the input params exists). @@ -231,18 +627,23 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten @varargs def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => - put(p.param.asInstanceOf[Param[Any]], p.value) + map(p.param.asInstanceOf[Param[Any]]) = p.value } this } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) } /** @@ -250,10 +651,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -265,6 +663,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ @@ -274,19 +679,19 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) override def toString: String = { - map.map { case (param, value) => - s"\t${param.parent.uid}-${param.name}: $value" + map.toSeq.sortBy(_._1.name).map { case (param, value) => + s"\t${param.parent}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. @@ -310,8 +715,14 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten ParamPair(param, value) } } + + /** + * Number of param pairs in this map. + */ + def size: Int = map.size } +@Experimental object ParamMap { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala new file mode 100644 index 000000000000..8c16c6149b40 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" + * }}} + */ +private[shared] object SharedParamsCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter (>= 0)", + isValid = "ParamValidators.gtEq(0)"), + ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)", + isValid = "ParamValidators.gtEq(0)"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + + " probabilities. Note: Not all models output well-calibrated probability estimates!" + + " These probabilities should be treated as confidences, not precise probabilities.", + Some("\"probability\"")), + ParamDesc[Double]("threshold", + "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), + isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), + ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original probability" + + " of that class and t is the class' threshold.", + isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("inputCols", "input column names"), + ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", + isValid = "ParamValidators.gtEq(1)"), + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), + ParamDesc[Boolean]("standardization", "whether to standardize the training features" + + " before fitting the model.", Some("true")), + ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + isValid = "ParamValidators.inRange(0, 1)"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0.")) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None, + isValid: String = "", + finalMethods: Boolean = true) { + + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam" + case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + |""".stripMargin + }.getOrElse("") + val isValid = if (param.isValid != "") { + ", " + param.isValid + } else { + "" + } + val methodStr = if (param.finalMethods) { + "final def" + } else { + "def" + } + + s""" + |/** + | * Trait for shared param $name$defaultValueDoc. + | */ + |private[ml] trait Has$Name extends Params { + | + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) + |$setDefault + | /** @group getParam */ + | $methodStr get$Name: $T = $$($name) + |} + |""".stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + | + |// scalastyle:off + |""".stripMargin + + val footer = "// scalastyle:on\n" + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 000000000000..c26768953e3d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + +// scalastyle:off + +/** + * Trait for shared param regParam. + */ +private[ml] trait HasRegParam extends Params { + + /** + * Param for regularization parameter (>= 0). + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0)) + + /** @group getParam */ + final def getRegParam: Double = $(regParam) +} + +/** + * Trait for shared param maxIter. + */ +private[ml] trait HasMaxIter extends Params { + + /** + * Param for maximum number of iterations (>= 0). + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0)) + + /** @group getParam */ + final def getMaxIter: Int = $(maxIter) +} + +/** + * Trait for shared param featuresCol (default: "features"). + */ +private[ml] trait HasFeaturesCol extends Params { + + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + setDefault(featuresCol, "features") + + /** @group getParam */ + final def getFeaturesCol: String = $(featuresCol) +} + +/** + * Trait for shared param labelCol (default: "label"). + */ +private[ml] trait HasLabelCol extends Params { + + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + setDefault(labelCol, "label") + + /** @group getParam */ + final def getLabelCol: String = $(labelCol) +} + +/** + * Trait for shared param predictionCol (default: "prediction"). + */ +private[ml] trait HasPredictionCol extends Params { + + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + setDefault(predictionCol, "prediction") + + /** @group getParam */ + final def getPredictionCol: String = $(predictionCol) +} + +/** + * Trait for shared param rawPredictionCol (default: "rawPrediction"). + */ +private[ml] trait HasRawPredictionCol extends Params { + + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + setDefault(rawPredictionCol, "rawPrediction") + + /** @group getParam */ + final def getRawPredictionCol: String = $(rawPredictionCol) +} + +/** + * Trait for shared param probabilityCol (default: "probability"). + */ +private[ml] trait HasProbabilityCol extends Params { + + /** + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + + setDefault(probabilityCol, "probability") + + /** @group getParam */ + final def getProbabilityCol: String = $(probabilityCol) +} + +/** + * Trait for shared param threshold (default: 0.5). + */ +private[ml] trait HasThreshold extends Params { + + /** + * Param for threshold in binary classification prediction, in range [0, 1]. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) + + setDefault(threshold, 0.5) + + /** @group getParam */ + def getThreshold: Double = $(threshold) +} + +/** + * Trait for shared param thresholds. + */ +private[ml] trait HasThresholds extends Params { + + /** + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + * @group param + */ + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) + + /** @group getParam */ + def getThresholds: Array[Double] = $(thresholds) +} + +/** + * Trait for shared param inputCol. + */ +private[ml] trait HasInputCol extends Params { + + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = $(inputCol) +} + +/** + * Trait for shared param inputCols. + */ +private[ml] trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = $(inputCols) +} + +/** + * Trait for shared param outputCol (default: uid + "__output"). + */ +private[ml] trait HasOutputCol extends Params { + + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + setDefault(outputCol, uid + "__output") + + /** @group getParam */ + final def getOutputCol: String = $(outputCol) +} + +/** + * Trait for shared param checkpointInterval. + */ +private[ml] trait HasCheckpointInterval extends Params { + + /** + * Param for checkpoint interval (>= 1). + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) + + /** @group getParam */ + final def getCheckpointInterval: Int = $(checkpointInterval) +} + +/** + * Trait for shared param fitIntercept (default: true). + */ +private[ml] trait HasFitIntercept extends Params { + + /** + * Param for whether to fit an intercept term. + * @group param + */ + final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term") + + setDefault(fitIntercept, true) + + /** @group getParam */ + final def getFitIntercept: Boolean = $(fitIntercept) +} + +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + +/** + * Trait for shared param standardization (default: true). + */ +private[ml] trait HasStandardization extends Params { + + /** + * Param for whether to standardize the training features before fitting the model.. + * @group param + */ + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.") + + setDefault(standardization, true) + + /** @group getParam */ + final def getStandardization: Boolean = $(standardization) +} + +/** + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). + */ +private[ml] trait HasSeed extends Params { + + /** + * Param for random seed. + * @group param + */ + final val seed: LongParam = new LongParam(this, "seed", "random seed") + + setDefault(seed, this.getClass.getName.hashCode.toLong) + + /** @group getParam */ + final def getSeed: Long = $(seed) +} + +/** + * Trait for shared param elasticNetParam. + */ +private[ml] trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + + /** @group getParam */ + final def getElasticNetParam: Double = $(elasticNetParam) +} + +/** + * Trait for shared param tol. + */ +private[ml] trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = $(tol) +} + +/** + * Trait for shared param stepSize. + */ +private[ml] trait HasStepSize extends Params { + + /** + * Param for Step size to be used for each iteration of optimization.. + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.") + + /** @group getParam */ + final def getStepSize: Double = $(stepSize) +} + +/** + * Trait for shared param weightCol. + */ +private[ml] trait HasWeightCol extends Params { + + /** + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * @group param + */ + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) +} +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index ef141d3eb2b0..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -private[ml] trait HasRegParam extends Params { - /** param for regularization parameter */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** param for max number of iterations */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** param for features column name */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** param for label column name */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasScoreCol extends Params { - /** param for score column name */ - val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) - def getScoreCol: String = get(scoreCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** param for prediction column name */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasThreshold extends Params { - /** param for threshold in (binary) prediction */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** param for input column name */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasOutputCol extends Params { - /** param for output column name */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - def getOutputCol: String = get(outputCol) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala new file mode 100644 index 000000000000..f5a022c31ed9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.api.r + +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.DataFrame + +private[r] object SparkRWrappers { + def fitRModelFormula( + value: String, + df: DataFrame, + family: String, + lambda: Double, + alpha: Double): PipelineModel = { + val formula = new RFormula().setFormula(value) + val estimator = family match { + case "gaussian" => new LinearRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) + case "binomial" => new LogisticRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) + } + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 511cb2fe4005..7db8ad8d2791 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} +import java.io.IOException import scala.collection.mutable import scala.reflect.ClassTag @@ -26,131 +27,205 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.LAPACK.{getInstance => lapack} -import org.jblas.DoubleMatrix +import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW -import org.apache.spark.{HashPartitioner, Logging, Partitioner} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.{Logging, Partitioner} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom +/** + * Common params for ALS and ALSModel. + */ +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { + /** + * Param for the column name for user ids. + * Default: "user" + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = $(userCol) + + /** + * Param for the column name for item ids. + * Default: "item" + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = $(itemCol) +} + /** * Common params for ALS. */ -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam - with HasPredictionCol { +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam + with HasPredictionCol with HasCheckpointInterval with HasSeed { - /** Param for rank of the matrix factorization. */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) - def getRank: Int = get(rank) + /** + * Param for rank of the matrix factorization (>= 1). + * Default: 10 + * @group param + */ + val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1)) + + /** @group getParam */ + def getRank: Int = $(rank) + + /** + * Param for number of user blocks (>= 1). + * Default: 10 + * @group param + */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", + ParamValidators.gtEq(1)) - /** Param for number of user blocks. */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) - def getNumUserBlocks: Int = get(numUserBlocks) + /** @group getParam */ + def getNumUserBlocks: Int = $(numUserBlocks) - /** Param for number of item blocks. */ - val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) - def getNumItemBlocks: Int = get(numItemBlocks) + /** + * Param for number of item blocks (>= 1). + * Default: 10 + * @group param + */ + val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", + ParamValidators.gtEq(1)) + + /** @group getParam */ + def getNumItemBlocks: Int = $(numItemBlocks) + + /** + * Param to decide whether to use implicit preference. + * Default: false + * @group param + */ + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") - /** Param to decide whether to use implicit preference. */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) - def getImplicitPrefs: Boolean = get(implicitPrefs) + /** @group getParam */ + def getImplicitPrefs: Boolean = $(implicitPrefs) - /** Param for the alpha parameter in the implicit preference formulation. */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) - def getAlpha: Double = get(alpha) + /** + * Param for the alpha parameter in the implicit preference formulation (>= 0). + * Default: 1.0 + * @group param + */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", + ParamValidators.gtEq(0)) - /** Param for the column name for user ids. */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) - def getUserCol: String = get(userCol) + /** @group getParam */ + def getAlpha: Double = $(alpha) - /** Param for the column name for item ids. */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) - def getItemCol: String = get(itemCol) + /** + * Param for the column name for ratings. + * Default: "rating" + * @group param + */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") - /** Param for the column name for ratings. */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) - def getRatingCol: String = get(ratingCol) + /** @group getParam */ + def getRatingCol: String = $(ratingCol) + /** + * Param for whether to apply nonnegativity constraints. + * Default: false + * @group param + */ val nonnegative = new BooleanParam( - this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) - val getNonnegative: Boolean = get(nonnegative) + this, "nonnegative", "whether to use nonnegative constraint for least squares") + + /** @group getParam */ + def getNonnegative: Boolean = $(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) /** * Validates and transforms the input schema. * @param schema input schema - * @param paramMap extra params * @return output schema */ - protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - assert(schema(map(userCol)).dataType == IntegerType) - assert(schema(map(itemCol)).dataType== IntegerType) - val ratingType = schema(map(ratingCol)).dataType - assert(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = map(predictionCol) - assert(!schema.fieldNames.contains(predictionColName), - s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) - StructType(newFields) + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + val ratingType = schema($(ratingCol)).dataType + require(ratingType == FloatType || ratingType == DoubleType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** + * :: Experimental :: * Model fitted by ALS. + * + * @param rank rank of the matrix factorization model + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ +@Experimental class ALSModel private[ml] ( - override val parent: ALS, - override val fittingParamMap: ParamMap, - k: Int, - userFactors: RDD[(Int, Array[Float])], - itemFactors: RDD[(Int, Array[Float])]) - extends Model[ALSModel] with ALSParams { + override val uid: String, + val rank: Int, + @transient val userFactors: DataFrame, + @transient val itemFactors: DataFrame) + extends Model[ALSModel] with ALSModelParams { + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - import dataset.sqlContext.createDataFrame - val map = this.paramMap ++ paramMap - val users = userFactors.toDataFrame("id", "features") - val items = itemFactors.toDataFrame("id", "features") - val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + override def transform(dataset: DataFrame): DataFrame = { + // Register a UDF for DataFrame, and then + // create a new column named map(predictionCol) by running the predict UDF. + val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } - val inputColumns = dataset.schema.fieldNames - val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) - val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction dataset - .join(users, dataset(map(userCol)) === users("id"), "left") - .join(items, dataset(map(itemCol)) === items("id"), "left") - .select(outputColumns: _*) - // TODO: Just use a dataset("*") - // .select(dataset("*"), prediction) + .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") + .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .select(dataset("*"), + predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } - override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) + } + + override def copy(extra: ParamMap): ALSModel = { + val copied = new ALSModel(uid, rank, userFactors, itemFactors) + copyValues(copied, extra).setParent(parent) } } /** + * :: Experimental :: * Alternating Least Squares (ALS) matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, @@ -179,52 +254,89 @@ class ALSModel private[ml] ( * indicated user * preferences rather than explicit ratings given to items. */ -class ALS extends Estimator[ALSModel] with ALSParams { +@Experimental +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating + def this() = this(Identifiable.randomUID("als")) + + /** @group setParam */ def setRank(value: Int): this.type = set(rank, value) + + /** @group setParam */ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + + /** @group setParam */ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + + /** @group setParam */ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + + /** @group setParam */ def setAlpha(value: Double): this.type = set(alpha, value) + + /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ def setItemCol(value: String): this.type = set(itemCol, value) + + /** @group setParam */ def setRatingCol(value: String): this.type = set(ratingCol, value) + + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setNonnegative(value: Boolean): this.type = set(nonnegative, value) - /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + /** @group setParam */ + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Sets both numUserBlocks and numItemBlocks to the specific value. + * @group setParam + */ def setNumBlocks(value: Int): this.type = { setNumUserBlocks(value) setNumItemBlocks(value) this } - setMaxIter(20) - setRegParam(1.0) - - override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = this.paramMap ++ paramMap + override def fit(dataset: DataFrame): ALSModel = { + import dataset.sqlContext.implicits._ val ratings = dataset - .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) + .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), + col($(ratingCol)).cast(FloatType)) .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), - numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), - maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), - alpha = map(alpha), nonnegative = map(nonnegative)) - val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) - Params.inheritValues(map, this, model) - model + val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), + numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), + maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), + alpha = $(alpha), nonnegative = $(nonnegative), + checkpointInterval = $(checkpointInterval), seed = $(seed)) + val userDF = userFactors.toDF("id", "features") + val itemDF = itemFactors.toDF("id", "features") + val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) + copyValues(model) } - override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap) + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): ALS = defaultCopy(extra) } /** @@ -238,12 +350,16 @@ class ALS extends Estimator[ALSModel] with ALSParams { @DeveloperApi object ALS extends Logging { - /** Rating class for better code readability. */ + /** + * :: DeveloperApi :: + * Rating class for better code readability. + */ + @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { - /** Solves a least squares problem (possibly with other constraints). */ + /** Solves a least squares problem with regularization (possibly with other constraints). */ def solve(ne: NormalEquation, lambda: Double): Array[Float] } @@ -255,20 +371,19 @@ object ALS extends Logging { /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } @@ -291,14 +406,14 @@ object ALS extends Logging { private[recommendation] class NNLSSolver extends LeastSquaresNESolver { private var rank: Int = -1 private var workspace: NNLS.Workspace = _ - private var ata: DoubleMatrix = _ + private var ata: Array[Double] = _ private var initialized: Boolean = false private def initialize(rank: Int): Unit = { if (!initialized) { this.rank = rank workspace = NNLS.createWorkspace(rank) - ata = new DoubleMatrix(rank, rank) + ata = new Array[Double](rank * rank) initialized = true } else { require(this.rank == rank) @@ -314,8 +429,8 @@ object ALS extends Logging { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val rank = ne.k initialize(rank) - fillAtA(ne.ata, lambda * ne.n) - val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace) + fillAtA(ne.ata, lambda) + val x = NNLS.solve(ata, ne.atb, workspace) ne.reset() x.map(x => x.toFloat) } @@ -328,23 +443,30 @@ object ALS extends Logging { var i = 0 var pos = 0 var a = 0.0 - val data = ata.data while (i < rank) { var j = 0 while (j <= i) { a = triAtA(pos) - data(i * rank + j) = a - data(j * rank + i) = a + ata(i * rank + j) = a + ata(j * rank + i) = a pos += 1 j += 1 } - data(i * rank + i) += lambda + ata(i * rank + i) += lambda i += 1 } } } - /** Representing a normal equation (ALS' subproblem). */ + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -353,8 +475,6 @@ object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -368,28 +488,13 @@ object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) require(a.length == k) copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { - require(a.length == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) - copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -399,7 +504,6 @@ object ALS extends Logging { require(other.k == k) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) - n += other.n this } @@ -407,13 +511,14 @@ object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } /** + * :: DeveloperApi :: * Implementation of the ALS algorithm. */ + @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore ratings: RDD[Rating[ID]], rank: Int = 10, @@ -426,13 +531,14 @@ object ALS extends Logging { nonnegative: Boolean = false, intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") val sc = ratings.sparkContext - val userPart = new HashPartitioner(numUserBlocks) - val itemPart = new HashPartitioner(numItemBlocks) + val userPart = new ALSPartitioner(numUserBlocks) + val itemPart = new ALSPartitioner(numItemBlocks) val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) val solver = if (nonnegative) new NNLSSolver else new CholeskySolver @@ -453,6 +559,18 @@ object ALS extends Logging { val seedGen = new XORShiftRandom(seed) var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + var previousCheckpointFile: Option[String] = None + val shouldCheckpoint: Int => Boolean = (iter) => + sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + val deletePreviousCheckpointFile: () => Unit = () => + previousCheckpointFile.foreach { file => + try { + FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true) + } catch { + case e: IOException => + logWarning(s"Cannot delete checkpoint file $file:", e) + } + } if (implicitPrefs) { for (iter <- 1 to maxIter) { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) @@ -460,19 +578,30 @@ object ALS extends Logging { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, implicitPrefs, alpha, solver) previousItemFactors.unpersist() - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - itemFactors.checkpoint() - } itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) + // TODO: Generalize PeriodicGraphCheckpointer and use it here. + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() // itemFactors gets materialized in computeFactors. + } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, implicitPrefs, alpha, solver) + if (shouldCheckpoint(iter)) { + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } previousUserFactors.unpersist() } } else { for (iter <- 0 until maxIter) { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, solver = solver) + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() + itemFactors.count() // checkpoint item factors and cut lineage + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, solver = solver) } @@ -480,13 +609,23 @@ object ALS extends Logging { val userIdAndFactors = userInBlocks .mapValues(_.srcIds) .join(userFactors) - .values + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks + // and userFactors. + }, preservesPartitioning = true) .setName("userFactors") .persist(finalRDDStorageLevel) val itemIdAndFactors = itemInBlocks .mapValues(_.srcIds) .join(itemFactors) - .values + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + }, preservesPartitioning = true) .setName("itemFactors") .persist(finalRDDStorageLevel) if (finalRDDStorageLevel != StorageLevel.NONE) { @@ -499,13 +638,7 @@ object ALS extends Logging { itemOutBlocks.unpersist() blockRatings.unpersist() } - val userOutput = userIdAndFactors.flatMap { case (ids, factors) => - ids.view.zip(factors) - } - val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => - ids.view.zip(factors) - } - (userOutput, itemOutput) + (userIdAndFactors, itemIdAndFactors) } /** @@ -925,15 +1058,15 @@ object ALS extends Logging { "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) - }.groupByKey(new HashPartitioner(srcPart.numPartitions)) - .mapValues { iter => - val builder = - new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) - iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => - builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) - } - builder.build().compress() - }.setName(prefix + "InBlocks") + }.groupByKey(new ALSPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks") .persist(storageLevel) val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => val encoder = new LocalIndexEncoder(dstPart.numPartitions) @@ -994,7 +1127,7 @@ object ALS extends Logging { (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) } } - val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length)) + val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length)) dstInBlocks.join(merged).mapValues { case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) @@ -1010,6 +1143,7 @@ object ALS extends Logging { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1017,13 +1151,23 @@ object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -1037,7 +1181,7 @@ object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) @@ -1079,4 +1223,11 @@ object ALS extends Logging { encoded & localIndexMask } } + + /** + * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k, + * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner + * satisfies this requirement, we simply use a type alias here. + */ + private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala new file mode 100644 index 000000000000..a2bcd67401d0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for regression. + * It supports both continuous and categorical features. + */ +@Experimental +final class DecisionTreeRegressor(override val uid: String) + extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] + with DecisionTreeParams with TreeRegressorParams { + + def this() = this(Identifiable.randomUID("dtr")) + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = getOldStrategy(categoricalFeatures) + val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeRegressionModel] + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) + } + + override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) +} + +@Experimental +object DecisionTreeRegressor { + /** Accessor for supported impurities: variance */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities +} + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. + * It supports both continuous and categorical features. + * @param rootNode Root of the decision tree + */ +@Experimental +final class DecisionTreeRegressionModel private[ml] ( + override val uid: String, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeRegressionModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + /** + * Construct a decision tree regression model. + * @param rootNode Root node of tree, with other nodes attached. + */ + private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + + override protected def predict(features: Vector): Double = { + rootNode.predictImpl(features).prediction + } + + override def copy(extra: ParamMap): DecisionTreeRegressionModel = { + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) + } + + override def toString: String = { + s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + } + + /** Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) + } +} + +private[ml] object DecisionTreeRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, + s"Cannot convert non-regression DecisionTreeModel (old API) to" + + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") + new DecisionTreeRegressionModel(uid, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala new file mode 100644 index 000000000000..b66e61f37dd5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@Experimental +final class GBTRegressor(override val uid: String) + extends Predictor[Vector, GBTRegressor, GBTRegressionModel] + with GBTParams with TreeRegressorParams with Logging { + + def this() = this(Identifiable.randomUID("gbtr")) + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTRegressor.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTRegressor: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group setParam */ + def setLossType(value: String): this.type = set(lossType, value) + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } + + override protected def train(dataset: DataFrame): GBTRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) + } + + override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) +} + +@Experimental +object GBTRegressor { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +/** + * :: Experimental :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@Experimental +final class GBTRegressionModel( + override val uid: String, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override protected def predict(features: Vector): Double = { + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + } + + override def copy(extra: ParamMap): GBTRegressionModel = { + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) + } + + override def toString: String = { + s"GBTRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTRegressor, + categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) + } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") + new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala new file mode 100644 index 000000000000..0f33bae30e62 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Params for isotonic regression. + */ +private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol + with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { + + /** + * Param for whether the output sequence should be isotonic/increasing (true) or + * antitonic/decreasing (false). + * @group param + */ + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", + "whether the output sequence should be isotonic/increasing (true) or" + + "antitonic/decreasing (false)") + + /** @group getParam */ + final def getIsotonic: Boolean = $(isotonic) + + /** + * Param for the index of the feature if [[featuresCol]] is a vector column (default: `0`), no + * effect otherwise. + * @group param + */ + final val featureIndex: IntParam = new IntParam(this, "featureIndex", + "The index of the feature if featuresCol is a vector column, no effect otherwise.") + + /** @group getParam */ + final def getFeatureIndex: Int = $(featureIndex) + + setDefault(isotonic -> true, featureIndex -> 0) + + /** Checks whether the input has weight column. */ + protected[ml] def hasWeightCol: Boolean = { + isDefined(weightCol) && $(weightCol) != "" + } + + /** + * Extracts (label, feature, weight) from input dataset. + */ + protected[ml] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { + val idx = $(featureIndex) + val extract = udf { v: Vector => v(idx) } + extract(col($(featuresCol))) + } else { + col($(featuresCol)) + } + val w = if (hasWeightCol) { + col($(weightCol)) + } else { + lit(1.0) + } + dataset.select(col($(labelCol)), f, w) + .map { case Row(label: Double, feature: Double, weights: Double) => + (label, feature, weights) + } + } + + /** + * Validates and transforms input schema. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected[ml] def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + if (fitting) { + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + if (hasWeightCol) { + SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) + } else { + logInfo("The weight column is not defined. Treat all instance weights as 1.0.") + } + } + val featuresType = schema($(featuresCol)).dataType + require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT]) + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Isotonic regression. + * + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegression(override val uid: String) extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase { + + def this() = this(Identifiable.randomUID("isoReg")) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setIsotonic(value: Boolean): this.type = set(isotonic, value) + + /** @group setParam */ + def setWeightCol(value: String): this.type = set(weightCol, value) + + /** @group setParam */ + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + override def fit(dataset: DataFrame): IsotonicRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractWeightedLabeledPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) + val oldModel = isotonicRegression.run(instances) + + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } +} + +/** + * :: Experimental :: + * Model fitted by IsotonicRegression. + * Predicts using a piecewise linear function. + * + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. + * + * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegressionModel private[ml] ( + override val uid: String, + private val oldModel: MLlibIsotonicRegressionModel) + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + + /** Boundaries in increasing order for which predictions are known. */ + def boundaries: Vector = Vectors.dense(oldModel.boundaries) + + /** + * Predictions associated with the boundaries at the same index, monotone because of isotonic + * regression. + */ + def predictions: Vector = Vectors.dense(oldModel.predictions) + + override def copy(extra: ParamMap): IsotonicRegressionModel = { + copyValues(new IsotonicRegressionModel(uid, oldModel), extra) + } + + override def transform(dataset: DataFrame): DataFrame = { + val predict = dataset.schema($(featuresCol)).dataType match { + case DoubleType => + udf { feature: Double => oldModel.predict(feature) } + case _: VectorUDT => + val idx = $(featureIndex) + udf { features: Vector => oldModel.predict(features(idx)) } + } + dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala new file mode 100644 index 000000000000..884003eb3852 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -0,0 +1,647 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.StatCounter + +/** + * Params for linear regression. + */ +private[regression] trait LinearRegressionParams extends PredictorParams + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasFitIntercept with HasStandardization + +/** + * :: Experimental :: + * Linear regression. + * + * The learning objective is to minimize the squared error, with regularization. + * The specific squared error loss function used is: + * L = 1/2n ||A weights - y||^2^ + * + * This support multiple types of regularization: + * - none (a.k.a. ordinary least squares) + * - L2 (ridge regression) + * - L1 (Lasso) + * - L2 + L1 (elastic net) + */ +@Experimental +class LinearRegression(override val uid: String) + extends Regressor[Vector, LinearRegression, LinearRegressionModel] + with LinearRegressionParams with Logging { + + def this() = this(Identifiable.randomUID("linReg")) + + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Whether to standardize the training features before fitting the model. + * The coefficients of models will be always returned on the original scale, + * so it will be transparent for users. Note that with/without standardization, + * the models should be always converged to the same solution when no regularization + * is applied. In R's GLMNET package, the default behavior is true as well. + * Default is true. + * @group setParam + */ + def setStandardization(value: Boolean): this.type = set(standardization, value) + setDefault(standardization -> true) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + override protected def train(dataset: DataFrame): LinearRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( + seqOp = (c, v) => (c, v) match { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + // If the yStd is zero, then the intercept is yMean with zero weights; + // as a result, training is not needed. + if (yStd == 0.0) { + logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + + s"and the intercept will be the mean of the label; as a result, training is not needed.") + if (handlePersistence) instances.unpersist() + val weights = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, weights, intercept) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset), + $(predictionCol), + $(labelCol), + $(featuresCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) + } + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = $(regParam) / yStd + val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), + $(standardization), featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + } else { + def effectiveL1RegFun = (index: Int) => { + if ($(standardization)) { + effectiveL1RegParam + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0 + } + } + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) + } + + val initialWeights = Vectors.zeros(numFeatures) + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeights.toBreeze.toDenseVector) + + val (weights, objectiveHistory) = { + /* + Note that in Linear Regression, the objective history (loss + regularization) returned + from optimizer is computed in the scaled space given by the following formula. + {{{ + L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms + }}} + */ + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + logError(msg) + throw new SparkException(msg) + } + + /* + The weights are trained in the scaled space; we're converting them back to + the original space. + */ + val rawWeights = state.x.toArray.clone() + var i = 0 + val len = rawWeights.length + while (i < len) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + + (Vectors.dense(rawWeights).compressed, arrayBuilder.result()) + } + + /* + The intercept in R's GLMNET is computed using closed form after the coefficients are + converged. See the following discussion for detail. + http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + */ + val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0 + + if (handlePersistence) instances.unpersist() + + val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset), + $(predictionCol), + $(labelCol), + $(featuresCol), + objectiveHistory) + model.setSummary(trainingSummary) + } + + override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model produced by [[LinearRegression]]. + */ +@Experimental +class LinearRegressionModel private[ml] ( + override val uid: String, + val weights: Vector, + val intercept: Double) + extends RegressionModel[Vector, LinearRegressionModel] + with LinearRegressionParams { + + private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + + /** + * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LinearRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LinearRegressionModel", + new NullPointerException()) + } + + private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + val t = udf { features: Vector => predict(features) } + val predictionAndObservations = dataset + .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + + new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + } + + override protected def predict(features: Vector): Double = { + dot(features, weights) + intercept + } + + override def copy(extra: ParamMap): LinearRegressionModel = { + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel.setParent(parent) + } +} + +/** + * :: Experimental :: + * Linear regression training results. + * @param predictions predictions outputted by the model's `transform` method. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class LinearRegressionTrainingSummary private[regression] ( + predictions: DataFrame, + predictionCol: String, + labelCol: String, + val featuresCol: String, + val objectiveHistory: Array[Double]) + extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + + /** Number of training iterations until termination */ + val totalIterations = objectiveHistory.length + +} + +/** + * :: Experimental :: + * Linear regression results evaluated on a dataset. + * @param predictions predictions outputted by the model's `transform` method. + */ +@Experimental +class LinearRegressionSummary private[regression] ( + @transient val predictions: DataFrame, + val predictionCol: String, + val labelCol: String) extends Serializable { + + @transient private val metrics = new RegressionMetrics( + predictions + .select(predictionCol, labelCol) + .map { case Row(pred: Double, label: Double) => (pred, label) } ) + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + val explainedVariance: Double = metrics.explainedVariance + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + val meanAbsoluteError: Double = metrics.meanAbsoluteError + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + val meanSquaredError: Double = metrics.meanSquaredError + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + val rootMeanSquaredError: Double = metrics.rootMeanSquaredError + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + val r2: Double = metrics.r2 + + /** Residuals (label - predicted value) */ + @transient lazy val residuals: DataFrame = { + val t = udf { (pred: Double, label: Double) => label - pred } + predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) + } + +} + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * For improving the convergence rate during the optimization process, and also preventing against + * features with very large variances exerting an overly large influence during model training, + * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce + * the condition number, and then trains the model in scaled space but returns the weights in + * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache + * the standardized dataset since it will create a lot of overhead. As a result, we perform the + * scaling implicitly when we compute the objective function. The following is the mathematical + * derivation. + * + * Note that we don't deal with intercept by adding bias here, because the intercept + * can be computed using closed form after the coefficients are converged. + * See this discussion for detail. + * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + * + * When training with intercept enabled, + * The objective function in the scaled space is given by + * {{{ + * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, + * }}} + * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, + * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. + * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead + * of the respective means. + * + * This can be rewritten as + * {{{ + * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} + * + \bar{y} / \hat{y}||^2 + * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 + * }}} + * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is + * {{{ + * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. + * }}}, and diff is + * {{{ + * \sum_i w_i^\prime x_i - y / \hat{y} + offset + * }}} + * + * + * Note that the effective weights and offset don't depend on training dataset, + * so they can be precomputed. + * + * Now, the first derivative of the objective function in scaled space is + * {{{ + * \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * }}} + * However, ($x_i - \bar{x_i}$) will densify the computation, so it's not + * an ideal formula when the training dataset is sparse format. + * + * This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms + * in the end by keeping the sum of diff. The first derivative of total + * objective function from all the samples is + * {{{ + * \frac{\partial L}{\partial\w_i} = + * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i}) + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) + * }}}, + * where correction_i = - diffSum \bar{x_i}) / \hat{x_i} + * + * A simple math can show that diffSum is actually zero, so we don't even + * need to add the correction terms in the end. From the definition of diff, + * {{{ + * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) + * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y}) + * = 0 + * }}} + * + * As a result, the first derivative of the total objective function only depends on + * the training dataset, which can be easily computed in distributed fashion, and is + * sparse format friendly. + * {{{ + * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * }}}, + * + * @param weights The weights/coefficients corresponding to the features. + * @param labelStd The standard deviation value of the label. + * @param labelMean The mean value of the label. + * @param fitIntercept Whether to fit an intercept term. + * @param featuresStd The standard deviation values of the features. + * @param featuresMean The mean values of the features. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + fitIntercept: Boolean, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + val len = weightsArray.length + while (i < len) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + } + + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + fitIntercept: Boolean, + standardization: Boolean, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, fitIntercept, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + val totalGradientArray = leastSquaresAggregator.gradient.toArray + + val regVal = if (effectiveL2regParam == 0.0) { + 0.0 + } else { + var sum = 0.0 + w.foreachActive { (index, value) => + // The following code will compute the loss of the regularization; also + // the gradient of the regularization, and add back to totalGradientArray. + sum += { + if (standardization) { + totalGradientArray(index) += effectiveL2regParam * value + value * value + } else { + if (featuresStd(index) != 0.0) { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + val temp = value / (featuresStd(index) * featuresStd(index)) + totalGradientArray(index) += effectiveL2regParam * temp + value * temp + } else { + 0.0 + } + } + } + } + 0.5 * effectiveL2regParam * sum + } + + (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala new file mode 100644 index 000000000000..2f36da371f57 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ + + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@Experimental +final class RandomForestRegressor(override val uid: String) + extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] + with RandomForestParams with TreeRegressorParams { + + def this() = this(Identifiable.randomUID("rfr")) + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + val trees = + RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) + .map(_.asInstanceOf[DecisionTreeRegressionModel]) + val numFeatures = oldDataset.first().features.size + new RandomForestRegressionModel(trees, numFeatures) + } + + override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) +} + +@Experimental +object RandomForestRegressor { + /** Accessor for supported impurity settings: variance */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: Experimental :: + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + * @param numFeatures Number of features used by this model + */ +@Experimental +final class RandomForestRegressionModel private[ml] ( + override val uid: String, + private val _trees: Array[DecisionTreeRegressionModel], + val numFeatures: Int) + extends PredictionModel[Vector, RandomForestRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + + /** + * Construct a random forest regression model, with all trees weighted equally. + * @param trees Component trees + */ + private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = + this(Identifiable.randomUID("rfr"), trees, numFeatures) + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override protected def predict(features: Vector): Double = { + // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 + // Predict average of tree predictions. + // Ignore the weights since all are 1.0 for now. + _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees + } + + override def copy(extra: ParamMap): RandomForestRegressionModel = { + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) + } + + override def toString: String = { + s"RandomForestRegressionModel with $numTrees trees" + } + + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestRegressor, + categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) + } + new RandomForestRegressionModel(parent.uid, newTrees, -1) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala new file mode 100644 index 000000000000..c72ef2968032 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} + + +/** + * :: DeveloperApi :: + * + * Single-label regression + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam Learner Concrete Estimator type + * @tparam M Concrete Model type + */ +@DeveloperApi +private[spark] abstract class Regressor[ + FeaturesType, + Learner <: Regressor[FeaturesType, Learner, M], + M <: RegressionModel[FeaturesType, M]] + extends Predictor[FeaturesType, Learner, M] with PredictorParams { + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: DeveloperApi :: + * + * Model produced by a [[Regressor]]. + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam M Concrete Model type. + */ +@DeveloperApi +abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with PredictorParams { + + // TODO: defaultEvaluator (follow-up PR) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala new file mode 100644 index 000000000000..cd2493129390 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict, ImpurityStats} + +/** + * :: DeveloperApi :: + * Decision tree node interface. + */ +@DeveloperApi +sealed abstract class Node extends Serializable { + + // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree + // code into the new API and deprecate the old API. SPARK-3727 + + /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ + def prediction: Double + + /** Impurity measure at this node (for training data) */ + def impurity: Double + + /** + * Statistics aggregated from training data at this node, used to compute prediction, impurity, + * and probabilities. + * For classification, the array of class counts must be normalized to a probability distribution. + */ + private[ml] def impurityStats: ImpurityCalculator + + /** Recursive prediction helper method */ + private[ml] def predictImpl(features: Vector): LeafNode + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. + */ + private[tree] def subtreeDepth: Int + + /** + * Create a copy of this node in the old Node format, recursively creating child nodes as needed. + * @param id Node ID using old format IDs + */ + private[ml] def toOld(id: Int): OldNode + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int +} + +private[ml] object Node { + + /** + * Create a new Node from the old Node format, recursively creating child nodes as needed. + */ + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + if (oldNode.isLeaf) { + // TODO: Once the implementation has been moved to this API, then include sufficient + // statistics here. + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + val gain = if (oldNode.stats.nonEmpty) { + oldNode.stats.get.gain + } else { + 0.0 + } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } + } +} + +/** + * :: DeveloperApi :: + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +@DeveloperApi +final class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) extends Node { + + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" + + override private[ml] def predictImpl(features: Vector): LeafNode = this + + override private[tree] def numDescendants: Int = 0 + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"Predict: $prediction\n" + } + + override private[tree] def subtreeDepth: Int = 0 + + override private[ml] def toOld(id: Int): OldNode = { + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), + impurity, isLeaf = true, None, None, None, None) + } + + override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * :: DeveloperApi :: + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. + * Values < 0 indicate missing values; this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. + */ +@DeveloperApi +final class InternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) extends Node { + + override def toString: String = { + s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" + } + + override private[ml] def predictImpl(features: Vector): LeafNode = { + if (split.shouldGoLeft(features)) { + leftChild.predictImpl(features) + } else { + rightChild.predictImpl(features) + } + } + + override private[tree] def numDescendants: Int = { + 2 + leftChild.numDescendants + rightChild.numDescendants + } + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" + + leftChild.subtreeToString(indentFactor + 1) + + prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" + + rightChild.subtreeToString(indentFactor + 1) + } + + override private[tree] def subtreeDepth: Int = { + 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth) + } + + override private[ml] def toOld(id: Int): OldNode = { + assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + + " since the old API does not support deep trees.") + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + Some(rightChild.toOld(OldNode.rightChildIndex(id))), + Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, + new OldPredict(leftChild.prediction, prob = 0.0), + new OldPredict(rightChild.prediction, prob = 0.0)))) + } + + override private[ml] def maxSplitFeatureIndex(): Int = { + math.max(split.featureIndex, + math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) + } +} + +private object InternalNode { + + /** + * Helper method for [[Node.subtreeToString()]]. + * @param split Split to print + * @param left Indicates whether this is the part of the split going to the left, + * or that going to the right. + */ + private def splitToString(split: Split, left: Boolean): String = { + val featureStr = s"feature ${split.featureIndex}" + split match { + case contSplit: ContinuousSplit => + if (left) { + s"$featureStr <= ${contSplit.threshold}" + } else { + s"$featureStr > ${contSplit.threshold}" + } + case catSplit: CategoricalSplit => + val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}") + if (left) { + s"$featureStr in $categoriesStr" + } else { + s"$featureStr not in $categoriesStr" + } + } + } +} + +/** + * Version of a node used in learning. This uses vars so that we can modify nodes as we split the + * tree by adding children, etc. + * + * For now, we use node IDs. These will be kept internal since we hope to remove node IDs + * in the future, or at least change the indexing (so that we can support much deeper trees). + * + * This node can either be: + * - a leaf node, with leftChild, rightChild, split set to null, or + * - an internal node, with all values set + * + * @param id We currently use the same indexing as the old implementation in + * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. + * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, + * so that we do not need to consider splitting it further. + * @param stats Impurity statistics for this node. + */ +private[tree] class LearningNode( + var id: Int, + var leftChild: Option[LearningNode], + var rightChild: Option[LearningNode], + var split: Option[Split], + var isLeaf: Boolean, + var stats: ImpurityStats) extends Serializable { + + /** + * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. + */ + def toNode: Node = { + if (leftChild.nonEmpty) { + assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + } else { + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + } + + } + } + +} + +private[tree] object LearningNode { + + /** Create a node with some of its fields set. */ + def apply( + id: Int, + isLeaf: Boolean, + stats: ImpurityStats): LearningNode = { + new LearningNode(id, None, None, None, false, stats) + } + + /** Create an empty node with the given node index. Values must be set later on. */ + def emptyNode(nodeIndex: Int): LearningNode = { + new LearningNode(nodeIndex, None, None, None, false, null) + } + + // The below indexing methods were copied from spark.mllib.tree.model.Node + + /** + * Return the index of the left child of this node. + */ + def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 + + /** + * Return the index of the right child of this node. + */ + def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 + + /** + * Get the parent index of the given node, or 0 if it is the root. + */ + def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 + + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { + throw new IllegalArgumentException(s"0 is not a valid node index.") + } else { + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) + } + + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 + + /** + * Return the maximum number of nodes which can be in the given level of the tree. + * @param level Level of tree (0 = root). + */ + def maxNodesInLevel(level: Int): Int = 1 << level + + /** + * Return the index of the first node in the given level. + * @param level Level of tree (0 = root). + */ + def startIndexInLevel(level: Int): Int = 1 << level + + /** + * Traces down from a root node to get the node with the given node index. + * This assumes the node exists. + */ + def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = { + var tmpNode: LearningNode = rootNode + var levelsToGo = indexToLevel(nodeIndex) + while (levelsToGo > 0) { + if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { + tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode] + } else { + tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode] + } + levelsToGo -= 1 + } + tmpNode + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala new file mode 100644 index 000000000000..78199cc2df58 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} +import org.apache.spark.mllib.tree.model.{Split => OldSplit} + + +/** + * :: DeveloperApi :: + * Interface for a "Split," which specifies a test made at a decision tree node + * to choose the left or right path. + */ +@DeveloperApi +sealed trait Split extends Serializable { + + /** Index of feature which this split tests */ + def featureIndex: Int + + /** + * Return true (split to left) or false (split to right). + * @param features Vector of features (original values, not binned). + */ + private[ml] def shouldGoLeft(features: Vector): Boolean + + /** + * Return true (split to left) or false (split to right). + * @param binnedFeature Binned feature value. + * @param splits All splits for the given feature. + */ + private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean + + /** Convert to old Split format */ + private[tree] def toOld: OldSplit +} + +private[tree] object Split { + + def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { + oldSplit.featureType match { + case OldFeatureType.Categorical => + new CategoricalSplit(featureIndex = oldSplit.feature, + _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + case OldFeatureType.Continuous => + new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) + } + } +} + +/** + * :: DeveloperApi :: + * Split which tests a categorical feature. + * @param featureIndex Index of the feature to test + * @param _leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. + * @param numCategories Number of categories for this feature. + */ +@DeveloperApi +final class CategoricalSplit private[ml] ( + override val featureIndex: Int, + _leftCategories: Array[Double], + private val numCategories: Int) + extends Split { + + require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}") + + /** + * If true, then "categories" is the set of categories for splitting to the left, and vice versa. + */ + private val isLeft: Boolean = _leftCategories.length <= numCategories / 2 + + /** Set of categories determining the splitting rule, along with [[isLeft]]. */ + private val categories: Set[Double] = { + if (isLeft) { + _leftCategories.toSet + } else { + setComplement(_leftCategories.toSet) + } + } + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + if (isLeft) { + categories.contains(features(featureIndex)) + } else { + !categories.contains(features(featureIndex)) + } + } + + override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = { + if (isLeft) { + categories.contains(binnedFeature.toDouble) + } else { + !categories.contains(binnedFeature.toDouble) + } + } + + override def equals(o: Any): Boolean = { + o match { + case other: CategoricalSplit => featureIndex == other.featureIndex && + isLeft == other.isLeft && categories == other.categories + case _ => false + } + } + + override private[tree] def toOld: OldSplit = { + val oldCats = if (isLeft) { + categories + } else { + setComplement(categories) + } + OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList) + } + + /** Get sorted categories which split to the left */ + def leftCategories: Array[Double] = { + val cats = if (isLeft) categories else setComplement(categories) + cats.toArray.sorted + } + + /** Get sorted categories which split to the right */ + def rightCategories: Array[Double] = { + val cats = if (isLeft) setComplement(categories) else categories + cats.toArray.sorted + } + + /** [0, numCategories) \ cats */ + private def setComplement(cats: Set[Double]): Set[Double] = { + Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet + } +} + +/** + * :: DeveloperApi :: + * Split which tests a continuous feature. + * @param featureIndex Index of the feature to test + * @param threshold If the feature value is <= this threshold, then the split goes left. + * Otherwise, it goes right. + */ +@DeveloperApi +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + features(featureIndex) <= threshold + } + + override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = { + if (binnedFeature == splits.length) { + // > last split, so split right + false + } else { + val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold + featureValueUpperBound <= threshold + } + } + + override def equals(o: Any): Boolean = { + o match { + case other: ContinuousSplit => + featureIndex == other.featureIndex && threshold == other.threshold + case _ => + false + } + } + + override private[tree] def toOld: OldSplit = { + OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala new file mode 100644 index 000000000000..488e8e4fb5dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.tree.{LearningNode, Split} +import org.apache.spark.mllib.tree.impl.BaggedPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) { + + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeature Binned feature value. + * @param splits Split information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = { + if (split.shouldGoLeft(binnedFeature, splits)) { + LearningNode.leftChildIndex(nodeIndex) + } else { + LearningNode.rightChildIndex(nodeIndex) + } + } +} + +/** + * Each TreePoint belongs to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +private[spark] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointInterval: Int) extends Logging { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // Indicates whether we can checkpoint + private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty + + // FileSystem instance for deleting checkpoints as needed + private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration) + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param splits Split information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + splits: Array[Array[Split]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) => + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null) + if (nodeIdUpdater != null) { + val featureIndex = nodeIdUpdater.split.featureIndex + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeature = point.datum.binnedFeatures(featureIndex), + splits = splits(featureIndex)) + ids(treeId) = newNodeIndex + } + treeId += 1 + } + ids + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue(1).getCheckpointFile.isDefined) { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we'll manually delete it here. + try { + fs.delete(new Path(old.getCheckpointFile.get), true) + } catch { + case e: IOException => + logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + + s" file: ${old.getCheckpointFile.get}") + } + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile.isDefined) { + try { + fs.delete(new Path(old.getCheckpointFile.get), true) + } catch { + case e: IOException => + logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + + s" file: ${old.getCheckpointFile.get}") + } + } + } + } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } +} + +@DeveloperApi +private[spark] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointInterval) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala new file mode 100644 index 000000000000..4ac51a475474 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -0,0 +1,1208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.Logging +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, + TimeTracker} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} + + +private[ml] object RandomForest extends Logging { + + /** + * Train a random forest. + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return an unweighted set of trees + */ + def run( + input: RDD[LabeledPoint], + strategy: OldStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Long, + parentUID: Option[String] = None): Array[DecisionTreeModel] = { + + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + logDebug("algo = " + strategy.algo) + logDebug("numTrees = " + numTrees) + logDebug("seed = " + seed) + logDebug("maxBins = " + metadata.maxBins) + logDebug("featureSubsetStrategy = " + featureSubsetStrategy) + logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplitsBins") + val splits = findSplits(retaggedInput, metadata) + timer.stop("findSplitsBins") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) + + val withReplacement = numTrees > 1 + + val baggedInput = BaggedPoint + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .persist(StorageLevel.MEMORY_AND_DISK) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val maxMemoryPerNode = { + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. + Some(metadata.numBins.zipWithIndex.sortBy(- _._1) + .take(metadata.numFeaturesPerNode).map(_._2)) + } else { + None + } + RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + } + require(maxMemoryPerNode <= maxMemoryUsage, + s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + + " which is too small for the given features." + + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") + + timer.stop("init") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + + // FIFO queue of nodes to train: (treeIndex, node) + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + + val rng = new Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + + while (nodeQueue.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + val (nodesForGroup, treeToNodeToIndexInfo) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.nonEmpty, + s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, + treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) + timer.stop("findBestSplits") + } + + baggedInput.unpersist() + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e: IOException => + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + } + } + + parentUID match { + case Some(uid) => + if (strategy.algo == OldAlgo.Classification) { + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + } + } else { + topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) + } + case None => + if (strategy.algo == OldAlgo.Classification) { + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + } + } else { + topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) + } + } + } + + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. + * + * @param node Node in tree from which to classify the given data point. + * @param binnedFeatures Binned feature vector for data point. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to [[findBestSplits()]]. + */ + private def predictNodeIndex( + node: LearningNode, + binnedFeatures: Array[Int], + splits: Array[Array[Split]]): Int = { + if (node.isLeaf || node.split.isEmpty) { + node.id + } else { + val split = node.split.get + val featureIndex = split.featureIndex + val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) + if (node.leftChild.isEmpty) { + // Not yet split. Return index from next layer of nodes to train + if (splitLeft) { + LearningNode.leftChildIndex(node.id) + } else { + LearningNode.rightChildIndex(node.id) + } + } else { + if (splitLeft) { + predictNodeIndex(node.leftChild.get, binnedFeatures, splits) + } else { + predictNodeIndex(node.rightChild.get, binnedFeatures, splits) + } + } + } + } + + /** + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits possible splits indexed (numFeatures)(numSplits) + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def mixedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + splits: Array[Array[Split]], + unorderedFeatures: Set[Int], + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.length + } else { + // Use all features + agg.metadata.numFeatures + } + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } + if (unorderedFeatures.contains(featureIndex)) { + // Unordered feature + val featureValue = treePoint.binnedFeatures(featureIndex) + val (leftNodeFeatureOffset, rightNodeFeatureOffset) = + agg.getLeftRightFeatureOffsets(featureIndexIdx) + // Update the left or right bin for each split. + val numSplits = agg.metadata.numSplits(featureIndex) + val featureSplits = splits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + } else { + agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + } + splitIndex += 1 + } + } else { + // Ordered feature + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + } + featureIndexIdx += 1 + } + } + + /** + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def orderedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val label = treePoint.label + + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.length) { + val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.update(featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } + } + } + + /** + * Given a group of nodes, this finds the best split for each node. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param metadata Learning and dataset metadata + * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. + */ + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePoint]], + metadata: DecisionTreeMetadata, + topNodes: Array[LearningNode], + nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], + splits: Array[Array[Split]], + nodeQueue: mutable.Queue[(Int, LearningNode)], + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { + + /* + * The high-level descriptions of the best split optimizations are noted here. + * + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // numNodes: Number of nodes in this group + val numNodes = nodesForGroup.values.map(_.length).sum + logDebug("numNodes = " + numNodes) + logDebug("numFeatures = " + metadata.numFeatures) + logDebug("numClasses = " + metadata.numClasses) + logDebug("isMulticlass = " + metadata.isMulticlass) + logDebug("isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, instanceWeight, featuresForNode) + } + } + } + + /** + * Performs a sequential aggregation over a partition. + * + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return agg + */ + def binSeqOp( + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = + predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group. + */ + def getNodeToFeatures( + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + if (!metadata.subsamplingFeatures) { + None + } else { + val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() + treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => + nodeIdToNodeInfo.values.foreach { nodeIndexInfo => + assert(nodeIndexInfo.featureSubset.isDefined) + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + } + } + Some(mutableNodeToFeatures.toMap) + } + } + + // array of nodes to train indexed by node index in group + val nodes = new Array[LearningNode](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + + // Calculate best splits for all nodes in the group + timer.start("chooseSplits") + + // In each partition, iterate all instances and compute aggregate stats for each node, + // yield an (nodeIndex, nodeAggregateStats) pair for each node. + // After a `reduceByKey` operation, + // stats of a node will be shuffled to a particular partition and be combined together, + // then best splits for nodes are found there. + // Finally, only best Splits for nodes are collected to driver to construct decision tree. + val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) + val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) + + val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { + input.mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + + // find best split for each node + val (split: Split, stats: ImpurityStats) = + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + }.collectAsMap() + + timer.stop("chooseSplits") + + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) + + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightChild.get)) + } + + logDebug("leftChildIndex = " + node.leftChild.get.id + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + node.rightChild.get.id + + ", impurity = " + stats.rightImpurity) + } + } + } + + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits) + } + } + + /** + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ + private def calculateImpurityStats( + stats: ImpurityStats, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + val totalCount = leftCount + rightCount + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) + } + + /** + * Find the best split for a node. + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: LearningNode): (Split, ImpurityStats) = { + + // Calculate InformationGain and ImpurityStats if current node is top node + val level = LearningNode.indexToLevel(node.id) + var gainAndImpurityStats: ImpurityStats = if (level ==0) { + null + } else { + node.stats + } + + // For each (feature, split), calculate the gain, and select the best (feature, split). + val (bestSplit, bestSplitStats) = + Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = + binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (binAggregates.metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.predict + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } + + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + }.maxBy(_._2.gain) + + (bestSplit, bestSplitStats) + } + + /** + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) "unordered features" + * For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * (b) "ordered features" + * For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one bin per category. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param metadata Learning and dataset metadata + * @return A tuple of (splits, bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numSplits). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). + */ + protected[tree] def findSplits( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata): Array[Array[Split]] = { + + logDebug("isMulticlass = " + metadata.isMulticlass) + + val numFeatures = metadata.numFeatures + + // Sample the input only if there are continuous features. + val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) + val sampledInput = if (hasContinuousFeatures) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug("fraction of data used for calculating quantiles = " + fraction) + input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect() + } else { + new Array[LabeledPoint](0) + } + + val splits = new Array[Array[Split]](numFeatures) + + // Find all splits. + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isContinuous(featureIndex)) { + val featureSamples = sampledInput.map(_.features(featureIndex)) + val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex) + + val numSplits = featureSplits.length + logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") + splits(featureIndex) = new Array[Split](numSplits) + + var splitIndex = 0 + while (splitIndex < numSplits) { + val threshold = featureSplits(splitIndex) + splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold) + splitIndex += 1 + } + } else { + // Categorical feature + if (metadata.isUnordered(featureIndex)) { + val numSplits = metadata.numSplits(featureIndex) + val featureArity = metadata.featureArity(featureIndex) + // TODO: Use an implicit representation mapping each category to a subset of indices. + // I.e., track indices such that we can calculate the set of bins for which + // feature value x splits to the left. + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + splits(featureIndex) = new Array[Split](numSplits) + var splitIndex = 0 + while (splitIndex < numSplits) { + val categories: List[Double] = + extractMultiClassCategories(splitIndex + 1, featureArity) + splits(featureIndex)(splitIndex) = + new CategoricalSplit(featureIndex, categories.toArray, featureArity) + splitIndex += 1 + } + } else { + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + splits(featureIndex) = new Array[Split](0) + } + } + featureIndex += 1 + } + splits + } + + /** + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of splits + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Array[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits = { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + // if possible splits is not enough or just enough, just return all possible splits + val possibleSplits = valueCounts.length + if (possibleSplits <= numSplits) { + valueCounts.map(_._1) + } else { + // stride between splits + val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + // iterate `valueCount` to find splits + val splitsBuilder = mutable.ArrayBuilder.make[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splitsBuilder += valueCounts(index - 1)._1 + targetCount += stride + } + index += 1 + } + + splitsBuilder.result() + } + } + + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") + // set number of splits accordingly + metadata.setNumSplits(featureIndex, splits.length) + + splits + } + + private[tree] class NodeIndexInfo( + val nodeIndexInGroup: Int, + val featureSubset: Option[Array[Int]]) extends Serializable + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeQueue Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeQueue: mutable.Queue[(Int, LearningNode)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { + val (treeIndex, node) = nodeQueue.head + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage) { + nodeQueue.dequeue() + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += + node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + } + numNodesInGroup += 1 + memUsage += nodeMemUsage + } + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[LearningNode]] = + mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } + + /** + * Given a Random Forest model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + * + * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for + * independently trained trees. + * @param trees Unweighted forest of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + val totalImportances = new OpenHashMap[Int, Double]() + trees.foreach { tree => + // Aggregate feature importance vector for this tree + val importances = new OpenHashMap[Int, Double]() + computeFeatureImportance(tree.rootNode, importances) + // Normalize importance vector for this tree, and add it to total. + // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? + val treeNorm = importances.map(_._2).sum + if (treeNorm != 0) { + importances.foreach { case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) + } + } + } + // Normalize importances + normalizeMapValues(totalImportances) + // Construct vector + val d = if (numFeatures != -1) { + numFeatures + } else { + // Find max feature index used in trees + val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max + maxFeatureIndex + 1 + } + if (d == 0) { + assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" + + s" importance: No splits in forest, but some non-zero importances.") + } + val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip + Vectors.sparse(d, indices.toArray, values.toArray) + } + + /** + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ + private[impl] def computeFeatureImportance( + node: Node, + importances: OpenHashMap[Int, Double]): Unit = { + node match { + case n: InternalNode => + val feature = n.split.featureIndex + val scaledGain = n.gain * n.impurityStats.count + importances.changeValue(feature, scaledGain, _ + scaledGain) + computeFeatureImportance(n.leftChild, importances) + computeFeatureImportance(n.rightChild, importances) + case n: LeafNode => + // do nothing + } + } + + /** + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * @param map Map with non-negative values. + */ + private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { + val total = map.map(_._2).sum + if (total != 0) { + val keys = map.iterator.map(_._1).toArray + keys.foreach { key => map.changeValue(key, 0.0, _ / total) } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala new file mode 100644 index 000000000000..9fa27e5e1f72 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.{ContinuousSplit, Split} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[spark] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param splits Splits for features, of size (numFeatures, numSplits). + * @param metadata Learning and dataset metadata + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + splits: Array[Array[Split]], + metadata: DecisionTreeMetadata): RDD[TreePoint] = { + // Construct arrays for featureArity for efficiency in the inner loop. + val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) + featureIndex += 1 + } + val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) => + if (arity == 0) { + splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold) + } else { + Array.empty[Double] + } + } + input.map { x => + TreePoint.labeledPointToTreePoint(x, thresholds, featureArity) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param thresholds For each feature, split thresholds for continuous features, + * empty for categorical features. + * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories + * for categorical features. + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + thresholds: Array[Array[Double]], + featureArity: Array[Int]): TreePoint = { + val numFeatures = labeledPoint.features.size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + arr(featureIndex) = + findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) + featureIndex += 1 + } + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find discretized value for one (labeledPoint, feature). + * + * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old + * (mllib) tree API. We want to maintain the same behavior as the old tree API. + * + * @param featureArity 0 for continuous features; number of categories for categorical features. + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + featureArity: Int, + thresholds: Array[Double]): Int = { + val featureValue = labeledPoint.features(featureIndex) + + if (featureArity == 0) { + val idx = java.util.Arrays.binarySearch(thresholds, featureValue) + if (idx >= 0) { + idx + } else { + -idx - 1 + } + } else { + // Categorical feature bins are indexed by feature values. + if (featureValue < 0 || featureValue >= featureArity) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + featureValue.toInt + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala new file mode 100644 index 000000000000..b77191156f68 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.{Vectors, Vector} + +/** + * Abstraction for Decision Tree models. + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + */ +private[ml] trait DecisionTreeModel { + + /** Root of the decision tree */ + def rootNode: Node + + /** Number of nodes in tree, including leaf nodes. */ + def numNodes: Int = { + 1 + rootNode.numDescendants + } + + /** + * Depth of the tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + lazy val depth: Int = { + rootNode.subtreeDepth + } + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"DecisionTreeModel of depth $depth with $numNodes nodes" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + rootNode.subtreeToString(2) + } + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() +} + +/** + * Abstraction for models which are ensembles of decision trees + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + */ +private[ml] trait TreeEnsembleModel { + + // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of + // DecisionTreeModel. + + /** Trees in this ensemble. Warning: These have null parent Estimators. */ + def trees: Array[DecisionTreeModel] + + /** Weights for each tree, zippable with [[trees]] */ + def treeWeights: Array[Double] + + /** Weights used by the python wrappers. */ + // Note: An array cannot be returned directly due to serialization problems. + private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights) + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"TreeEnsembleModel with $numTrees trees" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) => + s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + + /** Total number of nodes, summed over all trees in the ensemble. */ + lazy val totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala new file mode 100644 index 000000000000..dbd8d31571d2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.ml.classification.ClassifierParams +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} + +/** + * Parameters for Decision Tree-based algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +private[ml] trait DecisionTreeParams extends PredictorParams { + + /** + * Maximum depth of the tree (>= 0). + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default = 5) + * @group param + */ + final val maxDepth: IntParam = + new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", + ParamValidators.gtEq(0)) + + /** + * Maximum number of bins used for discretizing continuous features and for choosing how to split + * on features at each node. More bins give higher granularity. + * Must be >= 2 and >= number of categories in any categorical feature. + * (default = 32) + * @group param + */ + final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature.", ParamValidators.gtEq(2)) + + /** + * Minimum number of instances each child must have after split. + * If a split causes the left or right child to have fewer than minInstancesPerNode, + * the split will be discarded as invalid. + * Should be >= 1. + * (default = 1) + * @group param + */ + final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + + " number of instances each child must have after split. If a split causes the left or right" + + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + + " Should be >= 1.", ParamValidators.gtEq(1)) + + /** + * Minimum information gain for a split to be considered at a tree node. + * (default = 0.0) + * @group param + */ + final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", + "Minimum information gain for a split to be considered at a tree node.") + + /** + * Maximum memory in MB allocated to histogram aggregation. + * (default = 256 MB) + * @group expertParam + */ + final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", + "Maximum memory in MB allocated to histogram aggregation.", + ParamValidators.gtEq(0)) + + /** + * If false, the algorithm will pass trees to executors to match instances with nodes. + * If true, the algorithm will cache node IDs for each instance. + * Caching can speed up training of deeper trees. + * (default = false) + * @group expertParam + */ + final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + + " algorithm will pass trees to executors to match instances with nodes. If true, the" + + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + + " trees.") + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertParam + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + + " checkpoint directory is set in the SparkContext. Must be >= 1.", + ParamValidators.gtEq(1)) + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + + /** @group setParam */ + def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group getParam */ + final def getMaxDepth: Int = $(maxDepth) + + /** @group setParam */ + def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group getParam */ + final def getMaxBins: Int = $(maxBins) + + /** @group setParam */ + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group getParam */ + final def getMinInstancesPerNode: Int = $(minInstancesPerNode) + + /** @group setParam */ + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group getParam */ + final def getMinInfoGain: Double = $(minInfoGain) + + /** @group expertSetParam */ + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertGetParam */ + final def getMaxMemoryInMB: Int = $(maxMemoryInMB) + + /** @group expertSetParam */ + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** @group expertGetParam */ + final def getCacheNodeIds: Boolean = $(cacheNodeIds) + + /** @group expertSetParam */ + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group expertGetParam */ + final def getCheckpointInterval: Int = $(checkpointInterval) + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStrategy(oldAlgo) + strategy.impurity = oldImpurity + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = subsamplingRate + strategy + } +} + +/** + * Parameters for Decision Tree-based classification algorithms. + */ +private[ml] trait TreeClassifierParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "entropy" and "gini". + * (default = gini) + * @group param + */ + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + + setDefault(impurity -> "gini") + + /** @group setParam */ + def setImpurity(value: String): this.type = set(impurity, value) + + /** @group getParam */ + final def getImpurity: String = $(impurity).toLowerCase + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "entropy" => OldEntropy + case "gini" => OldGini + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeClassifierParams was given unrecognized impurity: $impurity.") + } + } +} + +private[ml] object TreeClassifierParams { + // These options should be lowercase. + final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) +} + +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) + + setDefault(impurity -> "variance") + + /** @group setParam */ + def setImpurity(value: String): this.type = set(impurity, value) + + /** @group getParam */ + final def getImpurity: String = $(impurity).toLowerCase + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeRegressorParams was given unrecognized impurity: $impurity") + } + } +} + +private[ml] object TreeRegressorParams { + // These options should be lowercase. + final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} + +/** + * Parameters for Decision Tree-based ensemble algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { + + /** + * Fraction of the training data used for learning each decision tree, in range (0, 1]. + * (default = 1.0) + * @group param + */ + final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", + "Fraction of the training data used for learning each decision tree, in range (0, 1].", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + setDefault(subsamplingRate -> 1.0) + + /** @group setParam */ + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group getParam */ + final def getSubsamplingRate: Double = $(subsamplingRate) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and seed. + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + } +} + +/** + * Parameters for Random Forest algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + /** + * The number of features to consider for splits at each tree node. + * Supported options: + * - "auto": Choose automatically for task: + * If numTrees == 1, set to "all." + * If numTrees > 1 (forest), set to "sqrt" for classification and + * to "onethird" for regression. + * - "all": use all features + * - "onethird": use 1/3 of the features + * - "sqrt": use sqrt(number of features) + * - "log2": use log2(number of features) + * (default = "auto") + * + * These various settings are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @group param + */ + final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node." + + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + (value: String) => + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) + + setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + + /** @group setParam */ + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) + + /** @group setParam */ + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + + /** @group getParam */ + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase +} + +private[ml] object RandomForestParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) +} + +/** + * Parameters for Gradient-Boosted Tree algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { + + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + /* TODO: Add this doc when we add this param. SPARK-7132 + * Threshold for stopping early when runWithValidation is used. + * If the error rate on the validation input changes by less than the validationTol, + * then learning will stop early (before [[numIterations]]). + * This parameter is ignored when run is used. + * (default = 1e-5) + * @group param + */ + // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") + // validationTol -> 1e-5 + + setDefault(maxIter -> 20, stepSize -> 0.1) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group getParam */ + final def getStepSize: Double = $(stepSize) + + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ + private[ml] def getOldBoostingStrategy( + categoricalFeatures: Map[Int, Int], + oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + // NOTE: The old API does not support "seed" so we ignore it. + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + } + + /** Get old Gradient Boosting Loss type */ + private[ml] def getOldLossType: OldLoss +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5d51c5134666..0679bfd0f3ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -30,98 +32,136 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends Params { - /** param for the estimator to be cross-validated */ - val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") - def getEstimator: Estimator[_] = get(estimator) - - /** param for estimator param maps */ - val estimatorParamMaps: Param[Array[ParamMap]] = - new Param(this, "estimatorParamMaps", "param maps for the estimator") - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) - - /** param for the evaluator for selection */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") - def getEvaluator: Evaluator = get(evaluator) - - /** param for number of folds for cross validation */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) - def getNumFolds: Int = get(numFolds) +private[ml] trait CrossValidatorParams extends ValidatorParams { + /** + * Param for number of folds for cross validation. Must be >= 2. + * Default: 3 + * @group param + */ + val numFolds: IntParam = new IntParam(this, "numFolds", + "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) + + /** @group getParam */ + def getNumFolds: Int = $(numFolds) + + setDefault(numFolds -> 3) } /** - * :: AlphaComponent :: + * :: Experimental :: * K-fold cross validation. */ -@AlphaComponent -class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { +@Experimental +class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] + with CrossValidatorParams with Logging { + + def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS + /** @group setParam */ def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema - transformSchema(dataset.schema, paramMap, logging = true) + transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext - val est = map(estimator) - val eval = map(evaluator) - val epm = map(estimatorParamMaps) - val numModels = epm.size - val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) + val est = $(estimator) + val eval = $(evaluator) + val epm = $(estimatorParamMaps) + val numModels = epm.length + val metrics = new Array[Double](epm.length) + val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => - val trainingDataset = sqlCtx.applySchema(training, schema).cache() - val validationDataset = sqlCtx.applySchema(validation, schema).cache() + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) + // TODO: duplicate evaluator to take extra params from input + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric i += 1 } + validationDataset.unpersist() } - f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) + f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - val cvModel = new CrossValidatorModel(this, map, bestModel) - Params.inheritValues(map, this, cvModel) - cvModel + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + $(estimator).transformSchema(schema) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - map(estimator).transformSchema(schema, paramMap) + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } + + override def copy(extra: ParamMap): CrossValidator = { + val copied = defaultCopy(extra).asInstanceOf[CrossValidator] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied } } /** - * :: AlphaComponent :: + * :: Experimental :: * Model from k-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidatorModel private[ml] ( - override val parent: CrossValidator, - override val fittingParamMap: ParamMap, - val bestModel: Model[_]) + override val uid: String, + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - bestModel.transform(dataset, paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + bestModel.transform(dataset) + } + + override def transformSchema(schema: StructType): StructType = { + bestModel.transformSchema(schema) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - bestModel.transformSchema(schema, paramMap) + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index dafe73d82c00..98a8f0330ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ /** - * :: AlphaComponent :: + * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ -@AlphaComponent +@Experimental class ParamGridBuilder { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala new file mode 100644 index 000000000000..73a14b831015 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +/** + * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. + */ +private[ml] trait TrainValidationSplitParams extends ValidatorParams { + /** + * Param for ratio between train and validation data. Must be between 0 and 1. + * Default: 0.75 + * @group param + */ + val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", + "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) + + /** @group getParam */ + def getTrainRatio: Double = $(trainRatio) + + setDefault(trainRatio -> 0.75) +} + +/** + * :: Experimental :: + * Validation for hyper-parameter tuning. + * Randomly splits the input dataset into train and validation sets, + * and uses evaluation metric on the validation set to select the best model. + * Similar to [[CrossValidator]], but only splits the set once. + */ +@Experimental +class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] + with TrainValidationSplitParams with Logging { + + def this() = this(Identifiable.randomUID("tvs")) + + /** @group setParam */ + def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ + def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ + def setTrainRatio(value: Double): this.type = set(trainRatio, value) + + override def fit(dataset: DataFrame): TrainValidationSplitModel = { + val schema = dataset.schema + transformSchema(schema, logging = true) + val sqlCtx = dataset.sqlContext + val est = $(estimator) + val eval = $(evaluator) + val epm = $(estimatorParamMaps) + val numModels = epm.length + val metrics = new Array[Double](epm.length) + + val Array(training, validation) = + dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + + // multi-model training + logDebug(s"Train split with multiple sets of parameters.") + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() + var i = 0 + while (i < numModels) { + // TODO: duplicate evaluator to take extra params from input + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + validationDataset.unpersist() + + logInfo(s"Train validation split metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best train validation split metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + $(estimator).transformSchema(schema) + } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } + + override def copy(extra: ParamMap): TrainValidationSplit = { + val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } +} + +/** + * :: Experimental :: + * Model from train validation split. + * + * @param uid Id. + * @param bestModel Estimator determined best model. + * @param validationMetrics Evaluated validation metrics. + */ +@Experimental +class TrainValidationSplitModel private[ml] ( + override val uid: String, + val bestModel: Model[_], + val validationMetrics: Array[Double]) + extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + + override def validateParams(): Unit = { + bestModel.validateParams() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + bestModel.transform(dataset) + } + + override def transformSchema(schema: StructType): StructType = { + bestModel.transformSchema(schema) + } + + override def copy(extra: ParamMap): TrainValidationSplitModel = { + val copied = new TrainValidationSplitModel ( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + validationMetrics.clone()) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala new file mode 100644 index 000000000000..8897ab0825ac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param.{ParamMap, Param, Params} + +/** + * :: DeveloperApi :: + * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. + */ +@DeveloperApi +private[ml] trait ValidatorParams extends Params { + + /** + * param for the estimator to be validated + * @group param + */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + + /** @group getParam */ + def getEstimator: Estimator[_] = $(estimator) + + /** + * param for estimator param maps + * @group param + */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + + /** @group getParam */ + def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) + + /** + * param for the evaluator used to select hyper-parameters that maximize the validated metric + * @group param + */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", + "evaluator used to select hyper-parameters that maximize the validated metric") + + /** @group getParam */ + def getEvaluator: Evaluator = $(evaluator) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala new file mode 100644 index 000000000000..bd213e7362e9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.util.UUID + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * + * Trait for an object with an immutable unique ID that identifies itself and its derivatives. + * + * WARNING: There have not yet been final discussions on this API, so it may be broken in future + * releases. + */ +@DeveloperApi +trait Identifiable { + + /** + * An immutable unique ID for the object and its derivatives. + */ + val uid: String + + override def toString: String = uid +} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Identifiable { + + /** + * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. + */ + def randomUID(prefix: String): String = { + prefix + "_" + UUID.randomUUID().toString.takeRight(12) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala new file mode 100644 index 000000000000..fcb517b5f735 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.immutable.HashMap + +import org.apache.spark.ml.attribute._ +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types.StructField + + +/** + * Helper utilities for tree-based algorithms + */ +private[spark] object MetadataUtils { + + /** + * Examine a schema to identify the number of classes in a label column. + * Returns None if the number of labels is not specified, or if the label column is continuous. + */ + def getNumClasses(labelSchema: StructField): Option[Int] = { + Attribute.fromStructField(labelSchema) match { + case binAttr: BinaryAttribute => Some(2) + case nomAttr: NominalAttribute => nomAttr.getNumValues + case _: NumericAttribute | UnresolvedAttribute => None + } + } + + /** + * Examine a schema to identify categorical (Binary and Nominal) features. + * + * @param featuresSchema Schema of the features column. + * If a feature does not have metadata, it is assumed to be continuous. + * If a feature is Nominal, then it must have the number of values + * specified. + * @return Map: feature index --> number of categories. + * The map's set of keys will be the set of categorical feature indices. + */ + def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = { + val metadata = AttributeGroup.fromStructField(featuresSchema) + if (metadata.attributes.isEmpty) { + HashMap.empty[Int, Int] + } else { + metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) => + if (attr == null) { + Iterator() + } else { + attr match { + case _: NumericAttribute | UnresolvedAttribute => Iterator() + case binAttr: BinaryAttribute => Iterator(idx -> 2) + case nomAttr: NominalAttribute => + nomAttr.getNumValues match { + case Some(numValues: Int) => Iterator(idx -> numValues) + case None => throw new IllegalArgumentException(s"Feature $idx is marked as" + + " Nominal (categorical), but it does not have the number of values specified.") + } + } + } + }.toMap + } + } + + /** + * Takes a Vector column and a list of feature names, and returns the corresponding list of + * feature indices in the column, in order. + * @param col Vector column which must have feature names specified via attributes + * @param names List of feature names + */ + def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = { + require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col" + + s" to be Vector type, but it was type ${col.dataType} instead.") + val inputAttr = AttributeGroup.fromStructField(col) + names.map { name => + require(inputAttr.hasAttr(name), + s"getFeatureIndicesFromNames found no feature with name $name in column $col.") + inputAttr.getAttr(name).index.get + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 000000000000..76f651488aef --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.sql.types.{DataType, StructField, StructType} + + +/** + * Utils for handling schemas. + */ +private[spark] object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + def checkColumnType( + schema: StructType, + colName: String, + dataType: DataType, + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.$message") + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param colName new column name. If this column name is an empty string "", this method returns + * the input schema unchanged. This allows users to disable output columns. + * @param dataType new column data type + * @return new schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.isEmpty) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param col New column schema + * @return new schema with the input column appended + */ + def appendColumn(schema: StructType, col: StructField): StructType = { + require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") + StructType(schema.fields :+ col) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala new file mode 100644 index 000000000000..8d4174124b5c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.{Accumulator, SparkContext} + +/** + * Abstract class for stopwatches. + */ +private[spark] abstract class Stopwatch extends Serializable { + + @transient private var running: Boolean = false + private var startTime: Long = _ + + /** + * Name of the stopwatch. + */ + val name: String + + /** + * Starts the stopwatch. + * Throws an exception if the stopwatch is already running. + */ + def start(): Unit = { + assume(!running, "start() called but the stopwatch is already running.") + running = true + startTime = now + } + + /** + * Stops the stopwatch and returns the duration of the last session in milliseconds. + * Throws an exception if the stopwatch is not running. + */ + def stop(): Long = { + assume(running, "stop() called but the stopwatch is not running.") + val duration = now - startTime + add(duration) + running = false + duration + } + + /** + * Checks whether the stopwatch is running. + */ + def isRunning: Boolean = running + + /** + * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch + * is running. + */ + def elapsed(): Long + + override def toString: String = s"$name: ${elapsed()}ms" + + /** + * Gets the current time in milliseconds. + */ + protected def now: Long = System.currentTimeMillis() + + /** + * Adds input duration to total elapsed time. + */ + protected def add(duration: Long): Unit +} + +/** + * A local [[Stopwatch]]. + */ +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch { + + private var elapsedTime: Long = 0L + + override def elapsed(): Long = elapsedTime + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A distributed [[Stopwatch]] using Spark accumulator. + * @param sc SparkContext + */ +private[spark] class DistributedStopwatch( + sc: SparkContext, + override val name: String) extends Stopwatch { + + private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + + override def elapsed(): Long = elapsedTime.value + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A multiple stopwatch that contains local and distributed stopwatches. + * @param sc SparkContext + */ +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable { + + private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty + + /** + * Adds a local stopwatch. + * @param name stopwatch name + */ + def addLocal(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new LocalStopwatch(name) + this + } + + /** + * Adds a distributed stopwatch. + * @param name stopwatch name + */ + def addDistributed(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new DistributedStopwatch(sc, name) + this + } + + /** + * Gets a stopwatch. + * @param name stopwatch name + */ + def apply(name: String): Stopwatch = stopwatches(name) + + override def toString: String = { + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" $c") + .mkString("{\n", ",\n", "\n}") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala new file mode 100644 index 000000000000..ee933f4cfcaf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of FPGrowthModel to provide helper method for Python + */ +private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { + + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala new file mode 100644 index 000000000000..0ec88ef77d69 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} +import org.apache.spark.mllib.clustering.GaussianMixtureModel + +/** + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { + val weights: Vector = Vectors.dense(model.weights) + val k: Int = weights.size + + /** + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ + val gaussians: JList[Object] = { + val modelGaussians = model.gaussians + var i = 0 + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + while (i < k) { + mu += modelGaussians(i).mu + sigma += modelGaussians(i).sigma + i += 1 + } + List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala new file mode 100644 index 000000000000..534edac56bc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of MatrixFactorizationModel to provide helper method for Python. + */ +private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.map { + case (user, feature) => (user, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) + } + + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.map { + case (product, feature) => (product, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala new file mode 100644 index 000000000000..bc6041b22173 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel + +/** + * A Wrapper of PowerIterationClusteringModel to provide helper method for Python + */ +private[python] class PowerIterationClusteringModelWrapper(model: PowerIterationClusteringModel) + extends PowerIterationClusteringModel(model.k, model.assignments) { + + def getAssignments: RDD[Array[Any]] = { + model.assignments.map(x => Array(x.id, x.cluster)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 980980593d19..f585aacd452e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,38 +28,41 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.evaluation.RankingMetrics import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} +import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** - * :: DeveloperApi :: - * The Java stubs necessary for the Python mllib bindings. + * The Java stubs necessary for the Python mllib bindings. It is called by Py4J on the Python side. */ -@DeveloperApi -class PythonMLLibAPI extends Serializable { - +private[python] class PythonMLLibAPI extends Serializable { /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -74,13 +77,28 @@ class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + /** + * Loads and serializes vectors saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @return serialized vectors in a RDD + */ + def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] = + MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeights: Vector): JList[Object] = { try { val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + if (model.isInstanceOf[LogisticRegressionModel]) { + val lrModel = model.asInstanceOf[LogisticRegressionModel] + List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) + .map(_.asInstanceOf[Object]).asJava + } else { + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } } finally { data.rdd.unpersist(blocking = false) } @@ -113,9 +131,11 @@ class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) + .setValidateData(validateData) lrAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -137,8 +157,12 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lassoAlg = new LassoWithSGD() + lassoAlg.setIntercept(intercept) + .setValidateData(validateData) lassoAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -159,8 +183,12 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.setIntercept(intercept) + .setValidateData(validateData) ridgeAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -183,9 +211,11 @@ class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) + .setValidateData(validateData) SVMAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -209,9 +239,11 @@ class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -235,9 +267,13 @@ class PythonMLLibAPI extends Serializable { regType: String, intercept: Boolean, corrections: Int, - tolerance: Double): JList[Object] = { + tolerance: Double, + validateData: Boolean, + numClasses: Int): JList[Object] = { val LogRegAlg = new LogisticRegressionWithLBFGS() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) + .setNumClasses(numClasses) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -253,14 +289,32 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes( + def trainNaiveBayesModel( data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)). map(_.asInstanceOf[Object]).asJava } + /** + * Java stub for Python mllib IsotonicRegression.run() + */ + def trainIsotonicRegressionModel( + data: JavaRDD[Vector], + isotonic: Boolean): JList[Object] = { + val isotonicRegressionAlg = new IsotonicRegression().setIsotonic(isotonic) + val input = data.rdd.map { x => + (x(0), x(1), x(2)) + }.persist(StorageLevel.MEMORY_AND_DISK) + try { + val model = isotonicRegressionAlg.run(input) + List[AnyRef](model.boundaryVector, model.predictionVector).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + /** * Java stub for Python mllib KMeans.run() */ @@ -270,12 +324,16 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String, - seed: java.lang.Long): KMeansModel = { + seed: java.lang.Long, + initializationSteps: Int, + epsilon: Double): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) + .setInitializationSteps(initializationSteps) + .setEpsilon(epsilon) if (seed != null) kMeansAlg.setSeed(seed) @@ -286,34 +344,46 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib KMeansModel.computeCost() + */ + def computeCostKmeansModel( + data: JavaRDD[Vector], + centers: java.util.ArrayList[Vector]): Double = { + new KMeansModel(centers).computeCost(data) + } + /** * Java stub for Python mllib GaussianMixture.run() * Returns a list containing weights, mean and covariance of each mixture component. */ - def trainGaussianMixture( - data: JavaRDD[Vector], - k: Int, - convergenceTol: Double, + def trainGaussianMixtureModel( + data: JavaRDD[Vector], + k: Int, + convergenceTol: Double, maxIterations: Int, - seed: Long): JList[Object] = { + seed: java.lang.Long, + initialModelWeights: java.util.ArrayList[Double], + initialModelMu: java.util.ArrayList[Vector], + initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) + if (initialModelWeights != null && initialModelMu != null && initialModelSigma != null) { + val gaussians = initialModelMu.asScala.toSeq.zip(initialModelSigma.asScala.toSeq).map { + case (x, y) => new MultivariateGaussian(x.asInstanceOf[Vector], y.asInstanceOf[Matrix]) + } + val initialModel = new GaussianMixtureModel( + initialModelWeights.asScala.toArray, gaussians.toArray) + gmmAlg.setInitialModel(initialModel) + } + if (seed != null) gmmAlg.setSeed(seed) try { - val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) - var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - for (i <- 0 until model.k) { - wt += model.weights(i) - mu += model.gaussians(i).mu - sigma += model.gaussians(i).sigma - } - List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))) } finally { data.rdd.unpersist(blocking = false) } @@ -324,33 +394,45 @@ class PythonMLLibAPI extends Serializable { */ def predictSoftGMM( data: JavaRDD[Vector], - wt: Object, + wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Array[Double]] = { + si: Array[Object]): RDD[Vector] = { - val weight = wt.asInstanceOf[Array[Double]] + val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) - } + } val model = new GaussianMixtureModel(weight, gaussians) - model.predictSoft(data) + model.predictSoft(data).map(Vectors.dense) } /** - * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + * Java stub for Python mllib PowerIterationClustering.run(). This stub returns a + * handle to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see the + * Py4J documentation. + * @param data an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix. + * @param k number of clusters. + * @param maxIterations maximum number of iterations of the power iteration loop. + * @param initMode the initialization mode. This can be either "random" to use + * a random vector as vertex properties, or "degree" to use + * normalized sum similarities. Default: random. */ - private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) - extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { - - def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = - predict(SerDe.asTupleRDD(userAndProducts.rdd)) - - def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + def trainPowerIterationClusteringModel( + data: JavaRDD[Vector], + k: Int, + maxIterations: Int, + initMode: String): PowerIterationClusteringModel = { - def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + val pic = new PowerIterationClustering() + .setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initMode) + val model = pic.run(data.rdd.map(v => (v(0).toLong, v(1).toLong, v(2)))) + new PowerIterationClusteringModelWrapper(model) } /** @@ -377,7 +459,7 @@ class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -408,10 +490,61 @@ class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib LDA.run() + */ + def trainLDAModel( + data: JavaRDD[java.util.List[Any]], + k: Int, + maxIterations: Int, + docConcentration: Double, + topicConcentration: Double, + seed: java.lang.Long, + checkpointInterval: Int, + optimizer: String): LDAModel = { + val algo = new LDA() + .setK(k) + .setMaxIterations(maxIterations) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) + .setCheckpointInterval(checkpointInterval) + .setOptimizer(optimizer) + + if (seed != null) algo.setSeed(seed) + + val documents = data.rdd.map(_.asScala.toArray).map { r => + r(0) match { + case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector]) + case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector]) + case _ => throw new IllegalArgumentException("input values contains invalid type value.") + } + } + algo.run(documents) + } + + + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel( + data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartitions: Int): FPGrowthModel[Any] = { + val fpg = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartitions) + + val model = fpg.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ @@ -427,7 +560,7 @@ class PythonMLLibAPI extends Serializable { } /** - * Java stub for IDF.fit(). This stub returns a + * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. @@ -439,6 +572,26 @@ class PythonMLLibAPI extends Serializable { new StandardScaler(withMean, withStd).fit(data.rdd) } + /** + * Java stub for ChiSqSelector.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { + new ChiSqSelector(numTopFeatures).fit(data.rdd) + } + + /** + * Java stub for PCA.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitPCA(k: Int, data: JavaRDD[Vector]): PCAModel = { + new PCA(k).fit(data.rdd) + } + /** * Java stub for IDF.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. @@ -462,19 +615,21 @@ class PythonMLLibAPI extends Serializable { * @param seed initial seed for random generator * @return A handle to java Word2VecModelWrapper instance at python side */ - def trainWord2Vec( + def trainWord2VecModel( dataJRDD: JavaRDD[java.util.ArrayList[String]], vectorSize: Int, learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long): Word2VecModelWrapper = { + seed: Long, + minCount: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) + .setMinCount(minCount) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -508,6 +663,12 @@ class PythonMLLibAPI extends Serializable { val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) } /** @@ -600,12 +761,14 @@ class PythonMLLibAPI extends Serializable { lossStr: String, numIterations: Int, learningRate: Double, - maxDepth: Int): GradientBoostedTreesModel = { + maxDepth: Int, + maxBins: Int): GradientBoostedTreesModel = { val boostingStrategy = BoostingStrategy.defaultParams(algoStr) boostingStrategy.setLoss(Losses.fromString(lossStr)) boostingStrategy.setNumIterations(numIterations) boostingStrategy.setLearningRate(learningRate) boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.setMaxBins(maxBins) boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) @@ -616,6 +779,14 @@ class PythonMLLibAPI extends Serializable { } } + def elementwiseProductVector(scalingVector: Vector, vector: Vector): Vector = { + new ElementwiseProduct(scalingVector).transform(vector) + } + + def elementwiseProductVector(scalingVector: Vector, vector: JavaRDD[Vector]): JavaRDD[Vector] = { + new ElementwiseProduct(scalingVector).transform(vector) + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. @@ -852,7 +1023,155 @@ class PythonMLLibAPI extends Serializable { RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s) } + /** + * Java stub for the constructor of Python mllib RankingMetrics + */ + def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = { + new RankingMetrics(predictionAndLabels.map( + r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) + } + + /** + * Java stub for the estimate method of KernelDensity + */ + def estimateKernelDensity( + sample: JavaRDD[Double], + bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = { + new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( + points.asScala.toArray) + } + + /** + * Java stub for the update method of StreamingKMeansModel. + */ + def updateStreamingKMeansModel( + clusterCenters: JList[Vector], + clusterWeights: JList[Double], + data: JavaRDD[Vector], + decayFactor: Double, + timeUnit: String): JList[Object] = { + val model = new StreamingKMeansModel( + clusterCenters.asScala.toArray, clusterWeights.asScala.toArray) + .update(data, decayFactor, timeUnit) + List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava + } + + /** + * Wrapper around the generateLinearInput method of LinearDataGenerator. + */ + def generateLinearInputWrapper( + intercept: Double, + weights: JList[Double], + xMean: JList[Double], + xVariance: JList[Double], + nPoints: Int, + seed: Int, + eps: Double): Array[LabeledPoint] = { + LinearDataGenerator.generateLinearInput( + intercept, weights.asScala.toArray, xMean.asScala.toArray, + xVariance.asScala.toArray, nPoints, seed, eps).toArray + } + + /** + * Wrapper around the generateLinearRDD method of LinearDataGenerator. + */ + def generateLinearRDDWrapper( + sc: JavaSparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int, + intercept: Double): JavaRDD[LabeledPoint] = { + LinearDataGenerator.generateLinearRDD( + sc, nexamples, nfeatures, eps, nparts, intercept) + } + + /** + * Java stub for Statistics.kolmogorovSmirnovTest() + */ + def kolmogorovSmirnovTest( + data: JavaRDD[Double], + distName: String, + params: JList[Double]): KolmogorovSmirnovTestResult = { + val paramsSeq = params.asScala.toSeq + Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*) + } + /** + * Wrapper around RowMatrix constructor. + */ + def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { + new RowMatrix(rows.rdd, numRows, numCols) + } + + /** + * Wrapper around IndexedRowMatrix constructor. + */ + def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { + // We use DataFrames for serialization of IndexedRows from Python, + // so map each Row in the DataFrame back to an IndexedRow. + val indexedRows = rows.map { + case Row(index: Long, vector: Vector) => IndexedRow(index, vector) + } + new IndexedRowMatrix(indexedRows, numRows, numCols) + } + + /** + * Wrapper around CoordinateMatrix constructor. + */ + def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = { + // We use DataFrames for serialization of MatrixEntry entries from + // Python, so map each Row in the DataFrame back to a MatrixEntry. + val entries = rows.map { + case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value) + } + new CoordinateMatrix(entries, numRows, numCols) + } + + /** + * Wrapper around BlockMatrix constructor. + */ + def createBlockMatrix(blocks: DataFrame, rowsPerBlock: Int, colsPerBlock: Int, + numRows: Long, numCols: Long): BlockMatrix = { + // We use DataFrames for serialization of sub-matrix blocks from + // Python, so map each Row in the DataFrame back to a + // ((blockRowIndex, blockColIndex), sub-matrix) tuple. + val blockTuples = blocks.map { + case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => + ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) + } + new BlockMatrix(blockTuples, rowsPerBlock, colsPerBlock, numRows, numCols) + } + + /** + * Return the rows of an IndexedRowMatrix. + */ + def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { + // We use DataFrames for serialization of IndexedRows to Python, + // so return a DataFrame. + val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext) + sqlContext.createDataFrame(indexedRowMatrix.rows) + } + + /** + * Return the entries of a CoordinateMatrix. + */ + def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { + // We use DataFrames for serialization of MatrixEntry entries to + // Python, so return a DataFrame. + val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) + sqlContext.createDataFrame(coordinateMatrix.entries) + } + + /** + * Return the sub-matrix blocks of a BlockMatrix. + */ + def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { + // We use DataFrames for serialization of sub-matrix blocks to + // Python, so return a DataFrame. + val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + sqlContext.createDataFrame(blockMatrix.blocks) + } } /** @@ -905,13 +1224,21 @@ private[spark] object SerDe extends Serializable { out.write(code) } + protected def getBytes(obj: Object): Array[Byte] = { + if (obj.getClass.isArray) { + obj.asInstanceOf[Array[Byte]] + } else { + obj.asInstanceOf[String].getBytes(LATIN1) + } + } + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) } // Pickler for DenseVector private[python] class DenseVectorPickler extends BasePickler[DenseVector] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val vector: DenseVector = obj.asInstanceOf[DenseVector] val bytes = new Array[Byte](8 * vector.size) val bb = ByteBuffer.wrap(bytes) @@ -930,7 +1257,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(0)) val bb = ByteBuffer.wrap(bytes, 0, bytes.length) bb.order(ByteOrder.nativeOrder()) val db = bb.asDoubleBuffer() @@ -943,12 +1270,14 @@ private[spark] object SerDe extends Serializable { // Pickler for DenseMatrix private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] val bytes = new Array[Byte](8 * m.values.size) val order = ByteOrder.nativeOrder() + val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + out.write(Opcodes.MARK) out.write(Opcodes.BININT) out.write(PickleUtils.integer_to_bytes(m.numRows)) out.write(Opcodes.BININT) @@ -956,26 +1285,84 @@ private[spark] object SerDe extends Serializable { out.write(Opcodes.BINSTRING) out.write(PickleUtils.integer_to_bytes(bytes.length)) out.write(bytes) - out.write(Opcodes.TUPLE3) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) } def construct(args: Array[Object]): Object = { - if (args.length != 3) { - throw new PickleException("should be 3") + if (args.length != 4) { + throw new PickleException("should be 4") } - val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) + val isTransposed = args(3).asInstanceOf[Int] == 1 + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) + } + } + + // Pickler for SparseMatrix + private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + val s = obj.asInstanceOf[SparseMatrix] + val order = ByteOrder.nativeOrder() + + val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length) + val indicesBytes = new Array[Byte](4 * s.rowIndices.length) + val valuesBytes = new Array[Byte](8 * s.values.length) + val isTransposed = if (s.isTransposed) 1 else 0 + ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs) + ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices) + ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values) + + out.write(Opcodes.MARK) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(s.numRows)) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(s.numCols)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length)) + out.write(colPtrsBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(indicesBytes.length)) + out.write(indicesBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(valuesBytes.length)) + out.write(valuesBytes) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) + } + + def construct(args: Array[Object]): Object = { + if (args.length != 6) { + throw new PickleException("should be 6") + } + val order = ByteOrder.nativeOrder() + val colPtrsBytes = getBytes(args(2)) + val indicesBytes = getBytes(args(3)) + val valuesBytes = getBytes(args(4)) + val colPtrs = new Array[Int](colPtrsBytes.length / 4) + val rowIndices = new Array[Int](indicesBytes.length / 4) + val values = new Array[Double](valuesBytes.length / 8) + ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs) + ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices) + ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values) + val isTransposed = args(5).asInstanceOf[Int] == 1 + new SparseMatrix( + args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values, + isTransposed) } } // Pickler for SparseVector private[python] class SparseVectorPickler extends BasePickler[SparseVector] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val v: SparseVector = obj.asInstanceOf[SparseVector] val n = v.indices.size val indiceBytes = new Array[Byte](4 * n) @@ -1000,8 +1387,8 @@ private[spark] object SerDe extends Serializable { throw new PickleException("should be 3") } val size = args(0).asInstanceOf[Int] - val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) - val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val indiceBytes = getBytes(args(1)) + val valueBytes = getBytes(args(2)) val n = indiceBytes.length / 4 val indices = new Array[Int](n) val values = new Array[Double](n) @@ -1017,7 +1404,7 @@ private[spark] object SerDe extends Serializable { // Pickler for LabeledPoint private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val point: LabeledPoint = obj.asInstanceOf[LabeledPoint] saveObjects(out, pickler, point.label, point.features) } @@ -1033,7 +1420,7 @@ private[spark] object SerDe extends Serializable { // Pickler for Rating private[python] class RatingPickler extends BasePickler[Rating] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val rating: Rating = obj.asInstanceOf[Rating] saveObjects(out, pickler, rating.user, rating.product, rating.rating) } @@ -1056,6 +1443,7 @@ private[spark] object SerDe extends Serializable { if (!initialized) { new DenseVectorPickler().register() new DenseMatrixPickler().register() + new SparseMatrixPickler().register() new SparseVectorPickler().register() new LabeledPointPickler().register() new RatingPickler().register() @@ -1080,7 +1468,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } @@ -1105,7 +1493,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24d7..85a413243b04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,7 +17,9 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.json4s.{DefaultFormats, JValue} + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -28,6 +30,7 @@ import org.apache.spark.rdd.RDD * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. */ @Experimental +@Since("0.8.0") trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. @@ -35,6 +38,7 @@ trait ClassificationModel extends Serializable { * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -43,6 +47,7 @@ trait ClassificationModel extends Serializable { * @param testData array representing a single data point * @return predicted category from the trained model */ + @Since("1.0.0") def predict(testData: Vector): Double /** @@ -50,6 +55,19 @@ trait ClassificationModel extends Serializable { * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction */ + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object ClassificationModel { + + /** + * Helper method for loading GLM classification model metadata. + * @return (numFeatures, numClasses) + */ + def getNumFeaturesClasses(metadata: JValue): (Int, Int) = { + implicit val formats = DefaultFormats + ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a469315a1b5c..5ceff5b2259e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,33 +17,66 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD + /** * Classification model trained using Multinomial/Binary Logistic Regression. * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. - * In Multinomial Logistic Regression, the intercepts will not be a single values, + * In Multinomial Logistic Regression, the intercepts will not be a single value, * so the intercepts will be part of the weights.) * @param numFeatures the dimension of the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. By default, it is binary logistic regression * so numClasses will be set to 2. */ -class LogisticRegressionModel ( - override val weights: Vector, - override val intercept: Double, - val numFeatures: Int, - val numClasses: Int) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { +@Since("0.8.0") +class LogisticRegressionModel @Since("1.3.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("1.0.0") override val intercept: Double, + @Since("1.3.0") val numFeatures: Int, + @Since("1.3.0") val numClasses: Int) + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable with PMMLExportable { + + if (numClasses == 2) { + require(weights.size == numFeatures, + s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + + s" numFeatures = $numFeatures, but weights.size = ${weights.size}") + } else { + val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures + val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) + require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, + s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + + s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + + s" or $weightsSizeWithIntercept (with intercept)," + + s" but was given weights of length ${weights.size}") + } + + private val dataWithBiasSize: Int = weights.size / (numClasses - 1) + + private val weightsArray: Array[Double] = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + /** + * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. + */ + @Since("1.0.0") def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) @@ -53,61 +86,54 @@ class LogisticRegressionModel ( * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * It is only used for binary classification. */ + @Since("1.0.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * It is only used for binary classification. + */ + @Since("1.3.0") + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * It is only used for binary classification. */ + @Since("1.0.0") @Experimental def clearThreshold(): this.type = { threshold = None this } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { require(dataMatrix.size == numFeatures) // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { - require(numFeatures == weightMatrix.size) - val margin = dot(weights, dataMatrix) + intercept + val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score > t) 1.0 else 0.0 case None => score } } else { - val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - - val weightsArray = weights match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") - } - - val margins = (0 until numClasses - 1).map { i => - var margin = 0.0 - dataMatrix.foreachActive { (index, value) => - if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) - } - // Intercept is required to be added into margin. - if (dataMatrix.size + 1 == dataWithBiasSize) { - margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) - } - margin - } - /** - * Find the one with maximum margins. If the maxMargin is negative, then the prediction - * result will be the first class. + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. * * PS, if you want to compute the probabilities for each outcome instead of the outcome * with maximum probability, remember to subtract the maxMargin from margins if maxMargin @@ -115,17 +141,64 @@ class LogisticRegressionModel ( */ var bestClass = 0 var maxMargin = 0.0 - var i = 0 - while(i < margins.size) { - if (margins(i) > maxMargin) { - maxMargin = margins(i) + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin bestClass = i + 1 } - i += 1 } bestClass.toDouble } } + + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures, numClasses, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" + + override def toString: String = { + s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" + } +} + +@Since("1.3.0") +object LogisticRegressionModel extends Loader[LogisticRegressionModel] { + + @Since("1.3.0") + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + // numFeatures, numClasses, weights are checked in model initialization + val model = + new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"LogisticRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** @@ -136,6 +209,7 @@ class LogisticRegressionModel ( * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ +@Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -145,6 +219,7 @@ class LogisticRegressionWithSGD private[mllib] ( private val gradient = new LogisticGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -156,6 +231,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -167,6 +243,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ +@Since("0.8.0") object LogisticRegressionWithSGD { // NOTE(shivaram): We use multiple train methods instead of default arguments to support // Java programs. @@ -185,6 +262,7 @@ object LogisticRegressionWithSGD { * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -207,6 +285,7 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -228,6 +307,7 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -245,6 +325,7 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int): LogisticRegressionModel = { @@ -258,11 +339,13 @@ object LogisticRegressionWithSGD { * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. */ +@Since("1.1.0") class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { this.setFeatureScaling(true) + @Since("1.1.0") override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(multiLabelValidator) @@ -281,6 +364,7 @@ class LogisticRegressionWithLBFGS * Multinomial Logistic Regression. * By default, it is binary logistic regression so k will be set to 2. */ + @Since("1.3.0") @Experimental def setNumClasses(numClasses: Int): this.type = { require(numClasses > 1) @@ -292,6 +376,10 @@ class LogisticRegressionWithLBFGS } override protected def createModel(weights: Vector, intercept: Double) = { - new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + if (numOfLinearPredictor == 1) { + new LogisticRegressionModel(weights, intercept) + } else { + new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a967df857bed..a956084ae06e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,13 +17,20 @@ package org.apache.spark.mllib.classification -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import java.lang.{Iterable => JIterable} -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} +import scala.collection.JavaConverters._ + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} /** * Model for Naive Bayes Classifiers. @@ -32,28 +39,52 @@ import org.apache.spark.rdd.RDD * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features + * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ -class NaiveBayesModel private[mllib] ( - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { +@Since("0.9.0") +class NaiveBayesModel private[spark] ( + @Since("1.0.0") val labels: Array[Double], + @Since("0.9.0") val pi: Array[Double], + @Since("0.9.0") val theta: Array[Array[Double]], + @Since("1.4.0") val modelType: String) + extends ClassificationModel with Serializable with Saveable { - private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM[Double](theta.length, theta(0).length) + import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} - { - // Need to put an extra pair of braces to prevent Scala treating `i` as a member. - var i = 0 - while (i < theta.length) { - var j = 0 - while (j < theta(i).length) { - brzTheta(i, j) = theta(i)(j) - j += 1 + private val piVector = new DenseVector(pi) + private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true) + + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + this(labels, pi, theta, NaiveBayes.Multinomial) + + /** A Java-friendly constructor that takes three Iterable parameters. */ + private[mllib] def this( + labels: JIterable[Double], + pi: JIterable[Double], + theta: JIterable[JIterable[Double]]) = + this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) + + require(supportedModelTypes.contains(modelType), + s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.") + + // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // application of this condition (in predict function). + private val (thetaMinusNegTheta, negThetaSum) = modelType match { + case Multinomial => (None, None) + case Bernoulli => + val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) + val thetaMinusNegTheta = thetaMatrix.map { value => + value - math.log(1.0 - math.exp(value)) } - i += 1 - } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: $modelType.") } + @Since("1.0.0") override def predict(testData: RDD[Vector]): RDD[Double] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -62,8 +93,211 @@ class NaiveBayesModel private[mllib] ( } } + @Since("1.0.0") override def predict(testData: Vector): Double = { - labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) + modelType match { + case Multinomial => + labels(multinomialCalculation(testData).argmax) + case Bernoulli => + labels(bernoulliCalculation(testData).argmax) + } + } + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities, + * in the same order as class labels + */ + @Since("1.5.0") + def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { + val bcModel = testData.context.broadcast(this) + testData.mapPartitions { iter => + val model = bcModel.value + iter.map(model.predictProbabilities) + } + } + + /** + * Predict posterior class probabilities for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return predicted posterior class probabilities from the trained model, + * in the same order as class labels + */ + @Since("1.5.0") + def predictProbabilities(testData: Vector): Vector = { + modelType match { + case Multinomial => + posteriorProbabilities(multinomialCalculation(testData)) + case Bernoulli => + posteriorProbabilities(bernoulliCalculation(testData)) + } + } + + private def multinomialCalculation(testData: Vector) = { + val prob = thetaMatrix.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + prob + } + + private def bernoulliCalculation(testData: Vector) = { + testData.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + private def posteriorProbabilities(logProb: DenseVector) = { + val logProbArray = logProb.toArray + val maxLog = logProbArray.max + val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog)) + val probSum = scaledProbs.sum + new DenseVector(scaledProbs.map(_ / probSum)) + } + + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) + NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) + } + + override protected def formatVersion: String = "2.0" +} + +@Since("1.3.0") +object NaiveBayesModel extends Loader[NaiveBayesModel] { + + import org.apache.spark.mllib.util.Loader._ + + private[mllib] object SaveLoadV2_0 { + + def thisFormatVersion: String = "2.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.write.parquet(dataPath(path)) + } + + @Since("1.3.0") + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.read.parquet(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + val modelType = data.getString(3) + new NaiveBayesModel(labels, pi, theta, modelType) + } + + } + + private[mllib] object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.write.parquet(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.read.parquet(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) + } + } + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + val classNameV2_0 = SaveLoadV2_0.thisClassName + val (model, numFeatures, numClasses) = (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV1_0.load(sc, path) + (model, numFeatures, numClasses) + case (className, "2.0") if className == classNameV2_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV2_0.load(sc, path) + (model, numFeatures, numClasses) + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + assert(model.pi.length == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.length} elements") + assert(model.theta.length == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.length} elements") + assert(model.theta.forall(_.length == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.length).mkString(",")}") + model } } @@ -75,64 +309,120 @@ class NaiveBayesModel private[mllib] ( * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { +@Since("0.9.0") +class NaiveBayes private ( + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { + + import NaiveBayes.{Bernoulli, Multinomial} + + @Since("1.4.0") + def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) - def this() = this(1.0) + @Since("0.9.0") + def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ + @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { this.lambda = lambda this } + /** Get the smoothing parameter. */ + @Since("1.4.0") + def getLambda: Double = lambda + + /** + * Set the model type using a string (case-sensitive). + * Supported options: "multinomial" (default) and "bernoulli". + */ + @Since("1.4.0") + def setModelType(modelType: String): NaiveBayes = { + require(NaiveBayes.supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown modelType: $modelType.") + this.modelType = modelType + this + } + + /** Get the model type. */ + @Since("1.4.0") + def getModelType: String = this.modelType + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ - def run(data: RDD[LabeledPoint]) = { + @Since("0.9.0") + def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") } } + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { - requireNonnegativeValues(v) - (1L, v.toBreeze.toDenseVector) + if (modelType == Bernoulli) { + requireZeroOneBernoulliValues(v) + } else { + requireNonnegativeValues(v) + } + (1L, v.copy.toDense) }, - mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + mergeValue = (c: (Long, DenseVector), v: Vector) => { requireNonnegativeValues(v) - (c._1 + 1L, c._2 += v.toBreeze) + BLAS.axpy(1.0, v, c._2) + (c._1 + 1L, c._2) }, - mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => - (c1._1 + c2._1, c1._2 += c2._2) - ).collect() + mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { + BLAS.axpy(1.0, c2._2, c1._2) + (c1._1 + c2._1, c1._2) + } + ).collect().sortBy(_._1) + val numLabels = aggregated.length var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label - val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) pi(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = modelType match { + case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: $modelType.") + } var j = 0 while (j < numFeatures) { theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom @@ -141,27 +431,38 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with i += 1 } - new NaiveBayesModel(labels, pi, theta) + new NaiveBayesModel(labels, pi, theta, modelType) } } /** * Top-level methods for calling naive Bayes. */ +@Since("0.9.0") object NaiveBayes { + + /** String name for multinomial model type. */ + private[spark] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[spark] val Bernoulli: String = "bernoulli" + + /* Set of modelTypes that NaiveBayes supports */ + private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) + /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * This version of the method uses a default smoothing parameter of 1.0. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. */ + @Since("0.9.0") def train(input: RDD[LabeledPoint]): NaiveBayesModel = { new NaiveBayes().run(input) } @@ -169,16 +470,42 @@ object NaiveBayes { /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter */ + @Since("0.9.0") def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda).run(input) + new NaiveBayes(lambda, Multinomial).run(input) } + + /** + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]]) + * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle + * discrete count data and can be called by setting the model type to "multinomial". + * For example, it can be used with word counts or TF_IDF vectors of documents. + * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a + * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as + * Bernoulli NB. + * + * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency + * vector or a count vector. + * @param lambda The smoothing parameter + * + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be + * multinomial or bernoulli + */ + @Since("1.4.0") + def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { + require(supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown modelType: $modelType.") + new NaiveBayes(lambda, modelType).run(input) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index dd514ff8a37f..896565cd90e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,11 +17,14 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD /** @@ -30,10 +33,12 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class SVMModel ( - override val weights: Vector, - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { +@Since("0.8.0") +class SVMModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable with PMMLExportable { private var threshold: Option[Double] = Some(0.0) @@ -43,16 +48,26 @@ class SVMModel ( * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. */ + @Since("1.0.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Since("1.3.0") + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. */ + @Since("1.0.0") @Experimental def clearThreshold(): this.type = { threshold = None @@ -69,6 +84,48 @@ class SVMModel ( case None => margin } } + + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" + + override def toString: String = { + s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" + } +} + +@Since("1.3.0") +object SVMModel extends Loader[SVMModel] { + + @Since("1.3.0") + override def load(sc: SparkContext, path: String): SVMModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new SVMModel(data.weights, data.intercept) + assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + + s" was given non-matching weights vector of size ${model.weights.size}") + assert(numClasses == 2, + s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes") + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"SVMModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** @@ -76,6 +133,7 @@ class SVMModel ( * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") class SVMWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -85,6 +143,7 @@ class SVMWithSGD private ( private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -96,6 +155,7 @@ class SVMWithSGD private ( * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, * regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -106,6 +166,7 @@ class SVMWithSGD private ( /** * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") object SVMWithSGD { /** @@ -124,6 +185,7 @@ object SVMWithSGD { * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -147,6 +209,7 @@ object SVMWithSGD { * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -168,6 +231,7 @@ object SVMWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -186,6 +250,7 @@ object SVMWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. */ + @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 0.01, 1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 6a3893d0e41d..75630054d136 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.StreamingLinearAlgorithm @@ -35,14 +35,16 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * Use a builder pattern to construct a streaming logistic regression * analysis in an application, like: * + * {{{ * val model = new StreamingLogisticRegressionWithSGD() * .setStepSize(0.5) * .setNumIterations(10) * .setInitialWeights(Vectors.dense(...)) * .trainOn(DStream) - * + * }}} */ @Experimental +@Since("1.3.0") class StreamingLogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -57,36 +59,44 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.3.0") def this() = this(0.1, 50, 1.0, 0.0) - val algorithm = new LogisticRegressionWithSGD( + protected val algorithm = new LogisticRegressionWithSGD( stepSize, numIterations, regParam, miniBatchFraction) + protected var model: Option[LogisticRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ + @Since("1.3.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + @Since("1.3.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } /** Set the fraction of each batch to use for updates. Default: 1.0. */ + @Since("1.3.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } /** Set the regularization parameter. Default: 0.0. */ + @Since("1.3.0") def setRegParam(regParam: Double): this.type = { this.algorithm.optimizer.setRegParam(regParam) this } /** Set the initial weights. Default: [0.0, 0.0]. */ + @Since("1.3.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala new file mode 100644 index 000000000000..fe09f6b75d28 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{Row, SQLContext} + +/** + * Helper class for import/export of GLM classification models. + */ +private[classification] object GLMClassificationModel { + + object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Model data for import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + + /** + * Helper method for saving GLM classification model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + * @param numClasses Number of classes label can take, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + numFeatures: Int, + numClasses: Int, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM classification model data. + * + * NOTE: Callers of this method should check numClasses, numFeatures on their own. + * + * @param modelClass String name for model class (used for error messages) + */ + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.read.parquet(datapath) + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") + val (weights, intercept) = data match { + case Row(weights: Vector, intercept: Double, _) => + (weights, intercept) + } + val threshold = if (data.isNullAt(2)) { + None + } else { + Some(data.getDouble(2)) + } + Data(weights, intercept, threshold) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 5c626fde4e65..f82bd82c2037 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -19,51 +19,67 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.IndexedSeq -import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** + * :: Experimental :: + * * This class performs expectation maximization for multivariate Gaussian * Mixture Models (GMMs). A GMM represents a composite distribution of * independent Gaussian distributions with associated "mixing" weights * specifying each's contribution to the composite. * - * Given a set of sample points, this class will maximize the log-likelihood - * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed - * to find a global optimum. - * + * to find a global optimum. + * + * Note: For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * * @param k The number of independent Gaussians in the mixture model * @param convergenceTol The maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ +@Experimental +@Since("1.3.0") class GaussianMixture private ( - private var k: Int, - private var convergenceTol: Double, + private var k: Int, + private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - - /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + + /** + * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, + * maxIterations: 100, seed: random}. + */ + @Since("1.3.0") def this() = this(2, 0.01, 100, Utils.random.nextLong()) - + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 - - // an initializing GMM can be provided rather than using the + + // an initializing GMM can be provided rather than using the // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - - /** Set the initial GMM starting point, bypassing the random initialization. - * You must call setK() prior to calling this method, and the condition - * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + + /** + * Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException */ + @Since("1.3.0") def setInitialModel(model: GaussianMixtureModel): this.type = { if (model.k == k) { initialModel = Some(model) @@ -72,171 +88,238 @@ class GaussianMixture private ( } this } - - /** Return the user supplied initial GMM, if supplied */ + + /** + * Return the user supplied initial GMM, if supplied + */ + @Since("1.3.0") def getInitialModel: Option[GaussianMixtureModel] = initialModel - - /** Set the number of Gaussians in the mixture model. Default: 2 */ + + /** + * Set the number of Gaussians in the mixture model. Default: 2 + */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this } - - /** Return the number of Gaussians in the mixture model */ + + /** + * Return the number of Gaussians in the mixture model + */ + @Since("1.3.0") def getK: Int = k - - /** Set the maximum number of iterations to run. Default: 100 */ + + /** + * Set the maximum number of iterations to run. Default: 100 + */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - - /** Return the maximum number of iterations to run */ + + /** + * Return the maximum number of iterations to run + */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations - + /** - * Set the largest change in log-likelihood at which convergence is + * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ + @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this } - + /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. */ + @Since("1.3.0") def getConvergenceTol: Double = convergenceTol - /** Set the random seed */ + /** + * Set the random seed + */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this } - /** Return the random seed */ + /** + * Return the random seed + */ + @Since("1.3.0") def getSeed: Long = seed - /** Perform expectation maximization */ + /** + * Perform expectation maximization + */ + @Since("1.3.0") def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext - + // we will operate on the data as breeze data - val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() - + val breezeData = data.map(_.toBreeze).cache() + // Get length of the input vectors val d = breezeData.first().length - + + val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(k, d) + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances - // derived from the samples + // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - + case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) - (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) - }) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + }) } } - - var llh = Double.MinValue // current log-likelihood + + var llh = Double.MinValue // current log-likelihood var llhp = 0.0 // previous log-likelihood - + var iter = 0 - while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) { + while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) { // create and broadcast curried cluster contribution function val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) - + // aggregate the cluster contribution for all sample points val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) - + // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum - var i = 0 - while (i < k) { - val mu = sums.means(i) / sums.weights(i) - BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector], - Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) - weights(i) = sums.weights(i) / sumWeights - gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) - i = i + 1 + + if (shouldDistributeGaussians) { + val numPartitions = math.min(k, 1024) + val tuples = + Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i))) + val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) => + updateWeightsAndGaussians(mean, sigma, weight, sumWeights) + }.collect().unzip + Array.copy(ws.toArray, 0, weights, 0, ws.length) + Array.copy(gs.toArray, 0, gaussians, 0, gs.length) + } else { + var i = 0 + while (i < k) { + val (weight, gaussian) = + updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights) + weights(i) = weight + gaussians(i) = gaussian + i = i + 1 + } } - + llhp = llh // current becomes previous llh = sums.logLikelihood // this is the freshly computed log-likelihood iter += 1 - } - + } + new GaussianMixtureModel(weights, gaussians) } - + + /** + * Java-friendly version of [[run()]] + */ + @Since("1.3.0") + def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) + + private def updateWeightsAndGaussians( + mean: BDV[Double], + sigma: BreezeMatrix[Double], + weight: Double, + sumWeights: Double): (Double, MultivariateGaussian) = { + val mu = (mean /= weight) + BLAS.syr(-weight, Vectors.fromBreeze(mu), + Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix]) + val newWeight = weight / sumWeights + val newGaussian = new MultivariateGaussian(mu, sigma / weight) + (newWeight, newGaussian) + } + /** Average of dense breeze vectors */ - private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { - val v = BreezeVector.zeros[Double](x(0).length) + private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { + val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) - v / x.length.toDouble + v / x.length.toDouble } - + /** * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) */ - private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) - val ss = BreezeVector.zeros[Double](x(0).length) - x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + val ss = BDV.zeros[Double](x(0).length) + x.foreach(xi => ss += (xi - mu) :^ 2.0) diag(ss / x.length.toDouble) } } +private[clustering] object GaussianMixture { + /** + * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when + * d > 25 except for when k is very small. + * @param k Number of topics + * @param d Number of features + */ + def shouldDistributeGaussians(k: Int, d: Int): Boolean = ((k - 1.0) / k) * d > 25 +} + // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { - new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } - + // compute cluster contributions for each input point // (U, T) => U for aggregation def add( - weights: Array[Double], + weights: Array[Double], dists: Array[MultivariateGaussian]) - (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) } val pSum = p.sum sums.logLikelihood += math.log(pSum) - val xxt = x * new Transpose(x) var i = 0 while (i < sums.k) { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += x * p(i) - BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector], + BLAS.syr(p(i), Vectors.fromBreeze(x), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) i = i + 1 } sums - } + } } // Aggregation class for partial expectation results private class ExpectationSum( var logLikelihood: Double, val weights: Array[Double], - val means: Array[BreezeVector[Double]], + val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { - + val k = weights.length - + def +=(x: ExpectationSum): ExpectationSum = { var i = 0 while (i < k) { @@ -247,5 +330,5 @@ private class ExpectationSum( } logLikelihood += x.logLikelihood this - } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 1a2178ee7f71..7f6163e04bf1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -19,50 +19,99 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} /** - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are - * the respective mean and covariance for each Gaussian distribution i=1..k. - * - * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is - * the weight for Gaussian i, and weight.sum == 1 - * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i - * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the - * covariance matrix for Gaussian i + * :: Experimental :: + * + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * + * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is + * the weight for Gaussian i, and weights.sum == 1 + * @param gaussians Array of MultivariateGaussian where gaussians(i) represents + * the Multivariate Gaussian (Normal) Distribution for Gaussian i */ -class GaussianMixtureModel( - val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable { - +@Since("1.3.0") +@Experimental +class GaussianMixtureModel @Since("1.3.0") ( + @Since("1.3.0") val weights: Array[Double], + @Since("1.3.0") val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") - - /** Number of gaussians in mixture */ + + override protected def formatVersion = "1.0" + + @Since("1.4.0") + override def save(sc: SparkContext, path: String): Unit = { + GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) + } + + /** + * Number of gaussians in mixture + */ + @Since("1.3.0") def k: Int = weights.length - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.3.0") def predict(points: RDD[Vector]): RDD[Int] = { val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - + + /** + * Maps given point to its cluster index. + */ + @Since("1.5.0") + def predict(point: Vector): Int = { + val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + r.indexOf(r.max) + } + + /** + * Java-friendly version of [[predict()]] + */ + @Since("1.4.0") + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** * Given the input vectors, return the membership value of each vector - * to all mixture components. + * to all mixture components. */ + @Since("1.3.0") def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) - points.map { x => + points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } - + + /** + * Given the input vector, return the membership values to all mixture components. + */ + @Since("1.4.0") + def predictSoft(point: Vector): Array[Double] = { + computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + } + /** * Compute the partial assignments for each vector */ @@ -74,10 +123,86 @@ class GaussianMixtureModel( val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) } - val pSum = p.sum + val pSum = p.sum for (i <- 0 until k) { p(i) /= pSum } p - } + } +} + +@Since("1.4.0") +@Experimental +object GaussianMixtureModel extends Loader[GaussianMixtureModel] { + + private object SaveLoadV1_0 { + + case class Data(weight: Double, mu: Vector, sigma: Matrix) + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel" + + def save( + sc: SparkContext, + path: String, + weights: Array[Double], + gaussians: Array[MultivariateGaussian]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(weights.length) { i => + Data(weights(i), gaussians(i).mu, gaussians(i).sigma) + } + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): GaussianMixtureModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val (weights, gaussians) = dataArray.map { + case Row(weight: Double, mu: Vector, sigma: Matrix) => + (weight, new MultivariateGaussian(mu, sigma)) + }.unzip + + new GaussianMixtureModel(weights.toArray, gaussians.toArray) + } + } + + @Since("1.4.0") + override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val k = (metadata \ "k").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, version) match { + case (classNameV1_0, "1.0") => { + val model = SaveLoadV1_0.load(sc, path) + require(model.weights.length == k, + s"GaussianMixtureModel requires weights of length $k " + + s"got weights of length ${model.weights.length}") + require(model.gaussians.length == k, + s"GaussianMixtureModel requires gaussians of length $k" + + s"got gaussians of length ${model.gaussians.length}") + model + } + case _ => throw new Exception( + s"GaussianMixtureModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 11633e824231..46920fffe6e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -37,6 +37,7 @@ import org.apache.spark.util.random.XORShiftRandom * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given * to it should be cached by the user. */ +@Since("0.8.0") class KMeans private ( private var k: Int, private var maxIterations: Int, @@ -50,39 +51,72 @@ class KMeans private ( * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ + @Since("0.8.0") def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) - /** Set the number of clusters to create (k). Default: 2. */ + /** + * Number of clusters to create (k). + */ + @Since("1.4.0") + def getK: Int = k + + /** + * Set the number of clusters to create (k). Default: 2. + */ + @Since("0.8.0") def setK(k: Int): this.type = { this.k = k this } - /** Set maximum number of iterations to run. Default: 20. */ + /** + * Maximum number of iterations to run. + */ + @Since("1.4.0") + def getMaxIterations: Int = maxIterations + + /** + * Set maximum number of iterations to run. Default: 20. + */ + @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } + /** + * The initialization algorithm. This can be either "random" or "k-means||". + */ + @Since("1.4.0") + def getInitializationMode: String = initializationMode + /** * Set the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ + @Since("0.8.0") def setInitializationMode(initializationMode: String): this.type = { - if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { - throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) - } + KMeans.validateInitMode(initializationMode) this.initializationMode = initializationMode this } + /** + * :: Experimental :: + * Number of runs of the algorithm to execute in parallel. + */ + @Since("1.4.0") + @Experimental + def getRuns: Int = runs + /** * :: Experimental :: * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm * this many times with random starting conditions (configured by the initialization mode), then * return the best clustering found over any run. Default: 1. */ + @Since("0.8.0") @Experimental def setRuns(runs: Int): this.type = { if (runs <= 0) { @@ -92,10 +126,17 @@ class KMeans private ( this } + /** + * Number of steps for the k-means|| initialization mode + */ + @Since("1.4.0") + def getInitializationSteps: Int = initializationSteps + /** * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. */ + @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { throw new IllegalArgumentException("Number of initialization steps must be positive") @@ -104,25 +145,58 @@ class KMeans private ( this } + /** + * The distance threshold within which we've consider centers to have converged. + */ + @Since("1.4.0") + def getEpsilon: Double = epsilon + /** * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. */ + @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon this } - /** Set the random seed for cluster initialization. */ + /** + * The random seed for cluster initialization. + */ + @Since("1.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("1.4.0") def setSeed(seed: Long): this.type = { this.seed = seed this } + // Initial cluster centers can be provided as a KMeansModel object rather than using the + // random or k-means|| initializationMode + private var initialModel: Option[KMeansModel] = None + + /** + * Set the initial starting point, bypassing the random initialization or k-means|| + * The condition model.k == this.k must be met, failure results + * in an IllegalArgumentException. + */ + @Since("1.4.0") + def setInitialModel(model: KMeansModel): this.type = { + require(model.k == k, "mismatched cluster count") + initialModel = Some(model) + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ + @Since("0.8.0") def run(data: RDD[Vector]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { @@ -156,20 +230,34 @@ class KMeans private ( val initStartTime = System.nanoTime() - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) + // Only one run is allowed when initialModel is given + val numRuns = if (initialModel.nonEmpty) { + if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") + 1 } else { - initKMeansParallel(data) + runs } + val centers = initialModel match { + case Some(kMeansCenters) => { + Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + } + case None => { + if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + } + } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + " seconds.") - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) + val active = Array.fill(numRuns)(true) + val costs = Array.fill(numRuns)(0.0) - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) var iteration = 0 val iterationStartTime = System.nanoTime() @@ -367,10 +455,13 @@ class KMeans private ( /** * Top-level methods for calling K-means clustering. */ +@Since("0.8.0") object KMeans { // Initialization mode names + @Since("0.8.0") val RANDOM = "random" + @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" /** @@ -383,6 +474,7 @@ object KMeans { * @param initializationMode initialization model, either "random" or "k-means||" (default). * @param seed random seed value for cluster initialization */ + @Since("1.3.0") def train( data: RDD[Vector], k: Int, @@ -407,6 +499,7 @@ object KMeans { * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -423,6 +516,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -433,6 +527,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -484,6 +579,14 @@ object KMeans { v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } + + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } } /** @@ -499,5 +602,5 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable def this(array: Array[Double]) = this(Vectors.dense(array)) /** Converts the vector to a dense vector. */ - def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm) + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3b95a9e6936e..a74158498272 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,32 +17,63 @@ package org.apache.spark.mllib.clustering +import scala.collection.JavaConverters._ + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.pmml.PMMLExportable +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { +@Since("0.8.0") +class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) + extends Saveable with Serializable with PMMLExportable { - /** Total number of clusters. */ + /** + * A Java-friendly constructor that takes an Iterable of Vectors. + */ + @Since("1.4.0") + def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) + + /** + * Total number of clusters. + */ + @Since("0.8.0") def k: Int = clusterCenters.length - /** Returns the cluster index that a given point belongs to. */ + /** + * Returns the cluster index that a given point belongs to. + */ + @Since("0.8.0") def predict(point: Vector): Int = { KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = points.context.broadcast(centersWithNorm) points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.0.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -50,6 +81,7 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. */ + @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = data.context.broadcast(centersWithNorm) @@ -58,4 +90,63 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) + + @Since("1.4.0") + override def save(sc: SparkContext, path: String): Unit = { + KMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("1.4.0") +object KMeansModel extends Loader[KMeansModel] { + + @Since("1.4.0") + override def load(sc: SparkContext, path: String): KMeansModel = { + KMeansModel.SaveLoadV1_0.load(sc, path) + } + + private case class Cluster(id: Int, point: Vector) + + private object Cluster { + def apply(r: Row): Cluster = { + Cluster(r.getInt(0), r.getAs[Vector](1)) + } + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => + Cluster(id, point) + }.toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centroids = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.map(Cluster.apply).collect() + assert(k == localCentroids.size) + new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d8f82867a09d..92a321afb0ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,21 +17,16 @@ package org.apache.spark.mllib.clustering -import java.util.Random - -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils - /** * :: Experimental :: * @@ -42,39 +37,44 @@ import org.apache.spark.util.Utils * - "token": instance of a term appearing in a document * - "topic": multinomial distribution over words representing some concept * - * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented - * according to the Asuncion et al. (2009) paper referenced below. - * * References: * - Original LDA paper (journal version): * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. - * - This class implements their "smoothed" LDA model. - * - Paper which clearly explains several algorithms, including EM: - * Asuncion, Welling, Smyth, and Teh. - * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] */ +@Since("1.3.0") @Experimental class LDA private ( private var k: Int, private var maxIterations: Int, - private var docConcentration: Double, + private var docConcentration: Vector, private var topicConcentration: Double, private var seed: Long, - private var checkpointDir: Option[String], - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var ldaOptimizer: LDAOptimizer) extends Logging { - def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10) + /** + * Constructs a LDA instance with default parameters. + */ + @Since("1.3.0") + def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), + topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, + ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. + * */ + @Since("1.3.0") def getK: Int = k /** * Number of topics to infer. I.e., the number of soft cluster centers. * (default = 10) */ + @Since("1.3.0") def setK(k: Int): this.type = { require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") this.k = k @@ -85,13 +85,26 @@ class LDA private ( * Concentration parameter (commonly named "alpha") for the prior placed on documents' * distributions over topics ("theta"). * - * This is the parameter to a symmetric Dirichlet distribution. + * This is the parameter to a Dirichlet distribution. */ + @Since("1.5.0") + def getAsymmetricDocConcentration: Vector = this.docConcentration + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This method assumes the Dirichlet distribution is symmetric and can be described by a single + * [[Double]] parameter. It should fail if docConcentration is asymmetric. + */ + @Since("1.3.0") def getDocConcentration: Double = { - if (this.docConcentration == -1) { - (50.0 / k) + 1.0 + val parameter = docConcentration(0) + if (docConcentration.size == 1) { + parameter } else { - this.docConcentration + require(docConcentration.toArray.forall(_ == parameter)) + parameter } } @@ -99,31 +112,64 @@ class LDA private ( * Concentration parameter (commonly named "alpha") for the prior placed on documents' * distributions over topics ("theta"). * - * This is the parameter to a symmetric Dirichlet distribution. - * - * This value should be > 1.0, where larger values mean more smoothing (more regularization). - * If set to -1, then docConcentration is set automatically. - * (default = -1 = automatic) + * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing + * (more regularization). * - * Automatic setting of parameter: - * - For EM: default = (50 / k) + 1. - * - The 50/k is common in LDA libraries. - * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * If set to a singleton vector Vector(-1), then docConcentration is set automatically. If set to + * singleton vector Vector(t) where t != -1, then t is replicated to a vector of length k during + * [[LDAOptimizer.initialize()]]. Otherwise, the [[docConcentration]] vector must be length k. + * (default = Vector(-1) = automatic) * - * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), - * but values in (0,1) are not yet supported. + * Optimizer-specific parameter settings: + * - EM + * - Currently only supports symmetric distributions, so all values in the vector should be + * the same. + * - Values should be > 1.0 + * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * from Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Values should be >= 0 + * - default = uniformly (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. */ - def setDocConcentration(docConcentration: Double): this.type = { - require(docConcentration > 1.0 || docConcentration == -1.0, - s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration") + @Since("1.5.0") + def setDocConcentration(docConcentration: Vector): this.type = { + require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration this } - /** Alias for [[getDocConcentration]] */ + /** + * Replicates a [[Double]] docConcentration to create a symmetric prior. + */ + @Since("1.3.0") + def setDocConcentration(docConcentration: Double): this.type = { + this.docConcentration = Vectors.dense(docConcentration) + this + } + + /** + * Alias for [[getAsymmetricDocConcentration]] + */ + @Since("1.5.0") + def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration + + /** + * Alias for [[getDocConcentration]] + */ + @Since("1.3.0") def getAlpha: Double = getDocConcentration - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[setDocConcentration()]] + */ + @Since("1.5.0") + def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) + + /** + * Alias for [[setDocConcentration()]] + */ + @Since("1.3.0") def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) /** @@ -135,13 +181,8 @@ class LDA private ( * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ - def getTopicConcentration: Double = { - if (this.topicConcentration == -1) { - 1.1 - } else { - this.topicConcentration - } - } + @Since("1.3.0") + def getTopicConcentration: Double = this.topicConcentration /** * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' @@ -152,101 +193,123 @@ class LDA private ( * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. * - * This value should be > 0.0. * If set to -1, then topicConcentration is set automatically. * (default = -1 = automatic) * - * Automatic setting of parameter: - * - For EM: default = 0.1 + 1. - * - The 0.1 gives a small amount of smoothing. - * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. - * - * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), - * but values in (0,1) are not yet supported. + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. */ + @Since("1.3.0") def setTopicConcentration(topicConcentration: Double): this.type = { - require(topicConcentration > 1.0 || topicConcentration == -1.0, - s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration") this.topicConcentration = topicConcentration this } - /** Alias for [[getTopicConcentration]] */ + /** + * Alias for [[getTopicConcentration]] + */ + @Since("1.3.0") def getBeta: Double = getTopicConcentration - /** Alias for [[setTopicConcentration()]] */ - def setBeta(beta: Double): this.type = setBeta(beta) + /** + * Alias for [[setTopicConcentration()]] + */ + @Since("1.3.0") + def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Maximum number of iterations for learning. * (default = 20) */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Random seed */ + /** + * Random seed + */ + @Since("1.3.0") def getSeed: Long = seed - /** Random seed */ + /** + * Random seed + */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this } /** - * Directory for storing checkpoint files during learning. - * This is not necessary, but checkpointing helps with recovery (when nodes fail). - * It also helps with eliminating temporary shuffle files on disk, which can be important when - * LDA is run for many iterations. + * Period (in iterations) between checkpoints. */ - def getCheckpointDir: Option[String] = checkpointDir + @Since("1.3.0") + def getCheckpointInterval: Int = checkpointInterval /** - * Directory for storing checkpoint files during learning. - * This is not necessary, but checkpointing helps with recovery (when nodes fail). - * It also helps with eliminating temporary shuffle files on disk, which can be important when - * LDA is run for many iterations. + * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be + * important when LDA is run for many iterations. If the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. * - * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value - * given to LDA is ignored, and the existing directory is kept. - * - * (default = None) + * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ - def setCheckpointDir(checkpointDir: String): this.type = { - this.checkpointDir = Some(checkpointDir) + @Since("1.3.0") + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval this } + /** - * Clear the directory for storing checkpoint files during learning. - * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still - * occur; otherwise, no checkpointing will be used. + * :: DeveloperApi :: + * + * LDAOptimizer used to perform the actual calculation */ - def clearCheckpointDir(): this.type = { - this.checkpointDir = None - this - } + @Since("1.4.0") + @DeveloperApi + def getOptimizer: LDAOptimizer = ldaOptimizer /** - * Period (in iterations) between checkpoints. - * @see [[getCheckpointDir]] + * :: DeveloperApi :: + * + * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) */ - def getCheckpointInterval: Int = checkpointInterval + @Since("1.4.0") + @DeveloperApi + def setOptimizer(optimizer: LDAOptimizer): this.type = { + this.ldaOptimizer = optimizer + this + } /** - * Period (in iterations) between checkpoints. - * (default = 10) - * @see [[getCheckpointDir]] + * Set the LDAOptimizer used to perform the actual calculation by algorithm name. + * Currently "em", "online" are supported. */ - def setCheckpointInterval(checkpointInterval: Int): this.type = { - this.checkpointInterval = checkpointInterval + @Since("1.4.0") + def setOptimizer(optimizerName: String): this.type = { + this.ldaOptimizer = + optimizerName.toLowerCase match { + case "em" => new EMLDAOptimizer + case "online" => new OnlineLDAOptimizer + case other => + throw new IllegalArgumentException(s"Only em, online are supported but got $other.") + } this } @@ -259,9 +322,9 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointDir, checkpointInterval) + @Since("1.3.0") + def run(documents: RDD[(Long, Vector)]): LDAModel = { + val state = ldaOptimizer.initialize(documents, this) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -271,12 +334,14 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) + state.getLDAModel(iterationTimes) } - /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + /** + * Java-friendly version of [[run()]] + */ + @Since("1.3.0") + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -337,102 +402,23 @@ private[clustering] object LDA { * Vector over topics (length k) of token counts. * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. */ - type TopicCounts = BDV[Double] + private[clustering] type TopicCounts = BDV[Double] - type TokenCount = Double + private[clustering] type TokenCount = Double /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ - def term2index(term: Int): Long = -(1 + term.toLong) + private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) - def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt - def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 - def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 - - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointDir: Option[String], - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointDir, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} - val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) - } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 - } else { - m0._2 + m1._2 - } - (true, sum) - } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this - } - - /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. - */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } - - } + private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 /** * Compute gamma_{wjk}, a distribution over topics k. */ - private def computePTopic( + private[clustering] def computePTopic( docTopicCounts: TopicCounts, termTopicCounts: TopicCounts, totalTopicCounts: TopicCounts, @@ -458,62 +444,4 @@ private[clustering] object LDA { // normalize BDV(gamma_wj) /= sum } - - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointDir: Option[String], - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } - } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.map { edge => - // Create a random gamma_{wjk} - (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)) - } - } - def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] = - edgesWithGamma.map { case (edge, gamma: TopicCounts) => - (sendToWhere(edge), (edge.attr, gamma)) - } - verticesTMP.aggregateByKey(BDV.zeros[Double](k))( - (sum, t) => { - brzAxpy(t._1, t._2, sum) - sum - }, - (sum0, sum1) => { - sum0 += sum1 - } - ) - } - val docVertices = createVertices(_.srcId) - val termVertices = createVertices(_.dstId) - - // Partition such that edges are grouped by document - val graph = Graph(docVertices ++ termVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) - - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir, - checkpointInterval) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 19e8aab6eabd..15129e0dd5a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,12 +17,21 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} - -import org.apache.spark.annotation.Experimental -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} +import breeze.numerics.{exp, lgamma} +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.BoundedPriorityQueue /** @@ -34,33 +43,61 @@ import org.apache.spark.util.BoundedPriorityQueue * including local and distributed data structures. */ @Experimental -abstract class LDAModel private[clustering] { +@Since("1.3.0") +abstract class LDAModel private[clustering] extends Saveable { /** Number of topics */ + @Since("1.3.0") def k: Int /** Vocabulary size (number of terms or terms in the vocabulary) */ + @Since("1.3.0") def vocabSize: Int + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution. + */ + @Since("1.5.0") + def docConcentration: Vector + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + @Since("1.5.0") + def topicConcentration: Double + + /** + * Shape parameter for random initialization of variational parameter gamma. + * Used for variational inference for perplexity and other test-time computations. + */ + protected def gammaShape: Double + /** * Inferred topics, where each topic is represented by a distribution over terms. * This is a matrix of size vocabSize x k, where each column is a topic. * No guarantees are given about the ordering of the topics. */ + @Since("1.3.0") def topicsMatrix: Matrix /** * Return the topics described by weighted terms. * - * This limits the number of terms per topic. - * This is approximate; it may not return exactly the top-weighted terms for each topic. - * To get a more precise set of top terms, increase maxTermsPerTopic. - * * @param maxTermsPerTopic Maximum number of terms to collect for each topic. * @return Array over topics. Each topic is represented as a pair of matching arrays: * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] /** @@ -72,6 +109,7 @@ abstract class LDAModel private[clustering] { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) /* TODO (once LDA can be trained with Strings or given a dictionary) @@ -130,7 +168,7 @@ abstract class LDAModel private[clustering] { /* TODO * Compute the estimated topic distribution for each document. - * This is often called “theta” in the literature. + * This is often called 'theta' in the literature. * * @param documents RDD of documents, which are term (word) count vectors paired with IDs. * The term count vectors are "bags of words" with a fixed-size vocabulary @@ -152,19 +190,27 @@ abstract class LDAModel private[clustering] { * This model stores only the inferred topics. * It may be used for computing topics for new documents, but it may give less accurate answers * than the [[DistributedLDAModel]]. - * * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental +@Since("1.3.0") class LocalLDAModel private[clustering] ( - private val topics: Matrix) extends LDAModel with Serializable { + @Since("1.3.0") val topics: Matrix, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel with Serializable { + @Since("1.3.0") override def k: Int = topics.numCols + @Since("1.3.0") override def vocabSize: Int = topics.numRows + @Since("1.3.0") override def topicsMatrix: Matrix = topics + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val brzTopics = topics.toBreeze.toDenseMatrix Range(0, k).map { topicIndex => @@ -175,14 +221,265 @@ class LocalLDAModel private[clustering] ( }.toArray } - // TODO - // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + override protected def formatVersion = "1.0" - // TODO: - // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + @Since("1.5.0") + override def save(sc: SparkContext, path: String): Unit = { + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, + gammaShape) + } + + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * + * See Equation (16) in original Online LDA paper. + * + * @param documents test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + @Since("1.5.0") + def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents, + docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, + vocabSize) + + /** + * Java-friendly version of [[logLikelihood]] + */ + @Since("1.5.0") + def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + + /** + * Calculate an upper bound bound on perplexity. (Lower is better.) + * See Equation (16) in original Online LDA paper. + * + * @param documents test corpus to use for calculating perplexity + * @return Variational upper bound on log perplexity per token. + */ + @Since("1.5.0") + def logPerplexity(documents: RDD[(Long, Vector)]): Double = { + val corpusTokenCount = documents + .map { case (_, termCounts) => termCounts.toArray.sum } + .sum() + -logLikelihood(documents) / corpusTokenCount + } + + /** Java-friendly version of [[logPerplexity]] */ + @Since("1.5.0") + def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + + /** + * Estimate the variational likelihood bound of from `documents`: + * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] + * This bound is derived by decomposing the LDA model to: + * log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p) + * and noting that the KL-divergence D(q|p) >= 0. + * + * See Equation (16) in original Online LDA paper, as well as Appendix A.3 in the JMLR version of + * the original LDA paper. + * @param documents a subset of the test corpus + * @param alpha document-topic Dirichlet prior parameters + * @param eta topic-word Dirichlet prior parameter + * @param lambda parameters for variational q(beta | lambda) topic-word distributions + * @param gammaShape shape parameter for random initialization of variational q(theta | gamma) + * topic mixture distributions + * @param k number of topics + * @param vocabSize number of unique terms in the entire test corpus + */ + private def logLikelihoodBound( + documents: RDD[(Long, Vector)], + alpha: Vector, + eta: Double, + lambda: BDM[Double], + gammaShape: Double, + k: Int, + vocabSize: Long): Double = { + val brzAlpha = alpha.toBreeze.toDenseVector + // transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t + val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + + // Sum bound components for each document: + // component for prob(tokens) + component for prob(document-topic distribution) + val corpusPart = + documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => + val localElogbeta = ElogbetaBc.value + var docBound = 0.0D + val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) + + // E[log p(doc | theta, beta)] + termCounts.foreachActive { case (idx, count) => + docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t) + } + // E[log p(theta | alpha) - log q(theta | gamma)] + docBound += sum((brzAlpha - gammad) :* Elogthetad) + docBound += sum(lgamma(gammad) - lgamma(brzAlpha)) + docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) + + docBound + }.sum() + + // Bound component for prob(topic-term distributions): + // E[log p(beta | eta) - log q(beta | lambda)] + val sumEta = eta * vocabSize + val topicsPart = sum((eta - lambda) :* Elogbeta) + + sum(lgamma(lambda) - lgamma(eta)) + + sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) + + corpusPart + topicsPart + } + + /** + * Predicts the topic mixture distribution for each document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @param documents documents to predict topic mixture distributions for + * @return An RDD of (document ID, topic mixture distribution for document) + */ + @Since("1.3.0") + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + // Double transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + documents.map { case (id: Long, termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + (id, Vectors.zeros(k)) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbetaBc.value, + docConcentrationBrz, + gammaShape, + k) + (id, Vectors.dense(normalize(gamma, 1.0).toArray)) + } + } + } + + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.4.1") + def topicDistributions( + documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { + val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } } +@Experimental +@Since("1.5.0") +object LocalLDAModel extends Loader[LocalLDAModel] { + + private object SaveLoadV1_0 { + + val thisFormatVersion = "1.0" + + val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel" + + // Store the distribution of terms of each topic and the column index in topicsMatrix + // as a Row in data. + case class Data(topic: Vector, index: Int) + + def save( + sc: SparkContext, + path: String, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val k = topicsMatrix.numCols + val metadata = compact(render + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("gammaShape" -> gammaShape))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix + val topics = Range(0, k).map { topicInd => + Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd) + }.toSeq + sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) + } + + def load( + sc: SparkContext, + path: String, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): LocalLDAModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = SQLContext.getOrCreate(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + + Loader.checkSchema[Data](dataFrame.schema) + val topics = dataFrame.collect() + val vocabSize = topics(0).getAs[Vector](0).size + val k = topics.length + + val brzTopics = BDM.zeros[Double](vocabSize, k) + topics.foreach { case Row(vec: Vector, ind: Int) => + brzTopics(::, ind) := vec.toBreeze + } + val topicsMat = Matrices.fromBreeze(brzTopics) + + new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) + } + } + + @Since("1.5.0") + override def load(sc: SparkContext, path: String): LocalLDAModel = { + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedK = (metadata \ "k").extract[Int] + val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val gammaShape = (metadata \ "gammaShape").extract[Double] + val classNameV1_0 = SaveLoadV1_0.thisClassName + + val model = (loadedClassName, loadedVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape) + case _ => throw new Exception( + s"LocalLDAModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + + val topicsMatrix = model.topicsMatrix + require(expectedK == topicsMatrix.numCols, + s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics") + require(expectedVocabSize == topicsMatrix.numRows, + s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + + s"but got ${topicsMatrix.numRows}") + model + } +} + /** * :: Experimental :: * @@ -192,28 +489,28 @@ class LocalLDAModel private[clustering] ( * than the [[LocalLDAModel]]. */ @Experimental -class DistributedLDAModel private ( - private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], - private val globalTopicTotals: LDA.TopicCounts, - val k: Int, - val vocabSize: Int, - private val docConcentration: Double, - private val topicConcentration: Double, - private[spark] val iterationTimes: Array[Double]) extends LDAModel { +@Since("1.3.0") +class DistributedLDAModel private[clustering] ( + private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private[clustering] val globalTopicTotals: LDA.TopicCounts, + @Since("1.3.0") val k: Int, + @Since("1.3.0") val vocabSize: Int, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, + private[spark] val iterationTimes: Array[Double], + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel { import LDA._ - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { - this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, - state.topicConcentration, iterationTimes) - } - /** * Convert model to a local model. * The local model stores the inferred topics but not the topic distributions for training * documents. */ - def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + @Since("1.3.0") + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) /** * Inferred topics, where each topic is represented by a distribution over terms. @@ -222,6 +519,7 @@ class DistributedLDAModel private ( * * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. */ + @Since("1.3.0") override lazy val topicsMatrix: Matrix = { // Collect row-major topics val termTopicCounts: Array[(Int, TopicCounts)] = @@ -240,6 +538,7 @@ class DistributedLDAModel private ( Matrices.fromBreeze(brzTopics) } + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val numTopics = k // Note: N_k is not needed to find the top terms, but it is needed to normalize weights @@ -271,6 +570,87 @@ class DistributedLDAModel private ( } } + /** + * Return the top documents for each topic + * + * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic. + * @return Array over topics. Each element represent as a pair of matching arrays: + * (IDs for the documents, weights of the topic in these documents). + * For each topic, documents are sorted in order of decreasing topic weights. + */ + @Since("1.5.0") + def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { + val numTopics = k + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = + topicDistributions.mapPartitions { docVertices => + // For this partition, collect the most common docs for each topic in queues: + // queues(topic) = queue of (doc topic, doc ID). + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic)) + for ((docId, docTopics) <- docVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (docTopics(topic) -> docId) + topic += 1 + } + } + Iterator(queues) + }.treeReduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b } + q1 + } + topicsInQueues.map { q => + val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip + (docs.toArray, docTopics.toArray) + } + } + + /** + * Return the top topic for each (doc, term) pair. I.e., for each document, what is the most + * likely topic generating each term? + * + * @return RDD of (doc ID, assignment of top topic index for each term), + * where the assignment is specified via a pair of zippable arrays + * (term indices, topic indices). Note that terms will be omitted if not present in + * the document. + */ + @Since("1.5.0") + lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { + // For reference, compare the below code with the core part of EMLDAOptimizer.next(). + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration(0) + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit = + (edgeContext) => { + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions). + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) + // For this (doc j, term w), send top topic k to doc vertex. + val topTopic: Int = argmax(scaledTopicDistribution) + val term: Int = index2term(edgeContext.dstId) + edgeContext.sendToSrc((Array(term), Array(topTopic))) + } + val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) = + (terms_topics0, terms_topics1) => { + (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val perDocAssignments = + graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex) + perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) => + // TODO: Avoid zip, which is inefficient. + val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip + (docID, sortedTerms.toArray, sortedTopics.toArray) + } + } + + /** Java-friendly version of [[topicAssignments]] */ + @Since("1.5.0") + lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { + topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -284,9 +664,11 @@ class DistributedLDAModel private ( * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the * hyperparameters. */ + @Since("1.3.0") lazy val logLikelihood: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration assert(eta > 1.0) assert(alpha > 1.0) val N_k = globalTopicTotals @@ -307,11 +689,13 @@ class DistributedLDAModel private ( /** * Log probability of the current parameter estimate: - * log P(topics, topic distributions for docs | alpha, eta) + * log P(topics, topic distributions for docs | alpha, eta) */ + @Since("1.3.0") lazy val logPrior: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration // Term vertices: Compute phi_{wk}. Use to compute prior log probability. // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. val N_k = globalTopicTotals @@ -322,12 +706,12 @@ class DistributedLDAModel private ( val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * brzSum(phi_wk.map(math.log)) + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + (alpha - 1.0) * sum(theta_kj.map(math.log)) } } graph.vertices.aggregate(0.0)(seqOp, _ + _) @@ -335,17 +719,196 @@ class DistributedLDAModel private ( /** * For each document in the training set, return the distribution over topics for that document - * (i.e., "theta_doc"). + * ("theta_doc"). * * @return RDD of (document ID, topic distribution) pairs */ + @Since("1.3.0") def topicDistributions: RDD[(Long, Vector)] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) } } + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.4.1") + def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { + JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } + + /** + * For each document, return the top k weighted topics for that document and their weights. + * @return RDD of (doc ID, topic indices, topic weights) + */ + @Since("1.5.0") + def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + val topIndices = argtopk(topicCounts, k) + val sumCounts = sum(topicCounts) + val weights = if (sumCounts != 0) { + topicCounts(topIndices) / sumCounts + } else { + topicCounts(topIndices) + } + (docID.toLong, topIndices.toArray, weights.toArray) + } + } + + /** + * Java-friendly version of [[topTopicsPerDocument]] + */ + @Since("1.5.0") + def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { + val topics = topTopicsPerDocument(k) + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + override protected def formatVersion = "1.0" + + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.5.0") + override def save(sc: SparkContext, path: String): Unit = { + DistributedLDAModel.SaveLoadV1_0.save( + sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, + iterationTimes, gammaShape) + } +} + + +@Experimental +@Since("1.5.0") +object DistributedLDAModel extends Loader[DistributedLDAModel] { + + private object SaveLoadV1_0 { + + val thisFormatVersion = "1.0" + + val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel" + + // Store globalTopicTotals as a Vector. + case class Data(globalTopicTotals: Vector) + + // Store each term and document vertex with an id and the topicWeights. + case class VertexData(id: Long, topicWeights: Vector) + + // Store each edge with the source id, destination id and tokenCounts. + case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double) + + def save( + sc: SparkContext, + path: String, + graph: Graph[LDA.TopicCounts, LDA.TokenCount], + globalTopicTotals: LDA.TopicCounts, + k: Int, + vocabSize: Int, + docConcentration: Vector, + topicConcentration: Double, + iterationTimes: Array[Double], + gammaShape: Double): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val metadata = compact(render + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("iterationTimes" -> iterationTimes.toSeq) ~ + ("gammaShape" -> gammaShape))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString + sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF() + .write.parquet(newPath) + + val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString + graph.vertices.map { case (ind, vertex) => + VertexData(ind, Vectors.fromBreeze(vertex)) + }.toDF().write.parquet(verticesPath) + + val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString + graph.edges.map { case Edge(srcId, dstId, prop) => + EdgeData(srcId, dstId, prop) + }.toDF().write.parquet(edgesPath) + } + + def load( + sc: SparkContext, + path: String, + vocabSize: Int, + docConcentration: Vector, + topicConcentration: Double, + iterationTimes: Array[Double], + gammaShape: Double): DistributedLDAModel = { + val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString + val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString + val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString + val sqlContext = SQLContext.getOrCreate(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + val vertexDataFrame = sqlContext.read.parquet(vertexDataPath) + val edgeDataFrame = sqlContext.read.parquet(edgeDataPath) + + Loader.checkSchema[Data](dataFrame.schema) + Loader.checkSchema[VertexData](vertexDataFrame.schema) + Loader.checkSchema[EdgeData](edgeDataFrame.schema) + val globalTopicTotals: LDA.TopicCounts = + dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { + case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) + } + + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { + case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) + } + val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) + + new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, + docConcentration, topicConcentration, iterationTimes, gammaShape) + } + + } + + @Since("1.5.0") + override def load(sc: SparkContext, path: String): DistributedLDAModel = { + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedK = (metadata \ "k").extract[Int] + val vocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] + val gammaShape = (metadata \ "gammaShape").extract[Double] + val classNameV1_0 = SaveLoadV1_0.thisClassName + + val model = (loadedClassName, loadedVersion) match { + case (className, "1.0") if className == classNameV1_0 => + DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, + topicConcentration, iterationTimes.toArray, gammaShape) + case _ => throw new Exception( + s"DistributedLDAModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") + } + + require(model.vocabSize == vocabSize, + s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize") + require(model.docConcentration == docConcentration, + s"DistributedLDAModel requires $docConcentration docConcentration, " + + s"got ${model.docConcentration} docConcentration") + require(model.topicConcentration == topicConcentration, + s"DistributedLDAModel requires $topicConcentration docConcentration, " + + s"got ${model.topicConcentration} docConcentration") + require(expectedK == model.k, + s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics") + model + } + } + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala new file mode 100644 index 000000000000..38486e949bbc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} +import breeze.numerics.{trigamma, abs, exp} +import breeze.stats.distributions.{Gamma, RandBasis} + +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * + * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can + * hold optimizer-specific parameters for users to set. + */ +@Since("1.4.0") +@DeveloperApi +sealed trait LDAOptimizer { + + /* + DEVELOPERS NOTE: + + An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which + stores internal data structure (Graph or Matrix) and other parameters for the algorithm. + The interface is isolated to improve the extensibility of LDA. + */ + + /** + * Initializer for the optimizer. LDA passes the common parameters to the optimizer and + * the internal structure can be initialized properly. + */ + private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer + + private[clustering] def next(): LDAOptimizer + + private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel +} + +/** + * :: DeveloperApi :: + * + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + */ +@Since("1.4.0") +@DeveloperApi +final class EMLDAOptimizer extends LDAOptimizer { + + import LDA._ + + /** + * The following fields will only be initialized through the initialize() method + */ + private[clustering] var graph: Graph[TopicCounts, TokenCount] = null + private[clustering] var k: Int = 0 + private[clustering] var vocabSize: Int = 0 + private[clustering] var docConcentration: Double = 0 + private[clustering] var topicConcentration: Double = 0 + private[clustering] var checkpointInterval: Int = 10 + private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null + + /** + * Compute bipartite term/doc graph. + */ + override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { + // EMLDAOptimizer currently only supports symmetric document-topic priors + val docConcentration = lda.getDocConcentration + + val topicConcentration = lda.getTopicConcentration + val k = lda.getK + + // Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + // but values in (0,1) are not yet supported. + require(docConcentration > 1.0 || docConcentration == -1.0, s"LDA docConcentration must be" + + s" > 1.0 (or -1 for auto) for EM Optimizer, but was set to $docConcentration") + require(topicConcentration > 1.0 || topicConcentration == -1.0, s"LDA topicConcentration " + + s"must be > 1.0 (or -1 for auto) for EM Optimizer, but was set to $topicConcentration") + + this.docConcentration = if (docConcentration == -1) (50.0 / k) + 1.0 else docConcentration + this.topicConcentration = if (topicConcentration == -1) 1.1 else topicConcentration + val randomSeed = lda.getSeed + + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + val docTermVertices: RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + // Partition such that edges are grouped by document + this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) + this.k = k + this.vocabSize = docs.take(1).head._2.size + this.checkpointInterval = lda.getCheckpointInterval + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) + this.globalTopicTotals = computeGlobalTopicTotals() + this + } + + override private[clustering] def next(): EMLDAOptimizer = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // The Boolean is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.update(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + private[clustering] var globalTopicTotals: TopicCounts = null + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + this.graphCheckpointer.deleteAllCheckpoints() + // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in + // LDAModel.toLocal conversion + new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, + Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, + iterationTimes) + } +} + + +/** + * :: DeveloperApi :: + * + * An online optimizer for LDA. The Optimizer implements the Online variational Bayes LDA + * algorithm, which processes a subset of the corpus on each iteration, and updates the term-topic + * distribution adaptively. + * + * Original Online LDA paper: + * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010. + */ +@Since("1.4.0") +@DeveloperApi +final class OnlineLDAOptimizer extends LDAOptimizer { + + // LDA common parameters + private var k: Int = 0 + private var corpusSize: Long = 0 + private var vocabSize: Int = 0 + + /** alias for docConcentration */ + private var alpha: Vector = Vectors.dense(0) + + /** (for debugging) Get docConcentration */ + private[clustering] def getAlpha: Vector = alpha + + /** alias for topicConcentration */ + private var eta: Double = 0 + + /** (for debugging) Get topicConcentration */ + private[clustering] def getEta: Double = eta + + private var randomGenerator: java.util.Random = null + + /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */ + private var sampleWithReplacement: Boolean = true + + // Online LDA specific parameters + // Learning rate is: (tau0 + t)^{-kappa} + private var tau0: Double = 1024 + private var kappa: Double = 0.51 + private var miniBatchFraction: Double = 0.05 + private var optimizeDocConcentration: Boolean = false + + // internal data structure + private var docs: RDD[(Long, Vector)] = null + + /** Dirichlet parameter for the posterior over topics */ + private var lambda: BDM[Double] = null + + /** (for debugging) Get parameter for topics */ + private[clustering] def getLambda: BDM[Double] = lambda + + /** Current iteration (count of invocations of [[next()]]) */ + private var iteration: Int = 0 + private var gammaShape: Double = 100 + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + */ + @Since("1.4.0") + def getTau0: Double = this.tau0 + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + * Default: 1024, following the original Online LDA paper. + */ + @Since("1.4.0") + def setTau0(tau0: Double): this.type = { + require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") + this.tau0 = tau0 + this + } + + /** + * Learning rate: exponential decay rate + */ + @Since("1.4.0") + def getKappa: Double = this.kappa + + /** + * Learning rate: exponential decay rate---should be between + * (0.5, 1.0] to guarantee asymptotic convergence. + * Default: 0.51, based on the original Online LDA paper. + */ + @Since("1.4.0") + def setKappa(kappa: Double): this.type = { + require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") + this.kappa = kappa + this + } + + /** + * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration + */ + @Since("1.4.0") + def getMiniBatchFraction: Double = this.miniBatchFraction + + /** + * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in + * each iteration. + * + * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]] + * so the entire corpus is used. Specifically, set both so that + * maxIterations * miniBatchFraction >= 1. + * + * Default: 0.05, i.e., 5% of total documents. + */ + @Since("1.4.0") + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, + s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction") + this.miniBatchFraction = miniBatchFraction + this + } + + /** + * Optimize docConcentration, indicates whether docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. + */ + @Since("1.5.0") + def getOptimizeDocConcentration: Boolean = this.optimizeDocConcentration + + /** + * Sets whether to optimize docConcentration parameter during training. + * + * Default: false + */ + @Since("1.5.0") + def setOptimizeDocConcentration(optimizeDocConcentration: Boolean): this.type = { + this.optimizeDocConcentration = optimizeDocConcentration + this + } + + /** + * Set the Dirichlet parameter for the posterior over topics. + * This is only used for testing now. In the future, it can help support training stop/resume. + */ + private[clustering] def setLambda(lambda: BDM[Double]): this.type = { + this.lambda = lambda + this + } + + /** + * Used for random initialization of the variational parameters. + * Larger value produces values closer to 1.0. + * This is only used for testing currently. + */ + private[clustering] def setGammaShape(shape: Double): this.type = { + this.gammaShape = shape + this + } + + /** + * Sets whether to sample mini-batches with or without replacement. (default = true) + * This is only used for testing currently. + */ + private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = { + this.sampleWithReplacement = replace + this + } + + override private[clustering] def initialize( + docs: RDD[(Long, Vector)], + lda: LDA): OnlineLDAOptimizer = { + this.k = lda.getK + this.corpusSize = docs.count() + this.vocabSize = docs.first()._2.size + this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) { + if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + else { + require(lda.getAsymmetricDocConcentration(0) >= 0, + s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0))) + } + } else { + require(lda.getAsymmetricDocConcentration.size == k, + s"alpha must have length k, got: $alpha") + lda.getAsymmetricDocConcentration.foreachActive { case (_, x) => + require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") + } + lda.getAsymmetricDocConcentration + } + this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration + this.randomGenerator = new Random(lda.getSeed) + + this.docs = docs + + // Initialize the variational distribution q(beta|lambda) + this.lambda = getGammaMatrix(k, vocabSize) + this.iteration = 0 + this + } + + override private[clustering] def next(): OnlineLDAOptimizer = { + val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction, + randomGenerator.nextLong()) + if (batch.isEmpty()) return this + submitMiniBatch(batch) + } + + /** + * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA + * model, and it will update the topic distribution adaptively for the terms appearing in the + * subset. + */ + private[clustering] def submitMiniBatch(batch: RDD[(Long, Vector)]): OnlineLDAOptimizer = { + iteration += 1 + val k = this.k + val vocabSize = this.vocabSize + val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t + val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) + val alpha = this.alpha.toBreeze + val gammaShape = this.gammaShape + + val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + var gammaPart = List[BDV[Double]]() + nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + val ids: List[Int] = termCounts match { + case v: DenseVector => (0 until v.size).toList + case v: SparseVector => v.indices.toList + } + val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k) + stat(::, ids) := stat(::, ids).toDenseMatrix + sstats + gammaPart = gammad :: gammaPart + } + Iterator((stat, gammaPart)) + } + val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + expElogbetaBc.unpersist() + val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( + stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) + val batchResult = statsSum :* expElogbeta.t + + // Note that this is an optimization to avoid batch.count + updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) + if (optimizeDocConcentration) updateAlpha(gammat) + this + } + + /** + * Update lambda based on the batch submitted. batchSize can be different for each iteration. + */ + private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = { + // weight of the mini-batch. + val weight = rho() + + // Update lambda based on documents. + lambda := (1 - weight) * lambda + + weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) + } + + /** + * Update alpha based on `gammat`, the inferred topic distributions for documents in the + * current mini-batch. Uses Newton-Rhapson method. + * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters + * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + */ + private def updateAlpha(gammat: BDM[Double]): Unit = { + val weight = rho() + val N = gammat.rows.toDouble + val alpha = this.alpha.toBreeze.toDenseVector + val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N + val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector) + + val c = N * trigamma(sum(alpha)) + val q = -N * trigamma(alpha) + val b = sum(gradf / q) / (1D / c + sum(1D / q)) + + val dalpha = -(gradf - b) / q + + if (all((weight * dalpha + alpha) :> 0D)) { + alpha :+= weight * dalpha + this.alpha = Vectors.dense(alpha.toArray) + } + } + + + /** Calculate learning rate rho for the current [[iteration]]. */ + private def rho(): Double = { + math.pow(getTau0 + this.iteration, -getKappa) + } + + /** + * Get a random matrix to initialize lambda. + */ + private def getGammaMatrix(row: Int, col: Int): BDM[Double] = { + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister( + randomGenerator.nextLong())) + val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis) + val temp = gammaRandomGenerator.sample(row * col).toArray + new BDM[Double](col, row, temp).t + } + + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + } + +} + +/** + * Serializable companion object containing helper methods and shared code for + * [[OnlineLDAOptimizer]] and [[LocalLDAModel]]. + */ +private[clustering] object OnlineLDAOptimizer { + /** + * Uses variational inference to infer the topic distribution `gammad` given the term counts + * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will + * throw a BLAS error. + * + * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001) + * avoids explicit computation of variational parameter `phi`. + * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]] + */ + private[clustering] def variationalTopicInference( + termCounts: Vector, + expElogbeta: BDM[Double], + alpha: breeze.linalg.Vector[Double], + gammaShape: Double, + k: Int): (BDV[Double], BDM[Double]) = { + val (ids: List[Int], cts: Array[Double]) = termCounts match { + case v: DenseVector => ((0 until v.size).toList, v.values) + case v: SparseVector => (v.indices.toList, v.values) + } + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K + val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K + + val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanGammaChange = 1D + val ctsVector = new BDV[Double](cts) // ids + + // Iterate between gamma and phi until convergence + while (meanGammaChange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha + expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) + // TODO: Keep more values in log space, and only exponentiate when needed. + phiNorm := expElogbetad * expElogthetad :+ 1e-100 + meanGammaChange = sum(abs(gammad - lastgamma)) / k + } + + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix + (gammad, sstatsd) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala new file mode 100644 index 000000000000..a9ba7b60bad0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.numerics._ + +/** + * Utility methods for LDA. + */ +private[clustering] object LDAUtils { + /** + * Log Sum Exp with overflow protection using the identity: + * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} + */ + private[clustering] def logSumExp(x: BDV[Double]): Double = { + val a = max(x) + a + log(sum(exp(x :- a))) + } + + /** + * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation + * uses [[breeze.numerics.digamma]] which is accurate but expensive. + */ + private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = { + digamma(alpha) - digamma(sum(alpha)) + } + + /** + * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are + * Dirichlet parameters. + */ + private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { + val rowSum = sum(alpha(breeze.linalg.*, ::)) + val digAlpha = digamma(alpha) + val digRowSum = digamma(rowSum) + val result = digAlpha(::, breeze.linalg.*) - digRowSum + result + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 9b5c155b0a80..6c76e26fd162 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,34 +17,111 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.{Logging, SparkException} +import org.json4s.JsonDSL._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.{Logging, SparkContext, SparkException} /** + * :: Experimental :: + * * Model produced by [[PowerIterationClustering]]. * * @param k number of clusters - * @param assignments an RDD of (vertexID, clusterID) pairs + * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s */ -class PowerIterationClusteringModel( - val k: Int, - val assignments: RDD[(Long, Int)]) extends Serializable +@Since("1.3.0") +@Experimental +class PowerIterationClusteringModel @Since("1.3.0") ( + @Since("1.3.0") val k: Int, + @Since("1.3.0") val assignments: RDD[PowerIterationClustering.Assignment]) + extends Saveable with Serializable { + + @Since("1.4.0") + override def save(sc: SparkContext, path: String): Unit = { + PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("1.4.0") +object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + + @Since("1.4.0") + override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + + @Since("1.4.0") + def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataRDD = model.assignments.toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + @Since("1.4.0") + def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val k = (metadata \ "k").extract[Int] + val assignments = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) + + val assignmentsRDD = assignments.map { + case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) + } + + new PowerIterationClusteringModel(k, assignmentsRDD) + } + } +} /** - * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and - * Cohen (see http://www.icml2010.org/papers/387.pdf). From the abstract: PIC finds a very + * :: Experimental :: + * + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise * similarity matrix of the data. * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. * @param initMode Initialization mode. + * + * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ +@Experimental +@Since("1.3.0") class PowerIterationClustering private[clustering] ( private var k: Int, private var maxIterations: Int, @@ -52,14 +129,17 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ - /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + /** + * Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. */ + @Since("1.3.0") def this() = this(k = 2, maxIterations = 100, initMode = "random") /** * Set the number of clusters. */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this @@ -68,6 +148,7 @@ class PowerIterationClustering private[clustering] ( /** * Set maximum number of iterations of the power iteration loop */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -77,6 +158,7 @@ class PowerIterationClustering private[clustering] ( * Set the initialization mode. This can be either "random" to use a random vector * as vertex properties, or "degree" to use normalized sum similarities. Default: random. */ + @Since("1.3.0") def setInitializationMode(mode: String): this.type = { this.initMode = mode match { case "random" | "degree" => mode @@ -85,17 +167,41 @@ class PowerIterationClustering private[clustering] ( this } + /** + * Run the PIC algorithm on Graph. + * + * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper. + * The similarity s,,ij,, represented as the edge between vertices (i, j) must + * be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For + * any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,) + * or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + @Since("1.5.0") + def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { + val w = normalize(graph) + val w0 = initMode match { + case "random" => randomInit(w) + case "degree" => initDegreeVector(w) + } + pic(w0) + } + /** * Run the PIC algorithm. * - * @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is - * the matrix A in the PIC paper. The similarity s_ij_ must be nonnegative. - * This is a symmetric matrix and hence s_ij_ = s_ji_. For any (i, j) with - * nonzero similarity, there should be either (i, j, s_ij_) or (j, i, s_ji_) - * in the input. Tuples with i = j are ignored, because we assume s_ij_ = 0.0. + * @param similarities an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix, which is + * the matrix A in the PIC paper. The similarity s,,ij,, must be nonnegative. + * This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For any (i, j) with + * nonzero similarity, there should be either (i, j, s,,ij,,) or + * (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. * * @return a [[PowerIterationClusteringModel]] that contains the clustering result */ + @Since("1.3.0") def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { val w = normalize(similarities) val w0 = initMode match { @@ -105,25 +211,78 @@ class PowerIterationClustering private[clustering] ( pic(w0) } + /** + * A Java-friendly version of [[PowerIterationClustering.run]]. + */ + @Since("1.3.0") + def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) + : PowerIterationClusteringModel = { + run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) + } + /** * Runs the PIC algorithm. * * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with - * w_ij_ = a_ij_ / d_ii_ as its edge properties and the initial vector of the power + * w,,ij,, = a,,ij,, / d,,ii,, as its edge properties and the initial vector of the power * iteration as its vertex properties. */ private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { val v = powerIter(w, maxIterations) - val assignments = kMeans(v, k) + val assignments = kMeans(v, k).mapPartitions({ iter => + iter.map { case (id, cluster) => + Assignment(id, cluster) + } + }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) } } -private[clustering] object PowerIterationClustering extends Logging { +@Since("1.3.0") +@Experimental +object PowerIterationClustering extends Logging { + + /** + * :: Experimental :: + * Cluster assignment. + * @param id node id + * @param cluster assigned cluster id + */ + @Since("1.3.0") + @Experimental + case class Assignment(id: Long, cluster: Int) + + /** + * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W). + */ + private[clustering] + def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = { + val vD = graph.aggregateMessages[Double]( + sendMsg = ctx => { + val i = ctx.srcId + val j = ctx.dstId + val s = ctx.attr + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (s > 0.0) { + ctx.sendToSrc(s) + } + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, graph.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). */ - def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = { + private[clustering] + def normalize(similarities: RDD[(Long, Long, Double)]) + : Graph[Double, Double] = { val edges = similarities.flatMap { case (i, j, s) => if (s < 0.0) { throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") @@ -149,11 +308,12 @@ private[clustering] object PowerIterationClustering extends Logging { /** * Generates random vertex properties (v0) to start power iteration. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm */ + private[clustering] def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { val r = g.vertices.mapPartitionsWithIndex( (part, iter) => { @@ -171,16 +331,17 @@ private[clustering] object PowerIterationClustering extends Logging { * Generates the degree vector as the vertex properties (v0) to start power iteration. * It is not exactly the node degrees but just the normalized sum similarities. Call it * as degree vector because it is used in the PIC paper. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ + private[clustering] def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { val sum = g.vertices.values.sum() val v0 = g.vertices.mapValues(_ / sum) GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } - + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices @@ -188,6 +349,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param maxIterations maximum number of iterations * @return a [[VertexRDD]] representing the pseudo-eigenvector */ + private[clustering] def powerIter( g: Graph[Double, Double], maxIterations: Int): VertexRDD[Double] = { @@ -227,6 +389,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param k number of clusters * @return a [[VertexRDD]] representing the clustering assignments */ + private[clustering] def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { val points = v.mapValues(x => Vectors.dense(x)).cache() val model = new KMeans() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 7752c1988fdd..1d50ffec96fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,16 +20,18 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeansModel extends MLlib's KMeansModel for streaming * algorithms, so it can keep track of a continuously updated weight * associated with each cluster, and also update the model by @@ -39,8 +41,10 @@ import org.apache.spark.util.random.XORShiftRandom * generalized to incorporate forgetfullness (i.e. decay). * The update rule (for each cluster) is: * + * {{{ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] * n_t+t = n_t * a + m_t + * }}} * * Where c_t is the previously estimated centroid for that cluster, * n_t is the number of points assigned to it thus far, x_t is the centroid @@ -59,14 +63,18 @@ import org.apache.spark.util.random.XORShiftRandom * such that at time t + h the discount applied to the data from t is 0.5. * The definition remains the same whether the time unit is given * as batches or points. - * */ -@DeveloperApi -class StreamingKMeansModel( - override val clusterCenters: Array[Vector], - val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { +@Since("1.2.0") +@Experimental +class StreamingKMeansModel @Since("1.2.0") ( + @Since("1.2.0") override val clusterCenters: Array[Vector], + @Since("1.2.0") val clusterWeights: Array[Double]) + extends KMeansModel(clusterCenters) with Logging { - /** Perform a k-means update on a batch of data. */ + /** + * Perform a k-means update on a batch of data. + */ + @Since("1.2.0") def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { // find nearest cluster to each point @@ -78,6 +86,7 @@ class StreamingKMeansModel( (p1._1, p1._2 + p2._2) } val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) .collect() @@ -140,7 +149,8 @@ class StreamingKMeansModel( } /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeans provides methods for configuring a * streaming k-means analysis, training the model on streaming, * and using the model to make predictions on streaming data. @@ -149,35 +159,48 @@ class StreamingKMeansModel( * Use a builder pattern to construct a streaming k-means analysis * in an application, like: * + * {{{ * val model = new StreamingKMeans() * .setDecayFactor(0.5) * .setK(3) * .setRandomCenters(5, 100.0) * .trainOn(DStream) + * }}} */ -@DeveloperApi -class StreamingKMeans( - var k: Int, - var decayFactor: Double, - var timeUnit: String) extends Logging { +@Since("1.2.0") +@Experimental +class StreamingKMeans @Since("1.2.0") ( + @Since("1.2.0") var k: Int, + @Since("1.2.0") var decayFactor: Double, + @Since("1.2.0") var timeUnit: String) extends Logging with Serializable { + @Since("1.2.0") def this() = this(2, 1.0, StreamingKMeans.BATCHES) protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) - /** Set the number of clusters. */ + /** + * Set the number of clusters. + */ + @Since("1.2.0") def setK(k: Int): this.type = { this.k = k this } - /** Set the decay factor directly (for forgetful algorithms). */ + /** + * Set the decay factor directly (for forgetful algorithms). + */ + @Since("1.2.0") def setDecayFactor(a: Double): this.type = { - this.decayFactor = decayFactor + this.decayFactor = a this } - /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + /** + * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + */ + @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) @@ -188,7 +211,10 @@ class StreamingKMeans( this } - /** Specify initial centers directly. */ + /** + * Specify initial centers directly. + */ + @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { model = new StreamingKMeansModel(centers, weights) this @@ -201,6 +227,7 @@ class StreamingKMeans( * @param weight Weight for each center * @param seed Random seed */ + @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) @@ -209,7 +236,10 @@ class StreamingKMeans( this } - /** Return the latest model. */ + /** + * Return the latest model. + */ + @Since("1.2.0") def latestModel(): StreamingKMeansModel = { model } @@ -222,6 +252,7 @@ class StreamingKMeans( * * @param data DStream containing vector data */ + @Since("1.2.0") def trainOn(data: DStream[Vector]) { assertInitialized() data.foreachRDD { (rdd, time) => @@ -229,17 +260,32 @@ class StreamingKMeans( } } + /** + * Java-friendly version of `trainOn`. + */ + @Since("1.4.0") + def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) + /** * Use the clustering model to make predictions on batches of data from a DStream. * * @param data DStream containing vector data * @return DStream containing predictions */ + @Since("1.2.0") def predictOn(data: DStream[Vector]): DStream[Int] = { assertInitialized() data.map(model.predict) } + /** + * Java-friendly version of `predictOn`. + */ + @Since("1.4.0") + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) + } + /** * Use the model to make predictions on the values of a DStream and carry over its keys. * @@ -247,11 +293,23 @@ class StreamingKMeans( * @tparam K key type * @return DStream containing the input keys and the predictions as values */ + @Since("1.2.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { assertInitialized() data.mapValues(model.predict) } + /** + * Java-friendly version of `predictOnValues`. + */ + @Since("1.4.0") + def predictOnValues[K]( + data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]]) + } + /** Check whether cluster centers have been initialized. */ private[this] def assertInitialized(): Unit = { if (model.clusterCenters == null) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index ced042e2f96c..508fe532b130 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -17,11 +17,12 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.sql.DataFrame /** * :: Experimental :: @@ -41,24 +42,39 @@ import org.apache.spark.rdd.{RDD, UnionRDD} * be smaller as a result, meaning there may be an extra sample at * partition boundaries. */ +@Since("1.0.0") @Experimental -class BinaryClassificationMetrics( - val scoreAndLabels: RDD[(Double, Double)], - val numBins: Int) extends Logging { +class BinaryClassificationMetrics @Since("1.3.0") ( + @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], + @Since("1.3.0") val numBins: Int) extends Logging { require(numBins >= 0, "numBins must be nonnegative") /** * Defaults `numBins` to 0. */ + @Since("1.0.0") def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) - /** Unpersist intermediate RDDs used in the computation. */ + /** + * An auxiliary constructor taking a DataFrame. + * @param scoreAndLabels a DataFrame with two double columns: score and label + */ + private[mllib] def this(scoreAndLabels: DataFrame) = + this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + + /** + * Unpersist intermediate RDDs used in the computation. + */ + @Since("1.0.0") def unpersist() { cumulativeCounts.unpersist() } - /** Returns thresholds in descending order. */ + /** + * Returns thresholds in descending order. + */ + @Since("1.0.0") def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -67,6 +83,7 @@ class BinaryClassificationMetrics( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.0.0") def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) val sc = confusions.context @@ -78,6 +95,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. */ + @Since("1.0.0") def areaUnderROC(): Double = AreaUnderCurve.of(roc()) /** @@ -85,6 +103,7 @@ class BinaryClassificationMetrics( * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall */ + @Since("1.0.0") def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) val sc = confusions.context @@ -95,6 +114,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. */ + @Since("1.0.0") def areaUnderPR(): Double = AreaUnderCurve.of(pr()) /** @@ -103,15 +123,25 @@ class BinaryClassificationMetrics( * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score */ + @Since("1.0.0") def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) - /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + /** + * Returns the (threshold, F-Measure) curve with beta = 1.0. + */ + @Since("1.0.0") def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) - /** Returns the (threshold, precision) curve. */ + /** + * Returns the (threshold, precision) curve. + */ + @Since("1.0.0") def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) - /** Returns the (threshold, recall) curve. */ + /** + * Returns the (threshold, recall) curve. + */ + @Since("1.0.0") def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 666362ae6739..00e837661dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -20,9 +20,10 @@ package org.apache.spark.mllib.evaluation import scala.collection.Map import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame /** * ::Experimental:: @@ -30,8 +31,16 @@ import org.apache.spark.rdd.RDD * * @param predictionAndLabels an RDD of (prediction, label) pairs. */ +@Since("1.1.0") @Experimental -class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { + + /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndLabels a DataFrame with two double columns: prediction and label + */ + private[mllib] def this(predictionAndLabels: DataFrame) = + this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() private lazy val labelCount: Long = labelCountByClass.values.sum @@ -57,6 +66,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * they are ordered by class label ascending, * as in "labels" */ + @Since("1.1.0") def confusionMatrix: Matrix = { val n = labels.size val values = Array.ofDim[Double](n * n) @@ -76,12 +86,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns true positive rate for a given label (category) * @param label the label. */ + @Since("1.1.0") def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. */ + @Since("1.1.0") def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) fp.toDouble / (labelCount - labelCountByClass(label)) @@ -91,6 +103,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns precision for a given label (category) * @param label the label. */ + @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) val fp = fpByClass.getOrElse(label, 0) @@ -101,6 +114,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns recall for a given label (category) * @param label the label. */ + @Since("1.1.0") def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) /** @@ -108,6 +122,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * @param label the label. * @param beta the beta parameter. */ + @Since("1.1.0") def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) val r = recall(label) @@ -119,11 +134,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns f1-measure for a given label (category) * @param label the label. */ + @Since("1.1.0") def fMeasure(label: Double): Double = fMeasure(label, 1.0) /** * Returns precision */ + @Since("1.1.0") lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount /** @@ -132,23 +149,27 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * because sum of all false positives is equal to sum * of all false negatives) */ + @Since("1.1.0") lazy val recall: Double = precision /** * Returns f-measure * (equals to precision and recall because precision equals recall) */ + @Since("1.1.0") lazy val fMeasure: Double = precision /** * Returns weighted true positive rate * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedTruePositiveRate: Double = weightedRecall /** * Returns weighted false positive rate */ + @Since("1.1.0") lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) => falsePositiveRate(category) * count.toDouble / labelCount }.sum @@ -157,6 +178,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged recall * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => recall(category) * count.toDouble / labelCount }.sum @@ -164,6 +186,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged precision */ + @Since("1.1.0") lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => precision(category) * count.toDouble / labelCount }.sum @@ -172,6 +195,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged f-measure * @param beta the beta parameter. */ + @Since("1.1.0") def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount }.sum @@ -179,6 +203,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f1-measure */ + @Since("1.1.0") lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => fMeasure(category, 1.0) * count.toDouble / labelCount }.sum @@ -186,5 +211,6 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns the sequence of labels in ascending order */ + @Since("1.1.0") lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index ea10bde5fa25..c100b3c9ec14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -17,15 +17,25 @@ package org.apache.spark.mllib.evaluation +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.sql.DataFrame /** * Evaluator for multilabel classification. * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. */ -class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { +@Since("1.2.0") +class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + + /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndLabels a DataFrame with two double array columns: prediction and label + */ + private[mllib] def this(predictionAndLabels: DataFrame) = + this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) private lazy val numDocs: Long = predictionAndLabels.count() @@ -36,6 +46,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns subset accuracy * (for equal sets of labels) */ + @Since("1.2.0") lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => predictions.deep == labels.deep }.count().toDouble / numDocs @@ -43,6 +54,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns accuracy */ + @Since("1.2.0") lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs @@ -51,6 +63,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns Hamming-loss */ + @Since("1.2.0") lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => labels.size + predictions.size - 2 * labels.intersect(predictions).size }.sum / (numDocs * numLabels) @@ -58,6 +71,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based precision averaged by the number of documents */ + @Since("1.2.0") lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => if (predictions.size > 0) { predictions.intersect(labels).size.toDouble / predictions.size @@ -69,6 +83,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based recall averaged by the number of documents */ + @Since("1.2.0") lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / labels.size }.sum / numDocs @@ -76,6 +91,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based f1-measure averaged by the number of documents */ + @Since("1.2.0") lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) }.sum / numDocs @@ -96,30 +112,33 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns precision for a given label (category) * @param label the label. */ - def precision(label: Double) = { + @Since("1.2.0") + def precision(label: Double): Double = { val tp = tpPerClass(label) val fp = fpPerClass.getOrElse(label, 0L) - if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + if (tp + fp == 0) 0.0 else tp.toDouble / (tp + fp) } /** * Returns recall for a given label (category) * @param label the label. */ - def recall(label: Double) = { + @Since("1.2.0") + def recall(label: Double): Double = { val tp = tpPerClass(label) val fn = fnPerClass.getOrElse(label, 0L) - if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + if (tp + fn == 0) 0.0 else tp.toDouble / (tp + fn) } /** * Returns f1-measure for a given label (category) * @param label the label. */ - def f1Measure(label: Double) = { + @Since("1.2.0") + def f1Measure(label: Double): Double = { val p = precision(label) val r = recall(label) - if((p + r) == 0) 0 else 2 * p * r / (p + r) + if((p + r) == 0) 0.0 else 2 * p * r / (p + r) } private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp } @@ -130,7 +149,8 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based precision * (equals to micro-averaged document-based precision) */ - lazy val microPrecision = { + @Since("1.2.0") + lazy val microPrecision: Double = { val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) } @@ -139,7 +159,8 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based recall * (equals to micro-averaged document-based recall) */ - lazy val microRecall = { + @Since("1.2.0") + lazy val microRecall: Double = { val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) } @@ -148,10 +169,12 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based f1-measure * (equals to micro-averaged document-based f1-measure) */ - lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) + @Since("1.2.0") + lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) /** * Returns the sequence of labels in ascending order */ + @Since("1.2.0") lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 93a7353e2c07..a7f43f0b110f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -17,19 +17,25 @@ package org.apache.spark.mllib.evaluation +import java.{lang => jl} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD /** * ::Experimental:: * Evaluator for ranking algorithms. * + * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. + * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ +@Since("1.2.0") @Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) extends Logging with Serializable { @@ -51,6 +57,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions */ + @Since("1.2.0") def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -71,7 +78,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] logWarning("Empty ground truth set, check input data") 0.0 } - }.mean + }.mean() } /** @@ -100,7 +107,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] logWarning("Empty ground truth set, check input data") 0.0 } - }.mean + }.mean() } /** @@ -120,6 +127,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions */ + @Since("1.2.0") def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -146,7 +154,24 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] logWarning("Empty ground truth set, check input data") 0.0 } - }.mean + }.mean() } } + +@Experimental +object RankingMetrics { + + /** + * Creates a [[RankingMetrics]] instance (for Java users). + * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs + */ + @Since("1.4.0") + def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { + implicit val tag = JavaSparkContext.fakeClassTag[E] + val rdd = predictionAndLabels.rdd.map { case (predictions, labels) => + (predictions.asScala.toArray, labels.asScala.toArray) + } + new RankingMetrics(rdd) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 693117d82058..799ebb980ef0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -17,11 +17,12 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.sql.DataFrame /** * :: Experimental :: @@ -29,8 +30,18 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate * * @param predictionAndObservations an RDD of (prediction, observation) pairs. */ +@Since("1.2.0") @Experimental -class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("1.2.0") ( + predictionAndObservations: RDD[(Double, Double)]) extends Logging { + + /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndObservations a DataFrame with two double columns: + * prediction and observation + */ + private[mllib] def this(predictionAndObservations: DataFrame) = + this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. @@ -44,20 +55,30 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend ) summary } + private lazy val SSerr = math.pow(summary.normL2(1), 2) + private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SSreg = { + val yMean = summary.mean(0) + predictionAndObservations.map { + case (prediction, _) => math.pow(prediction - yMean, 2) + }.sum() + } /** - * Returns the explained variance regression score. - * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) - * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * Returns the variance explained by regression. + * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n + * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ + @Since("1.2.0") def explainedVariance: Double = { - 1 - summary.variance(1) / summary.variance(0) + SSreg / summary.count } /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. */ + @Since("1.2.0") def meanAbsoluteError: Double = { summary.normL1(1) / summary.count } @@ -66,24 +87,26 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. */ + @Since("1.2.0") def meanSquaredError: Double = { - val rmse = summary.normL2(1) / math.sqrt(summary.count) - rmse * rmse + SSerr / summary.count } /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. */ + @Since("1.2.0") def rootMeanSquaredError: Double = { - summary.normL2(1) / math.sqrt(summary.count) + math.sqrt(this.meanSquaredError) } /** - * Returns R^2^, the coefficient of determination. - * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * Returns R^2^, the unadjusted coefficient of determination. + * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ + @Since("1.2.0") def r2: Double = { - 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + 1 - SSerr / SStot } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c6057c7f837b..4743cfd1a2c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics @@ -31,14 +31,17 @@ import org.apache.spark.rdd.RDD * * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ +@Since("1.3.0") @Experimental -class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { +class ChiSqSelectorModel @Since("1.3.0") ( + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { require(isSorted(selectedFeatures), "Array has to be sorted asc") protected def isSorted(array: Array[Int]): Boolean = { var i = 1 - while (i < array.length) { + val len = array.length + while (i < len) { if (array(i) < array(i-1)) return false i += 1 } @@ -51,6 +54,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.3.0") override def transform(vector: Vector): Vector = { compress(vector, selectedFeatures) } @@ -106,8 +110,10 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) */ +@Since("1.3.0") @Experimental -class ChiSqSelector (val numTopFeatures: Int) { +class ChiSqSelector @Since("1.3.0") ( + @Since("1.3.0") val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. @@ -116,6 +122,7 @@ class ChiSqSelector (val numTopFeatures: Int) { * Real-valued features will be treated as categorical for each distinct value. * Apply feature discretizer before using this function. */ + @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val indices = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala new file mode 100644 index 000000000000..d0a6cf61687a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.linalg._ + +/** + * :: Experimental :: + * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a + * provided "weight" vector. In other words, it scales each column of the dataset by a scalar + * multiplier. + * @param scalingVec The values used to scale the reference vector's individual components. + */ +@Since("1.4.0") +@Experimental +class ElementwiseProduct @Since("1.4.0") ( + @Since("1.4.0") val scalingVec: Vector) extends VectorTransformer { + + /** + * Does the hadamard product transformation. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + @Since("1.4.0") + override def transform(vector: Vector): Vector = { + require(vector.size == scalingVec.size, + s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") + vector match { + case dv: DenseVector => + val values: Array[Double] = dv.values.clone() + val dim = scalingVec.size + var i = 0 + while (i < dim) { + values(i) *= scalingVec(i) + i += 1 + } + Vectors.dense(values) + case SparseVector(size, indices, vs) => + val values = vs.clone() + val dim = values.length + var i = 0 + while (i < dim) { + values(i) *= scalingVec(indices(i)) + i += 1 + } + Vectors.sparse(size, indices, values) + case v => throw new IllegalArgumentException("Does not support vector type " + v.getClass) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index c53475818395..e47d524b6162 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -22,7 +22,7 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -34,19 +34,25 @@ import org.apache.spark.util.Utils * * @param numFeatures number of features (default: 2^20^) */ +@Since("1.1.0") @Experimental class HashingTF(val numFeatures: Int) extends Serializable { + /** + */ + @Since("1.1.0") def this() = this(1 << 20) /** * Returns the index of the input term. */ + @Since("1.1.0") def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) /** * Transforms the input document into a sparse term frequency vector. */ + @Since("1.1.0") def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] document.foreach { term => @@ -59,6 +65,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document into a sparse term frequency vector (Java version). */ + @Since("1.1.0") def transform(document: JavaIterable[_]): Vector = { transform(document.asScala) } @@ -66,6 +73,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors. */ + @Since("1.1.0") def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { dataset.map(this.transform) } @@ -73,6 +81,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors (Java version). */ + @Since("1.1.0") def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { dataset.rdd.map(this.transform).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index a89eea0e21be..68078ccfa3d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -37,9 +37,11 @@ import org.apache.spark.rdd.RDD * @param minDocFreq minimum of documents in which a term * should appear for filtering */ +@Since("1.1.0") @Experimental -class IDF(val minDocFreq: Int) { +class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) { + @Since("1.1.0") def this() = this(0) // TODO: Allow different IDF formulations. @@ -48,6 +50,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: RDD[Vector]): IDFModel = { val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( minDocFreq = minDocFreq))( @@ -61,6 +64,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } @@ -144,7 +148,7 @@ private object IDF { * Since arrays are initialized to 0 by default, * we just omit changing those entries. */ - if(df(j) >= minDocFreq) { + if (df(j) >= minDocFreq) { inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) } j += 1 @@ -159,7 +163,8 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[mllib] (val idf: Vector) extends Serializable { +@Since("1.1.0") +class IDFModel private[spark] (@Since("1.1.0") val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -171,6 +176,7 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable { * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: RDD[Vector]): RDD[Vector] = { val bcIdf = dataset.context.broadcast(idf) dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v))) @@ -182,6 +188,7 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable { * @param v a term frequency vector * @return a TF-IDF vector */ + @Since("1.3.0") def transform(v: Vector): Vector = IDFModel.transform(idf, v) /** @@ -189,6 +196,7 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable { * @param dataset a JavaRDD of term frequency vectors * @return a JavaRDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { transform(dataset.rdd).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 32848e039eb8..8d5a22520d6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** @@ -31,9 +31,11 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors * * @param p Normalization in L^p^ space, p = 2 by default. */ +@Since("1.1.0") @Experimental -class Normalizer(p: Double) extends VectorTransformer { +class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { + @Since("1.1.0") def this() = this(2) require(p >= 1.0) @@ -44,6 +46,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @param vector vector to be normalized. * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { val norm = Vectors.norm(vector, p) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala new file mode 100644 index 000000000000..ecb3c1e6c1c8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * A feature transformer that projects vectors to a low-dimensional space using PCA. + * + * @param k number of principal components + */ +@Since("1.4.0") +class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { + require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") + + /** + * Computes a [[PCAModel]] that contains the principal components of the input vectors. + * + * @param sources source vectors + */ + @Since("1.4.0") + def fit(sources: RDD[Vector]): PCAModel = { + require(k <= sources.first().size, + s"source vector size is ${sources.first().size} must be greater than k=$k") + + val mat = new RowMatrix(sources) + val pc = mat.computePrincipalComponents(k) match { + case dm: DenseMatrix => + dm + case sm: SparseMatrix => + /* Convert a sparse matrix to dense. + * + * RowMatrix.computePrincipalComponents always returns a dense matrix. + * The following code is a safeguard. + */ + sm.toDense + case m => + throw new IllegalArgumentException("Unsupported matrix format. Expected " + + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") + + } + new PCAModel(k, pc) + } + + /** + * Java-friendly version of [[fit()]] + */ + @Since("1.4.0") + def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd) +} + +/** + * Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + * + * @param k number of principal components. + * @param pc a principal components Matrix. Each column is one principal component. + */ +@Since("1.4.0") +class PCAModel private[spark] ( + @Since("1.4.0") val k: Int, + @Since("1.4.0") val pc: DenseMatrix) extends VectorTransformer { + /** + * Transform a vector by computed Principal Components. + * + * @param vector vector to be transformed. + * Vector must be the same length as the source vectors given to [[PCA.fit()]]. + * @return transformed vector. Vector will be of length k. + */ + @Since("1.4.0") + override def transform(vector: Vector): Vector = { + vector match { + case dv: DenseVector => + pc.transpose.multiply(dv) + case SparseVector(size, indices, values) => + /* SparseVector -> single row SparseMatrix */ + val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose + val projection = sm.multiply(pc) + Vectors.dense(projection.values) + case _ => + throw new IllegalArgumentException("Unsupported vector format. Expected " + + s"SparseVector or DenseVector. Instead got: ${vector.getClass}") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 6ae6917eae59..f018b453bae7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -32,9 +32,11 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ +@Since("1.1.0") @Experimental -class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { +class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) extends Logging { + @Since("1.1.0") def this() = this(false, true) if (!(withMean || withStd)) { @@ -47,6 +49,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param data The data used to compute the mean and variance to build the transformation model. * @return a StandardScalarModel */ + @Since("1.1.0") def fit(data: RDD[Vector]): StandardScalerModel = { // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( @@ -69,13 +72,17 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param withStd whether to scale the data to have unit standard deviation * @param withMean whether to center the data before scaling */ +@Since("1.1.0") @Experimental -class StandardScalerModel ( - val std: Vector, - val mean: Vector, - var withStd: Boolean, - var withMean: Boolean) extends VectorTransformer { +class StandardScalerModel @Since("1.3.0") ( + @Since("1.3.0") val std: Vector, + @Since("1.1.0") val mean: Vector, + @Since("1.3.0") var withStd: Boolean, + @Since("1.3.0") var withMean: Boolean) extends VectorTransformer { + /** + */ + @Since("1.3.0") def this(std: Vector, mean: Vector) { this(std, mean, withStd = std != null, withMean = mean != null) require(this.withStd || this.withMean, @@ -86,15 +93,18 @@ class StandardScalerModel ( } } + @Since("1.3.0") def this(std: Vector) = this(std, null) + @Since("1.3.0") @DeveloperApi def setWithMean(withMean: Boolean): this.type = { - require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") this.withMean = withMean this } + @Since("1.3.0") @DeveloperApi def setWithStd(withStd: Boolean): this.type = { require(!(withStd && this.std == null), @@ -115,6 +125,7 @@ class StandardScalerModel ( * @return Standardized vector. If the std of a column is zero, it will return default `0.0` * for the column with zero std. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 7358c1c84f79..5778fd1d0925 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD * :: DeveloperApi :: * Trait for transformation of a vector */ +@Since("1.1.0") @DeveloperApi trait VectorTransformer extends Serializable { @@ -35,6 +36,7 @@ trait VectorTransformer extends Serializable { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.1.0") def transform(vector: Vector): Vector /** @@ -43,6 +45,7 @@ trait VectorTransformer extends Serializable { * @param data RDD[Vector] to be transformed. * @return transformed RDD[Vector]. */ + @Since("1.1.0") def transform(data: RDD[Vector]): RDD[Vector] = { // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. // So it should be no longer necessary to explicitly broadcast `this` object. @@ -55,6 +58,7 @@ trait VectorTransformer extends Serializable { * @param data JavaRDD[Vector] to be transformed. * @return transformed JavaRDD[Vector]. */ + @Since("1.1.0") def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = { transform(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a3e40200bc06..36b124c5d296 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -21,48 +21,56 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.sql.{SQLContext, Row} /** - * Entry in vocabulary + * Entry in vocabulary */ private case class VocabWord( var word: String, var cn: Int, var point: Array[Int], var code: Array[Int], - var codeLen:Int + var codeLen: Int ) /** * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus - * and then learns vector representation of words in the vocabulary. - * The vector representation can be used as features in + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in * natural language processing and machine learning algorithms. - * - * We used skip-gram model in our implementation and hierarchical softmax + * + * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation * matches the original C implementation. * - * For original C implementation, see https://code.google.com/p/word2vec/ - * For research papers, see + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see * Efficient Estimation of Word Representations in Vector Space - * and + * and * Distributed Representations of Words and Phrases and their Compositionality. */ +@Since("1.1.0") @Experimental class Word2Vec extends Serializable with Logging { @@ -72,10 +80,11 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 - + /** * Sets vector size (default: 100). */ + @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { this.vectorSize = vectorSize this @@ -84,6 +93,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets initial learning rate (default: 0.025). */ + @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { this.learningRate = learningRate this @@ -92,6 +102,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets number of partitions (default: 1). Use a small number for accuracy. */ + @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") this.numPartitions = numPartitions @@ -102,6 +113,7 @@ class Word2Vec extends Serializable with Logging { * Sets number of iterations (default: 1), which should be smaller than or equal to number of * partitions. */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.numIterations = numIterations this @@ -110,20 +122,22 @@ class Word2Vec extends Serializable with Logging { /** * Sets random seed (default: a random long integer). */ + @Since("1.1.0") def setSeed(seed: Long): this.type = { this.seed = seed this } - /** - * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ + @Since("1.3.0") def setMinCount(minCount: Int): this.type = { this.minCount = minCount this } - + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -143,14 +157,17 @@ class Word2Vec extends Serializable with Logging { .map(x => VocabWord( x._1, x._2, - new Array[Int](MAX_CODE_LENGTH), - new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) - + vocabSize = vocab.length + require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + + "the setting of minCount, which could be large enough to remove all your words in sentences.") + var a = 0 while (a < vocabSize) { vocabHash += vocab(a).word -> a @@ -188,8 +205,8 @@ class Word2Vec extends Serializable with Logging { } var pos1 = vocabSize - 1 var pos2 = vocabSize - - var min1i = 0 + + var min1i = 0 var min2i = 0 a = 0 @@ -253,26 +270,27 @@ class Word2Vec extends Serializable with Logging { * @param dataset an RDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) learnVocab(words) - + createBinaryTree() - + val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => new Iterator[Array[Int]] { def hasNext: Boolean = iter.hasNext def next(): Array[Int] = { - var sentence = new ArrayBuffer[Int] + val sentence = ArrayBuilder.make[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { val word = bcVocabHash.value.get(iter.next()) @@ -283,11 +301,11 @@ class Word2Vec extends Serializable with Logging { case None => } } - sentence.toArray + sentence.result() } } } - + val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) @@ -392,18 +410,9 @@ class Word2Vec extends Serializable with Logging { } } newSentences.unpersist() - - val word2VecMap = mutable.HashMap.empty[String, Array[Float]] - var i = 0 - while (i < vocabSize) { - val word = bcVocab.value(i).word - val vector = new Array[Float](vectorSize) - Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) - word2VecMap += word -> vector - i += 1 - } - new Word2VecModel(word2VecMap.toMap) + val wordArray = vocab.map(_.word) + new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) } /** @@ -411,6 +420,7 @@ class Word2Vec extends Serializable with Logging { * @param dataset a JavaRDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { fit(dataset.rdd.map(_.asScala)) } @@ -419,10 +429,45 @@ class Word2Vec extends Serializable with Logging { /** * :: Experimental :: * Word2Vec model + * @param wordIndex maps each word to an index, which can retrieve the corresponding + * vector from wordVectors + * @param wordVectors array of length numWords * vectorSize, vector corresponding + * to the word mapped with index i can be retrieved by the slice + * (i * vectorSize, i * vectorSize + vectorSize) */ @Experimental +@Since("1.1.0") class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable { + private val wordIndex: Map[String, Int], + private val wordVectors: Array[Float]) extends Serializable with Saveable { + + private val numWords = wordIndex.size + // vectorSize: Dimension of each word's vector. + private val vectorSize = wordVectors.length / numWords + + // wordList: Ordered list of words obtained from wordIndex. + private val wordList: Array[String] = { + val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip + wl.toArray + } + + // wordVecNorms: Array of length numWords, each value being the Euclidean norm + // of the wordVector. + private val wordVecNorms: Array[Double] = { + val wordVecNorms = new Array[Double](numWords) + var i = 0 + while (i < numWords) { + val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) + wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) + i += 1 + } + wordVecNorms + } + + @Since("1.5.0") + def this(model: Map[String, Array[Float]]) = { + this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) + } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -430,17 +475,26 @@ class Word2VecModel private[mllib] ( val norm1 = blas.snrm2(n, v1, 1) val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 + } + + override protected def formatVersion = "1.0" + + @Since("1.4.0") + def save(sc: SparkContext, path: String): Unit = { + Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) } - + /** * Transforms a word to its vector representation - * @param word a word + * @param word a word * @return vector representation of word */ + @Since("1.1.0") def transform(word: String): Vector = { - model.get(word) match { - case Some(vec) => + wordIndex.get(word) match { + case Some(ind) => + val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) Vectors.dense(vec.map(_.toDouble)) case None => throw new IllegalStateException(s"$word not in vocabulary") @@ -450,36 +504,142 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of a word * @param word a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector,num) + findSynonyms(vector, num) } /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) - model.mapValues(vec => cosineSimilarity(fVector, vec)) + val cosineVec = Array.fill[Float](numWords)(0) + val alpha: Float = 1 + val beta: Float = 0 + + blas.sgemv( + "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) + + // Need not divide with the norm of the given vector since it is constant. + val cosVec = cosineVec.map(_.toDouble) + var ind = 0 + while (ind < numWords) { + cosVec(ind) /= wordVecNorms(ind) + ind += 1 + } + wordList.zip(cosVec) .toSeq .sortBy(- _._2) .take(num + 1) .tail .toArray } - + /** * Returns a map of words to their vector representations. */ + @Since("1.2.0") def getVectors: Map[String, Array[Float]] = { - model + wordIndex.map { case (word, ind) => + (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) + } + } +} + +@Since("1.4.0") +@Experimental +object Word2VecModel extends Loader[Word2VecModel] { + + private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { + model.keys.zipWithIndex.toMap + } + + private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { + require(model.nonEmpty, "Word2VecMap should be non-empty") + val (vectorSize, numWords) = (model.head._2.size, model.size) + val wordList = model.keys.toArray + val wordVectors = new Array[Float](vectorSize * numWords) + var i = 0 + while (i < numWords) { + Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize) + i += 1 + } + wordVectors + } + + private object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel" + + case class Data(word: String, vector: Array[Float]) + + def load(sc: SparkContext, path: String): Word2VecModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + + val dataArray = dataFrame.select("word", "vector").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap + new Word2VecModel(word2VecMap) + } + + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val vectorSize = model.values.head.size + val numWords = model.size + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } + sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) + } + } + + @Since("1.4.0") + override def load(sc: SparkContext, path: String): Word2VecModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedVectorSize = (metadata \ "vectorSize").extract[Int] + val expectedNumWords = (metadata \ "numWords").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, loadedVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path) + val vectorSize = model.getVectors.values.head.size + val numWords = model.getVectors.size + require(expectedVectorSize == vectorSize, + s"Word2VecModel requires each word to be mapped to a vector of size " + + s"$expectedVectorSize, got vector of size $vectorSize") + require(expectedNumWords == numWords, + s"Word2VecModel requires $expectedNumWords words, but got $numWords") + model + case _ => throw new Exception( + s"Word2VecModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala new file mode 100644 index 000000000000..95c688c86a7e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.AssociationRules.Rule +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates + * association rules which have a single item as the consequent. + * + */ +@Since("1.5.0") +@Experimental +class AssociationRules private[fpm] ( + private var minConfidence: Double) extends Logging with Serializable { + + /** + * Constructs a default instance with default parameters {minConfidence = 0.8}. + */ + @Since("1.5.0") + def this() = this(0.8) + + /** + * Sets the minimal confidence (default: `0.8`). + */ + @Since("1.5.0") + def setMinConfidence(minConfidence: Double): this.type = { + require(minConfidence >= 0.0 && minConfidence <= 1.0) + this.minConfidence = minConfidence + this + } + + /** + * Computes the association rules with confidence above [[minConfidence]]. + * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] + * @return a [[Set[Rule[Item]]] containing the assocation rules. + * + */ + @Since("1.5.0") + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) + val candidates = freqItemsets.flatMap { itemset => + val items = itemset.items + items.flatMap { item => + items.partition(_ == item) match { + case (consequent, antecedent) if !antecedent.isEmpty => + Some((antecedent.toSeq, (consequent.toSeq, itemset.freq))) + case _ => None + } + } + } + + // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence + candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) + .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => + new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) + }.filter(_.confidence >= minConfidence) + } + + /** Java-friendly version of [[run]]. */ + @Since("1.5.0") + def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { + val tag = fakeClassTag[Item] + run(freqItemsets.rdd)(tag) + } +} + +@Since("1.5.0") +object AssociationRules { + + /** + * :: Experimental :: + * + * An association rule between sets of items. + * @param antecedent hypotheses of the rule. Java users should call [[Rule#javaAntecedent]] + * instead. + * @param consequent conclusion of the rule. Java users should call [[Rule#javaConsequent]] + * instead. + * @tparam Item item type + * + */ + @Since("1.5.0") + @Experimental + class Rule[Item] private[fpm] ( + @Since("1.5.0") val antecedent: Array[Item], + @Since("1.5.0") val consequent: Array[Item], + freqUnion: Double, + freqAntecedent: Double) extends Serializable { + + /** + * Returns the confidence of the rule. + * + */ + @Since("1.5.0") + def confidence: Double = freqUnion.toDouble / freqAntecedent + + require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { + val sharedItems = antecedent.toSet.intersect(consequent.toSet) + s"A valid association rule must have disjoint antecedent and " + + s"consequent but ${sharedItems} is present in both." + }) + + /** + * Returns antecedent in a Java List. + * + */ + @Since("1.5.0") + def javaAntecedent: java.util.List[Item] = { + antecedent.toList.asJava + } + + /** + * Returns consequent in a Java List. + * + */ + @Since("1.5.0") + def javaConsequent: java.util.List[Item] = { + consequent.toList.asJava + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9591c7966e06..aea5c4f8a8a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -18,39 +18,80 @@ package org.apache.spark.mllib.fpm import java.{util => ju} +import java.lang.{Iterable => JavaIterable} import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag -import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.FPGrowth._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable +/** + * :: Experimental :: + * + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @tparam Item item type + * + */ +@Since("1.3.0") +@Experimental +class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + /** + * Generates association rules for the [[Item]]s in [[freqItemsets]]. + * @param confidence minimal confidence of the rules produced + */ + @Since("1.5.0") + def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { + val associationRules = new AssociationRules(confidence) + associationRules.run(freqItemsets) + } +} /** - * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. - * Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an - * independent group of mining tasks. More detail of this algorithm can be found at - * [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be - * found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]] + * :: Experimental :: + * + * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + * [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query + * Recommendation]]. PFP distributes computation in such a way that each worker executes an + * independent group of mining tasks. The FP-Growth algorithm is described in + * [[http://dx.doi.org/10.1145/335191.335372 Han et al., Mining frequent patterns without candidate + * generation]]. * * @param minSupport the minimal support level of the frequent pattern, any pattern appears * more than (minSupport * size-of-the-dataset) times will be output * @param numPartitions number of partitions used by parallel FP-growth + * + * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning + * (Wikipedia)]] + * */ +@Since("1.3.0") +@Experimental class FPGrowth private ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { /** - * Constructs a FPGrowth instance with default parameters: - * {minSupport: 0.3, numPartitions: auto} + * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same + * as the input data}. + * */ + @Since("1.3.0") def this() = this(0.3, -1) /** - * Sets the minimal support level (default: 0.3). + * Sets the minimal support level (default: `0.3`). + * */ + @Since("1.3.0") def setMinSupport(minSupport: Double): this.type = { this.minSupport = minSupport this @@ -58,7 +99,9 @@ class FPGrowth private ( /** * Sets the number of partitions used by parallel FP-growth (default: same as input data). + * */ + @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { this.numPartitions = numPartitions this @@ -68,8 +111,10 @@ class FPGrowth private ( * Computes an FP-Growth model that contains frequent itemsets. * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] + * */ - def run(data: RDD[Array[String]]): FPGrowthModel = { + @Since("1.3.0") + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -82,19 +127,26 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } + /** Java-friendly version of [[run]]. */ + @Since("1.3.0") + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) + } + /** * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems( - data: RDD[Array[String]], + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, - partitioner: Partitioner): Array[String] = { + partitioner: Partitioner): Array[Item] = { data.flatMap { t => val uniq = t.toSet - if (t.length != uniq.size) { + if (t.size != uniq.size) { throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") } t @@ -114,11 +166,11 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets( - data: RDD[Array[String]], + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, - freqItems: Array[String], - partitioner: Partitioner): RDD[(Array[String], Long)] = { + freqItems: Array[Item], + partitioner: Partitioner): RDD[FreqItemset[Item]] = { val itemToRank = freqItems.zipWithIndex.toMap data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) @@ -128,7 +180,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - (ranks.map(i => freqItems(i)).toArray, count) + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) } } @@ -139,9 +191,9 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions( - transaction: Array[String], - itemToRank: Map[String, Int], + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], + itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. @@ -160,3 +212,34 @@ class FPGrowth private ( output } } + +/** + * :: Experimental :: + * + */ +@Since("1.3.0") +@Experimental +object FPGrowth { + + /** + * Frequent itemset. + * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. + * @param freq frequency + * @tparam Item item type + * + */ + @Since("1.3.0") + class FreqItemset[Item] @Since("1.3.0") ( + @Since("1.3.0") val items: Array[Item], + @Since("1.3.0") val freq: Long) extends Serializable { + + /** + * Returns items in a Java List. + * + */ + @Since("1.3.0") + def javaItems: java.util.List[Item] = { + items.toList.asJava + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala new file mode 100644 index 000000000000..3ea10779a183 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.collection.mutable + +import org.apache.spark.Logging + +/** + * Calculate all patterns of a projected database in local mode. + * + * @param minCount minimal count for a frequent pattern + * @param maxPatternLength max pattern length for a frequent pattern + */ +private[fpm] class LocalPrefixSpan( + val minCount: Long, + val maxPatternLength: Int) extends Logging with Serializable { + import PrefixSpan.Postfix + import LocalPrefixSpan.ReversedPrefix + + /** + * Generates frequent patterns on the input array of postfixes. + * @param postfixes an array of postfixes + * @return an iterator of (frequent pattern, count) + */ + def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = { + genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix, count) => + (prefix.toSequence, count) + } + } + + /** + * Recursively generates frequent patterns. + * @param prefix current prefix + * @param postfixes projected postfixes w.r.t. the prefix + * @return an iterator of (prefix, count) + */ + private def genFreqPatterns( + prefix: ReversedPrefix, + postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = { + if (maxPatternLength == prefix.length || postfixes.length < minCount) { + return Iterator.empty + } + // find frequent items + val counts = mutable.Map.empty[Int, Long].withDefaultValue(0) + postfixes.foreach { postfix => + postfix.genPrefixItems.foreach { case (x, _) => + counts(x) += 1L + } + } + val freqItems = counts.toSeq.filter { case (_, count) => + count >= minCount + }.sorted + // project and recursively call genFreqPatterns + freqItems.toIterator.flatMap { case (item, count) => + val newPrefix = prefix :+ item + Iterator.single((newPrefix, count)) ++ { + val projected = postfixes.map(_.project(item)).filter(_.nonEmpty) + genFreqPatterns(newPrefix, projected) + } + } + } +} + +private object LocalPrefixSpan { + + /** + * Represents a prefix stored as a list in reversed order. + * @param items items in the prefix in reversed order + * @param length length of the prefix, not counting delimiters + */ + class ReversedPrefix private (val items: List[Int], val length: Int) extends Serializable { + /** + * Expands the prefix by one item. + */ + def :+(item: Int): ReversedPrefix = { + require(item != 0) + if (item < 0) { + new ReversedPrefix(-item :: items, length + 1) + } else { + new ReversedPrefix(item :: 0 :: items, length + 1) + } + } + + /** + * Converts this prefix to a sequence. + */ + def toSequence: Array[Int] = (0 :: items).toArray.reverse + } + + object ReversedPrefix { + /** An empty prefix. */ + val empty: ReversedPrefix = new ReversedPrefix(List.empty, 0) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala new file mode 100644 index 000000000000..97916daa2e9a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import java.{lang => jl, util => ju} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]). + * + * @param minSupport the minimal support level of the sequential pattern, any pattern appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears + * less than maxPatternLength will be output + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal + * storage format) allowed in a projected database before local + * processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run. + * + * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining + * (Wikipedia)]] + */ +@Experimental +@Since("1.5.0") +class PrefixSpan private ( + private var minSupport: Double, + private var maxPatternLength: Int, + private var maxLocalProjDBSize: Long) extends Logging with Serializable { + import PrefixSpan._ + + /** + * Constructs a default instance with default parameters + * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}. + */ + @Since("1.5.0") + def this() = this(0.1, 10, 32000000L) + + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + @Since("1.5.0") + def getMinSupport: Double = minSupport + + /** + * Sets the minimal support level (default: `0.1`). + */ + @Since("1.5.0") + def setMinSupport(minSupport: Double): this.type = { + require(minSupport >= 0 && minSupport <= 1, + s"The minimum support value must be in [0, 1], but got $minSupport.") + this.minSupport = minSupport + this + } + + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + @Since("1.5.0") + def getMaxPatternLength: Int = maxPatternLength + + /** + * Sets maximal pattern length (default: `10`). + */ + @Since("1.5.0") + def setMaxPatternLength(maxPatternLength: Int): this.type = { + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, + s"The maximum pattern length value must be greater than 0, but got $maxPatternLength.") + this.maxPatternLength = maxPatternLength + this + } + + /** + * Gets the maximum number of items allowed in a projected database before local processing. + */ + @Since("1.5.0") + def getMaxLocalProjDBSize: Long = maxLocalProjDBSize + + /** + * Sets the maximum number of items (including delimiters used in the internal storage format) + * allowed in a projected database before local processing (default: `32000000L`). + */ + @Since("1.5.0") + def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = { + require(maxLocalProjDBSize >= 0L, + s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize") + this.maxLocalProjDBSize = maxLocalProjDBSize + this + } + + /** + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * @param data sequences of itemsets. + * @return a [[PrefixSpanModel]] that contains the frequent patterns + */ + @Since("1.5.0") + def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = { + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") + } + + val totalCount = data.count() + logInfo(s"number of sequences: $totalCount") + val minCount = math.ceil(minSupport * totalCount).toLong + logInfo(s"minimum count for a frequent pattern: $minCount") + + // Find frequent items. + val freqItemAndCounts = data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach { _.foreach { item => + uniqItems += item + }} + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _) + .filter { case (_, count) => + count >= minCount + }.collect() + val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + logInfo(s"number of frequent items: ${freqItems.length}") + + // Keep only frequent items from input sequences and convert them to internal storage. + val itemToInt = freqItems.zipWithIndex.toMap + val dataInternalRepr = data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + } + allItems += 0 + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + }.persist(StorageLevel.MEMORY_AND_DISK) + + val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) + + def toPublicRepr(pattern: Array[Int]): Array[Array[Item]] = { + val sequenceBuilder = mutable.ArrayBuilder.make[Array[Item]] + val itemsetBuilder = mutable.ArrayBuilder.make[Item] + val n = pattern.length + var i = 1 + while (i < n) { + val x = pattern(i) + if (x == 0) { + sequenceBuilder += itemsetBuilder.result() + itemsetBuilder.clear() + } else { + itemsetBuilder += freqItems(x - 1) // using 1-indexing in internal format + } + i += 1 + } + sequenceBuilder.result() + } + + val freqSequences = results.map { case (seq: Array[Int], count: Long) => + new FreqSequence(toPublicRepr(seq), count) + } + new PrefixSpanModel(freqSequences) + } + + /** + * A Java-friendly version of [[run()]] that reads sequences from a [[JavaRDD]] and returns + * frequent sequences in a [[PrefixSpanModel]]. + * @param data ordered sequences of itemsets stored as Java Iterable of Iterables + * @tparam Item item type + * @tparam Itemset itemset type, which is an Iterable of Items + * @tparam Sequence sequence type, which is an Iterable of Itemsets + * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns + */ + @Since("1.5.0") + def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]]( + data: JavaRDD[Sequence]): PrefixSpanModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.map(_.asScala.toArray).toArray)) + } + +} + +@Experimental +@Since("1.5.0") +object PrefixSpan extends Logging { + + /** + * Find the complete set of frequent sequential patterns in the input sequences. + * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], + * where each itemset is represented by a contiguous sequence of distinct and ordered + * positive integers. We use 0 as the delimiter at itemset boundaries, including the + * first and the last position. + * @return an RDD of (frequent sequential pattern, count) pairs, + * @see [[Postfix]] + */ + private[fpm] def genFreqPatterns( + data: RDD[Array[Int]], + minCount: Long, + maxPatternLength: Int, + maxLocalProjDBSize: Long): RDD[(Array[Int], Long)] = { + val sc = data.sparkContext + + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") + } + + val postfixes = data.map(items => new Postfix(items)) + + // Local frequent patterns (prefixes) and their counts. + val localFreqPatterns = mutable.ArrayBuffer.empty[(Array[Int], Long)] + // Prefixes whose projected databases are small. + val smallPrefixes = mutable.Map.empty[Int, Prefix] + val emptyPrefix = Prefix.empty + // Prefixes whose projected databases are large. + var largePrefixes = mutable.Map(emptyPrefix.id -> emptyPrefix) + while (largePrefixes.nonEmpty) { + val numLocalFreqPatterns = localFreqPatterns.length + logInfo(s"number of local frequent patterns: $numLocalFreqPatterns") + if (numLocalFreqPatterns > 1000000) { + logWarning( + s""" + | Collected $numLocalFreqPatterns local frequent patterns. You may want to consider: + | 1. increase minSupport, + | 2. decrease maxPatternLength, + | 3. increase maxLocalProjDBSize. + """.stripMargin) + } + logInfo(s"number of small prefixes: ${smallPrefixes.size}") + logInfo(s"number of large prefixes: ${largePrefixes.size}") + val largePrefixArray = largePrefixes.values.toArray + val freqPrefixes = postfixes.flatMap { postfix => + largePrefixArray.flatMap { prefix => + postfix.project(prefix).genPrefixItems.map { case (item, postfixSize) => + ((prefix.id, item), (1L, postfixSize)) + } + } + }.reduceByKey { case ((c0, s0), (c1, s1)) => + (c0 + c1, s0 + s1) + }.filter { case (_, (c, _)) => c >= minCount } + .collect() + val newLargePrefixes = mutable.Map.empty[Int, Prefix] + freqPrefixes.foreach { case ((id, item), (count, projDBSize)) => + val newPrefix = largePrefixes(id) :+ item + localFreqPatterns += ((newPrefix.items :+ 0, count)) + if (newPrefix.length < maxPatternLength) { + if (projDBSize > maxLocalProjDBSize) { + newLargePrefixes += newPrefix.id -> newPrefix + } else { + smallPrefixes += newPrefix.id -> newPrefix + } + } + } + largePrefixes = newLargePrefixes + } + + var freqPatterns = sc.parallelize(localFreqPatterns, 1) + + val numSmallPrefixes = smallPrefixes.size + logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") + if (numSmallPrefixes > 0) { + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } + } + // Union local frequent patterns and distributed ones. + freqPatterns = freqPatterns ++ distributedFreqPattern + } + + freqPatterns + } + + /** + * Represents a prefix. + * @param items items in this prefix, using the internal format + * @param length length of this prefix, not counting 0 + */ + private[fpm] class Prefix private (val items: Array[Int], val length: Int) extends Serializable { + + /** A unique id for this prefix. */ + val id: Int = Prefix.nextId + + /** Expands this prefix by the input item. */ + def :+(item: Int): Prefix = { + require(item != 0) + if (item < 0) { + new Prefix(items :+ -item, length + 1) + } else { + new Prefix(items ++ Array(0, item), length + 1) + } + } + } + + private[fpm] object Prefix { + /** Internal counter to generate unique IDs. */ + private val counter: AtomicInteger = new AtomicInteger(-1) + + /** Gets the next unique ID. */ + private def nextId: Int = counter.incrementAndGet() + + /** An empty [[Prefix]] instance. */ + val empty: Prefix = new Prefix(Array.empty, 0) + } + + /** + * An internal representation of a postfix from some projection. + * We use one int array to store the items, which might also contains other items from the + * original sequence. + * Items are represented by positive integers, and items in each itemset must be distinct and + * ordered. + * we use 0 as the delimiter between itemsets. + * For example, a sequence `<(12)(31)1>` is represented by `[0, 1, 2, 0, 1, 3, 0, 1, 0]`. + * The postfix of this sequence w.r.t. to prefix `<1>` is `<(_2)(13)1>`. + * We may reuse the original items array `[0, 1, 2, 0, 1, 3, 0, 1, 0]` to represent the postfix, + * and mark the start index of the postfix, which is `2` in this example. + * So the active items in this postfix are `[2, 0, 1, 3, 0, 1, 0]`. + * We also remember the start indices of partial projections, the ones that split an itemset. + * For example, another possible partial projection w.r.t. `<1>` is `<(_3)1>`. + * We remember the start indices of partial projections, which is `[2, 5]` in this example. + * This data structure makes it easier to do projections. + * + * @param items a sequence stored as `Array[Int]` containing this postfix + * @param start the start index of this postfix in items + * @param partialStarts start indices of possible partial projections, strictly increasing + */ + private[fpm] class Postfix( + val items: Array[Int], + val start: Int = 0, + val partialStarts: Array[Int] = Array.empty) extends Serializable { + + require(items.last == 0, s"The last item in a postfix must be zero, but got ${items.last}.") + if (partialStarts.nonEmpty) { + require(partialStarts.head >= start, + "The first partial start cannot be smaller than the start index," + + s"but got partialStarts.head = ${partialStarts.head} < start = $start.") + } + + /** + * Start index of the first full itemset contained in this postfix. + */ + private[this] def fullStart: Int = { + var i = start + while (items(i) != 0) { + i += 1 + } + i + } + + /** + * Generates length-1 prefix items of this postfix with the corresponding postfix sizes. + * There are two types of prefix items: + * a) The item can be assembled to the last itemset of the prefix. For example, + * the postfix of `<(12)(123)>1` w.r.t. `<1>` is `<(_2)(123)1>`. The prefix items of this + * postfix can be assembled to `<1>` is `_2` and `_3`, resulting new prefixes `<(12)>` and + * `<(13)>`. We flip the sign in the output to indicate that this is a partial prefix item. + * b) The item can be appended to the prefix. Taking the same example above, the prefix items + * can be appended to `<1>` is `1`, `2`, and `3`, resulting new prefixes `<11>`, `<12>`, + * and `<13>`. + * @return an iterator of (prefix item, corresponding postfix size). If the item is negative, it + * indicates a partial prefix item, which should be assembled to the last itemset of the + * current prefix. Otherwise, the item should be appended to the current prefix. + */ + def genPrefixItems: Iterator[(Int, Long)] = { + val n1 = items.length - 1 + // For each unique item (subject to sign) in this sequence, we output exact one split. + // TODO: use PrimitiveKeyOpenHashMap + val prefixes = mutable.Map.empty[Int, Long] + // a) items that can be assembled to the last itemset of the prefix + partialStarts.foreach { start => + var i = start + var x = -items(i) + while (x != 0) { + if (!prefixes.contains(x)) { + prefixes(x) = n1 - i + } + i += 1 + x = -items(i) + } + } + // b) items that can be appended to the prefix + var i = fullStart + while (i < n1) { + val x = items(i) + if (x != 0 && !prefixes.contains(x)) { + prefixes(x) = n1 - i + } + i += 1 + } + prefixes.toIterator + } + + /** Tests whether this postfix is non-empty. */ + def nonEmpty: Boolean = items.length > start + 1 + + /** + * Projects this postfix with respect to the input prefix item. + * @param prefix prefix item. If prefix is positive, we match items in any full itemset; if it + * is negative, we do partial projections. + * @return the projected postfix + */ + def project(prefix: Int): Postfix = { + require(prefix != 0) + val n1 = items.length - 1 + var matched = false + var newStart = n1 + val newPartialStarts = mutable.ArrayBuilder.make[Int] + if (prefix < 0) { + // Search for partial projections. + val target = -prefix + partialStarts.foreach { start => + var i = start + var x = items(i) + while (x != target && x != 0) { + i += 1 + x = items(i) + } + if (x == target) { + i += 1 + if (!matched) { + newStart = i + matched = true + } + if (items(i) != 0) { + newPartialStarts += i + } + } + } + } else { + // Search for items in full itemsets. + // Though the items are ordered in each itemsets, they should be small in practice. + // So a sequential scan is sufficient here, compared to bisection search. + val target = prefix + var i = fullStart + while (i < n1) { + val x = items(i) + if (x == target) { + if (!matched) { + newStart = i + matched = true + } + if (items(i + 1) != 0) { + newPartialStarts += i + 1 + } + } + i += 1 + } + } + new Postfix(items, newStart, newPartialStarts.result()) + } + + /** + * Projects this postfix with respect to the input prefix. + */ + private def project(prefix: Array[Int]): Postfix = { + var partial = true + var cur = this + var i = 0 + val np = prefix.length + while (i < np && cur.nonEmpty) { + val x = prefix(i) + if (x == 0) { + partial = false + } else { + if (partial) { + cur = cur.project(-x) + } else { + cur = cur.project(x) + partial = true + } + } + i += 1 + } + cur + } + + /** + * Projects this postfix with respect to the input prefix. + */ + def project(prefix: Prefix): Postfix = project(prefix.items) + + /** + * Returns the same sequence with compressed storage if possible. + */ + def compressed: Postfix = { + if (start > 0) { + new Postfix(items.slice(start, items.length), 0, partialStarts.map(_ - start)) + } else { + this + } + } + } + + /** + * Represents a frequence sequence. + * @param sequence a sequence of itemsets stored as an Array of Arrays + * @param freq frequency + * @tparam Item item type + */ + @Since("1.5.0") + class FreqSequence[Item] @Since("1.5.0") ( + @Since("1.5.0") val sequence: Array[Array[Item]], + @Since("1.5.0") val freq: Long) extends Serializable { + /** + * Returns sequence as a Java List of lists for Java users. + */ + @Since("1.5.0") + def javaSequence: ju.List[ju.List[Item]] = sequence.map(_.toList.asJava).toList.asJava + } +} + +/** + * Model fitted by [[PrefixSpan]] + * @param freqSequences frequent sequences + * @tparam Item item type + */ +@Since("1.5.0") +class PrefixSpanModel[Item] @Since("1.5.0") ( + @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) + extends Serializable diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala new file mode 100644 index 000000000000..72d3aabc9b1f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[mllib] abstract class PeriodicCheckpointer[T]( + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + /** + * Update with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.head)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** Checkpoint the Dataset */ + protected def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + protected def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + protected def persist(data: T): Unit + + /** Unpersist the Dataset */ + protected def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + protected def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + getCheckpointFiles(old).foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 76672fe51e83..11a059536c50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,11 +17,7 @@ package org.apache.spark.mllib.impl -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as * unpersisting and removing checkpoint files. * - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * Users should call update() when a new graph has been created, * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are * responsible for materializing the graph to ensure that persisting and checkpointing actually * occur. * - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * When update() is called, this does the following: * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * val cp = new PeriodicGraphCheckpointer(2, sc) * graph1.vertices.count(); graph1.edges.count() * // persisted: graph1 * cp.updateGraph(graph2) @@ -73,107 +69,30 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param currentGraph Initial graph - * @param checkpointDir The directory for storing checkpoint files * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - var currentGraph: Graph[VD, ED], - val checkpointDir: Option[String], - val checkpointInterval: Int) extends Logging { - - /** FIFO queue of past checkpointed RDDs */ - private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() - - /** FIFO queue of past persisted RDDs */ - private val persistedQueue = mutable.Queue[Graph[VD, ED]]() - - /** Number of times [[updateGraph()]] has been called */ - private var updateCount = 0 - - /** - * Spark Context for the Graphs given to this checkpointer. - * NOTE: This code assumes that only one SparkContext is used for the given graphs. - */ - private val sc = currentGraph.vertices.sparkContext - - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) { - sc.setCheckpointDir(checkpointDir.get) - } + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - updateGraph(currentGraph) + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - /** - * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the graph - * has been materialized. - * - * @param newGraph New graph created from previous graphs in the lineage. - */ - def updateGraph(newGraph: Graph[VD, ED]): Unit = { - if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { - newGraph.persist() - } - persistedQueue.enqueue(newGraph) - // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: - // Users should call [[updateGraph()]] when a new graph has been created, - // before the graph has been materialized. - while (persistedQueue.size > 3) { - val graphToUnpersist = persistedQueue.dequeue() - graphToUnpersist.unpersist(blocking = false) - } - updateCount += 1 + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - newGraph.checkpoint() - checkpointQueue.enqueue(newGraph) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (checkpointQueue.get(1).get.isCheckpointed) { - removeCheckpointFile() - } else { - canDelete = false - } - } + override protected def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + data.persist() } } - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { - removeCheckpointFile() - } - } + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - /** - * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - old.getCheckpointFiles.foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala new file mode 100644 index 000000000000..f31ed2aa90a6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(2, sc) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[mllib] class PeriodicRDDCheckpointer[T]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { + + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override protected def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 079f7ca564a9..bbbcc8436b7c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -213,9 +213,9 @@ private[spark] object BLAS extends Serializable with Logging { def scal(a: Double, x: Vector): Unit = { x match { case sx: SparseVector => - f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + f2jBLAS.dscal(sx.values.length, a, sx.values, 1) case dx: DenseVector => - f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + f2jBLAS.dscal(dx.values.length, a, dx.values, 1) case _ => throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") } @@ -228,19 +228,31 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } - + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. * @param x the vector x that contains the n elements. * @param A the symmetric matrix A. Size of n x n. */ - def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + def syr(alpha: Double, x: Vector, A: DenseMatrix) { val mA = A.numRows val nA = A.numCols - require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA") + require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + x match { + case dv: DenseVector => syr(alpha, dv, A) + case sv: SparseVector => syr(alpha, sv, A) + case _ => + throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") + } + } + + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val nA = A.numRows + val mA = A.numCols + nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) // Fill lower triangular part of A @@ -252,7 +264,27 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 } i += 1 - } + } + } + + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + val mA = A.numCols + val xIndices = x.indices + val xValues = x.values + val nnz = xValues.length + val Avalues = A.values + + var i = 0 + while (i < nnz) { + val multiplier = alpha * xValues(i) + val offset = xIndices(i) * mA + var j = 0 + while (j < nnz) { + Avalues(xIndices(j) + offset) += multiplier * xValues(j) + j += 1 + } + i += 1 + } } /** @@ -271,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging { C: DenseMatrix): Unit = { require(!C.isTransposed, "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") - if (alpha == 0.0) { - logDebug("gemm: alpha is equal to 0. Returning C.") + if (alpha == 0.0 && beta == 1.0) { + logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.") } else { A match { case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) @@ -431,37 +463,42 @@ private[spark] object BLAS extends Serializable with Logging { def gemv( alpha: Double, A: Matrix, - x: DenseVector, + x: Vector, beta: Double, y: DenseVector): Unit = { require(A.numCols == x.size, s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, - s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { - A match { - case sparse: SparseMatrix => - gemv(alpha, sparse, x, beta, y) - case dense: DenseMatrix => - gemv(alpha, dense, x, beta, y) + (A, x) match { + case (smA: SparseMatrix, dvx: DenseVector) => + gemv(alpha, smA, dvx, beta, y) + case (smA: SparseMatrix, svx: SparseVector) => + gemv(alpha, smA, svx, beta, y) + case (dmA: DenseMatrix, dvx: DenseVector) => + gemv(alpha, dmA, dvx, beta, y) + case (dmA: DenseMatrix, svx: SparseVector) => + gemv(alpha, dmA, svx, beta, y) case _ => - throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") + throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " + + s"${A.getClass} and vector type ${x.getClass}.") } } } /** * y := alpha * A * x + beta * y - * For `DenseMatrix` A. + * For `DenseMatrix` A and `DenseVector` x. */ private def gemv( alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, - y: DenseVector): Unit = { + y: DenseVector): Unit = { val tStrA = if (A.isTransposed) "T" else "N" val mA = if (!A.isTransposed) A.numRows else A.numCols val nA = if (!A.isTransposed) A.numCols else A.numRows @@ -471,14 +508,134 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * For `SparseMatrix` A. + * For `DenseMatrix` A and `SparseVector` x. + */ + private def gemv( + alpha: Double, + A: DenseMatrix, + x: SparseVector, + beta: Double, + y: DenseVector): Unit = { + val mA: Int = A.numRows + val nA: Int = A.numCols + + val Avals = A.values + + val xIndices = x.indices + val xNnz = xIndices.length + val xValues = x.values + val yValues = y.values + + if (alpha == 0.0) { + scal(beta, y) + return + } + + if (A.isTransposed) { + var rowCounterForA = 0 + while (rowCounterForA < mA) { + var sum = 0.0 + var k = 0 + while (k < xNnz) { + sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA) + k += 1 + } + yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) + rowCounterForA += 1 + } + } else { + var rowCounterForA = 0 + while (rowCounterForA < mA) { + var sum = 0.0 + var k = 0 + while (k < xNnz) { + sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA) + k += 1 + } + yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) + rowCounterForA += 1 + } + } + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A and `SparseVector` x. + */ + private def gemv( + alpha: Double, + A: SparseMatrix, + x: SparseVector, + beta: Double, + y: DenseVector): Unit = { + val xValues = x.values + val xIndices = x.indices + val xNnz = xIndices.length + + val yValues = y.values + + val mA: Int = A.numRows + val nA: Int = A.numCols + + val Avals = A.values + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices + + if (alpha == 0.0) { + scal(beta, y) + return + } + + if (A.isTransposed) { + var rowCounter = 0 + while (rowCounter < mA) { + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + var k = 0 + while (k < xNnz && i < indEnd) { + if (xIndices(k) == Acols(i)) { + sum += Avals(i) * xValues(k) + i += 1 + } + k += 1 + } + yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) + rowCounter += 1 + } + } else { + scal(beta, y) + + var colCounterForA = 0 + var k = 0 + while (colCounterForA < nA && k < xNnz) { + if (xIndices(k) == colCounterForA) { + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + + val xTemp = xValues(k) * alpha + while (i < indEnd) { + val rowIndex = Arows(i) + yValues(Arows(i)) += Avals(i) * xTemp + i += 1 + } + k += 1 + } + colCounterForA += 1 + } + } + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A and `DenseVector` x. */ private def gemv( alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, - y: DenseVector): Unit = { + y: DenseVector): Unit = { val xValues = x.values val yValues = y.values val mA: Int = A.numRows @@ -502,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - // Scale vector first if `beta` is not equal to 0.0 - if (beta != 0.0) { - scal(beta, y) - } + scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 3515461b5249..ae3ba3099c87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -79,6 +79,9 @@ private[mllib] object EigenValueDecomposition { // Mode 1: A*x = lambda*x, A symmetric iparam(6) = 1 + require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, + s"k = $k and/or n = $n are too large to compute an eigendecomposition") + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) @@ -114,7 +117,7 @@ private[mllib] object EigenValueDecomposition { info.`val` match { case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " Maximum number of iterations taken. (Refer ARPACK user guide for details)") - case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " No shifts could be applied. Try to increase NCV. " + "(Refer ARPACK user guide for details)") case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index ad7e86827b36..c02ba426fcc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,21 +23,32 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + /** * Trait for a local matrix. */ +@SQLUserDefinedType(udt = classOf[MatrixUDT]) +@Since("1.0.0") sealed trait Matrix extends Serializable { /** Number of rows. */ + @Since("1.0.0") def numRows: Int /** Number of columns. */ + @Since("1.0.0") def numCols: Int /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + @Since("1.3.0") val isTransposed: Boolean = false /** Converts to a dense array in column major. */ + @Since("1.0.0") def toArray: Array[Double] = { val newArray = new Array[Double](numRows * numCols) foreachActive { (i, j, v) => @@ -50,7 +61,8 @@ sealed trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double + @Since("1.3.0") + def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ private[mllib] def index(i: Int, j: Int): Int @@ -59,20 +71,30 @@ sealed trait Matrix extends Serializable { private[mllib] def update(i: Int, j: Int, v: Double): Unit /** Get a deep copy of the matrix. */ + @Since("1.2.0") def copy: Matrix /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + @Since("1.3.0") def transpose: Matrix /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + @Since("1.2.0") def multiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) BLAS.gemm(1.0, this, y, 0.0, C) C } - /** Convenience method for `Matrix`-`DenseVector` multiplication. */ + /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ + @Since("1.2.0") def multiply(y: DenseVector): DenseVector = { + multiply(y.asInstanceOf[Vector]) + } + + /** Convenience method for `Matrix`-`Vector` multiplication. */ + @Since("1.4.0") + def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) output @@ -81,10 +103,14 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + /** A human readable representation of the matrix with maximum lines and width */ + @Since("1.4.0") + def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ - private[mllib] def map(f: Double => Double): Matrix + private[spark] def map(f: Double => Double): Matrix /** Update all the values of this matrix using the function f. Performed in-place on the * backing array. For example, an operation such as addition or subtraction will only be @@ -100,6 +126,103 @@ sealed trait Matrix extends Serializable { * corresponding value in the matrix with type `Double`. */ private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + + /** + * Find the number of non-zero active values. + */ + @Since("1.5.0") + def numNonzeros: Int + + /** + * Find the number of values stored explicitly. These values can be zero as well. + */ + @Since("1.5.0") + def numActives: Int +} + +@DeveloperApi +private[spark] class MatrixUDT extends UserDefinedType[Matrix] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are + // set as not nullable, except values since in the future, support for binary matrices might + // be added for which values are not needed. + // the sparse matrix needs colPtrs and rowIndices, which are set as + // null, while building the dense matrix. + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("numRows", IntegerType, nullable = false), + StructField("numCols", IntegerType, nullable = false), + StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true), + StructField("isTransposed", BooleanType, nullable = false) + )) + } + + override def serialize(obj: Any): InternalRow = { + val row = new GenericMutableRow(7) + obj match { + case sm: SparseMatrix => + row.setByte(0, 0) + row.setInt(1, sm.numRows) + row.setInt(2, sm.numCols) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) + row.setBoolean(6, sm.isTransposed) + + case dm: DenseMatrix => + row.setByte(0, 1) + row.setInt(1, dm.numRows) + row.setInt(2, dm.numCols) + row.setNullAt(3) + row.setNullAt(4) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) + row.setBoolean(6, dm.isTransposed) + } + row + } + + override def deserialize(datum: Any): Matrix = { + datum match { + case row: InternalRow => + require(row.numFields == 7, + s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") + val tpe = row.getByte(0) + val numRows = row.getInt(1) + val numCols = row.getInt(2) + val values = row.getArray(5).toDoubleArray() + val isTransposed = row.getBoolean(6) + tpe match { + case 0 => + val colPtrs = row.getArray(3).toIntArray() + val rowIndices = row.getArray(4).toIntArray() + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) + case 1 => + new DenseMatrix(numRows, numCols, values, isTransposed) + } + } + } + + override def userClass: Class[Matrix] = classOf[Matrix] + + override def equals(o: Any): Boolean = { + o match { + case v: MatrixUDT => true + case _ => false + } + } + + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() + + override def typeName: String = "matrix" + + override def pyUDT: String = "pyspark.mllib.linalg.MatrixUDT" + + private[spark] override def asNullable: MatrixUDT = this } /** @@ -115,15 +238,17 @@ sealed trait Matrix extends Serializable { * * @param numRows number of rows * @param numCols number of columns - * @param values matrix entries in column major + * @param values matrix entries in column major if not transposed or in row major otherwise * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. */ -class DenseMatrix( - val numRows: Int, - val numCols: Int, - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +@Since("1.0.0") +@SQLUserDefinedType(udt = classOf[MatrixUDT]) +class DenseMatrix @Since("1.3.0") ( + @Since("1.0.0") val numRows: Int, + @Since("1.0.0") val numCols: Int, + @Since("1.0.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") @@ -143,15 +268,19 @@ class DenseMatrix( * @param numCols number of columns * @param values matrix entries in column major */ + @Since("1.0.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) - override def equals(o: Any) = o match { - case m: DenseMatrix => - m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze case _ => false } + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray) + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BDM[Double](numRows, numCols, values) @@ -163,7 +292,8 @@ class DenseMatrix( private[mllib] def apply(i: Int): Double = values(i) - private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + @Since("1.3.0") + override def apply(i: Int, j: Int): Double = values(index(i, j)) private[mllib] def index(i: Int, j: Int): Int = { if (!isTransposed) i + numRows * j else j + numCols * i @@ -173,9 +303,11 @@ class DenseMatrix( values(index(i, j)) = v } - override def copy = new DenseMatrix(numRows, numCols, values.clone()) + @Since("1.4.0") + override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) + private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), + isTransposed) private[mllib] def update(f: Double => Double): DenseMatrix = { val len = values.length @@ -187,7 +319,8 @@ class DenseMatrix( this } - override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + @Since("1.3.0") + override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { @@ -217,9 +350,18 @@ class DenseMatrix( } } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed - * set to false. */ - def toSparse(): SparseMatrix = { + @Since("1.5.0") + override def numNonzeros: Int = values.count(_ != 0) + + @Since("1.5.0") + override def numActives: Int = values.length + + /** + * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. + */ + @Since("1.3.0") + def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt @@ -246,6 +388,7 @@ class DenseMatrix( /** * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]]. */ +@Since("1.3.0") object DenseMatrix { /** @@ -254,8 +397,12 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): DenseMatrix = + @Since("1.3.0") + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + } /** * Generate a `DenseMatrix` consisting of ones. @@ -263,14 +410,19 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): DenseMatrix = + @Since("1.3.0") + def ones(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } /** * Generate an Identity Matrix in `DenseMatrix` format. * @param n number of rows and columns of the matrix * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def eye(n: Int): DenseMatrix = { val identity = DenseMatrix.zeros(n, n) var i = 0 @@ -282,24 +434,30 @@ object DenseMatrix { } /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } @@ -309,6 +467,7 @@ object DenseMatrix { * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("1.3.0") def diag(vector: Vector): DenseMatrix = { val n = vector.size val matrix = DenseMatrix.zeros(n, n) @@ -336,21 +495,23 @@ object DenseMatrix { * * @param numRows number of rows * @param numCols number of columns - * @param colPtrs the index corresponding to the start of a new column - * @param rowIndices the row index of the entry. They must be in strictly increasing order for each - * column - * @param values non-zero matrix entries in column major + * @param colPtrs the index corresponding to the start of a new column (if not transposed) + * @param rowIndices the row index of the entry (if not transposed). They must be in strictly + * increasing order for each column + * @param values nonzero matrix entries in column major (if not transposed) * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ -class SparseMatrix( - val numRows: Int, - val numCols: Int, - val colPtrs: Array[Int], - val rowIndices: Array[Int], - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +@Since("1.2.0") +@SQLUserDefinedType(udt = classOf[MatrixUDT]) +class SparseMatrix @Since("1.3.0") ( + @Since("1.2.0") val numRows: Int, + @Since("1.2.0") val numCols: Int, + @Since("1.2.0") val colPtrs: Array[Int], + @Since("1.2.0") val rowIndices: Array[Int], + @Since("1.2.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") @@ -380,6 +541,7 @@ class SparseMatrix( * order for each column * @param values non-zero matrix entries in column major */ + @Since("1.2.0") def this( numRows: Int, numCols: Int, @@ -387,6 +549,11 @@ class SparseMatrix( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze + case _ => false + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) @@ -396,7 +563,8 @@ class SparseMatrix( } } - private[mllib] def apply(i: Int, j: Int): Double = { + @Since("1.3.0") + override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } @@ -419,10 +587,13 @@ class SparseMatrix( } } - override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + @Since("1.4.0") + override def copy: SparseMatrix = { + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + } - private[mllib] def map(f: Double => Double) = - new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) + private[spark] def map(f: Double => Double) = + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): SparseMatrix = { val len = values.length @@ -434,7 +605,8 @@ class SparseMatrix( this } - override def transpose: Matrix = + @Since("1.3.0") + override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { @@ -464,16 +636,27 @@ class SparseMatrix( } } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed - * set to false. */ - def toDense(): DenseMatrix = { + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. + */ + @Since("1.3.0") + def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } + + @Since("1.5.0") + override def numNonzeros: Int = values.count(_ != 0) + + @Since("1.5.0") + override def numActives: Int = values.length + } /** * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]]. */ +@Since("1.3.0") object SparseMatrix { /** @@ -485,6 +668,7 @@ object SparseMatrix { * @param entries Array of (i, j, value) tuples * @return The corresponding `SparseMatrix` */ + @Since("1.3.0") def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) val numEntries = sortedEntries.size @@ -533,6 +717,7 @@ object SparseMatrix { * @param n number of rows and columns of the matrix * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def speye(n: Int): SparseMatrix = { new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) } @@ -593,7 +778,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. The number of non-zero + * Generate a `SparseMatrix` consisting of `i.i.d`. uniform random numbers. The number of non-zero * elements equal the ceiling of `numRows` x `numCols` x `density` * * @param numRows number of rows of the matrix @@ -602,19 +787,21 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextDouble()) } /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d`. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextGaussian()) @@ -626,7 +813,8 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ - def diag(vector: Vector): SparseMatrix = { + @Since("1.3.0") + def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { case sVec: SparseVector => @@ -642,6 +830,7 @@ object SparseMatrix { /** * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]]. */ +@Since("1.0.0") object Matrices { /** @@ -651,6 +840,7 @@ object Matrices { * @param numCols number of columns * @param values matrix entries in column major */ + @Since("1.0.0") def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = { new DenseMatrix(numRows, numCols, values) } @@ -664,6 +854,7 @@ object Matrices { * @param rowIndices the row index of the entry * @param values non-zero matrix entries in column major */ + @Since("1.2.0") def sparse( numRows: Int, numCols: Int, @@ -692,11 +883,12 @@ object Matrices { } /** - * Generate a `DenseMatrix` consisting of zeros. + * Generate a `Matrix` consisting of zeros. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros */ + @Since("1.2.0") def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** @@ -705,6 +897,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of ones */ + @Since("1.2.0") def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** @@ -712,6 +905,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.2.0") def eye(n: Int): Matrix = DenseMatrix.eye(n) /** @@ -719,56 +913,62 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def speye(n: Int): Matrix = SparseMatrix.speye(n) /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.2.0") def rand(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.rand(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprand(numRows, numCols, density, rng) /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.2.0") def randn(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.randn(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprandn(numRows, numCols, density, rng) /** - * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. - * @param vector a `Vector` tat will form the values on the diagonal of the matrix + * Generate a diagonal matrix in `Matrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("1.2.0") def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) /** @@ -778,6 +978,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated */ + @Since("1.3.0") def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) @@ -836,6 +1037,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were vertically concatenated */ + @Since("1.3.0") def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 9669c364bad8..4dcf8f28f202 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -17,11 +17,21 @@ package org.apache.spark.mllib.linalg -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Represents singular value decomposition (SVD) factors. */ +@Since("1.0.0") @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) + +/** + * :: Experimental :: + * Represents QR factors. + */ +@Since("1.5.0") +@Experimental +case class QRDecomposition[QType, RType](Q: QType, R: RType) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 8f75e6f46e05..06ebb1586990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,8 +26,9 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException +import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ @@ -37,21 +38,24 @@ import org.apache.spark.sql.types._ * Note: Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) +@Since("1.0.0") sealed trait Vector extends Serializable { /** * Size of the vector. */ + @Since("1.0.0") def size: Int /** * Converts the instance to a double array. */ + @Since("1.0.0") def toArray: Array[Double] override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -62,20 +66,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -84,17 +96,19 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. * @param i index */ + @Since("1.1.0") def apply(i: Int): Double = toBreeze(i) /** * Makes a deep copy of this vector. */ + @Since("1.1.0") def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } @@ -107,13 +121,62 @@ sealed trait Vector extends Serializable { * with type `Double`. */ private[spark] def foreachActive(f: (Int, Double) => Unit) + + /** + * Number of active entries. An "active entry" is an element which is explicitly stored, + * regardless of its value. Note that inactive entries have value 0. + */ + @Since("1.4.0") + def numActives: Int + + /** + * Number of nonzero elements. This scans all active values and count nonzeros. + */ + @Since("1.4.0") + def numNonzeros: Int + + /** + * Converts this vector to a sparse vector with all explicit zeros removed. + */ + @Since("1.4.0") + def toSparse: SparseVector + + /** + * Converts this vector to a dense vector. + */ + @Since("1.4.0") + def toDense: DenseVector = new DenseVector(this.toArray) + + /** + * Returns a vector in either dense or sparse format, whichever uses less storage. + */ + @Since("1.4.0") + def compressed: Vector = { + val nnz = numNonzeros + // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. + if (1.5 * (nnz + 1.0) < size) { + toSparse + } else { + toDense + } + } + + /** + * Find the index of a maximal element. Returns the first maximal element in case of a tie. + * Returns -1 if vector has length 0. + */ + @Since("1.5.0") + def argmax: Int } /** + * :: AlphaComponent :: + * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. */ -private[spark] class VectorUDT extends UserDefinedType[Vector] { +@AlphaComponent +class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense @@ -127,40 +190,39 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) } - override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(4) + override def serialize(obj: Any): InternalRow = { obj match { case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, indices.toSeq) - row.update(3, values.toSeq) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row case DenseVector(values) => + val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, values.toSeq) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row } - row } override def deserialize(datum: Any): Vector = { datum match { - // TODO: something wrong with UDT serialization - case v: Vector => - v - case row: Row => - require(row.length == 4, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + case row: InternalRow => + require(row.numFields == 4, + s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") val tpe = row.getByte(0) tpe match { case 0 => val size = row.getInt(1) - val indices = row.getAs[Iterable[Int]](2).toArray - val values = row.getAs[Iterable[Double]](3).toArray + val indices = row.getArray(2).toIntArray() + val values = row.getArray(3).toDoubleArray() new SparseVector(size, indices, values) case 1 => - val values = row.getAs[Iterable[Double]](3).toArray + val values = row.getArray(3).toDoubleArray() new DenseVector(values) } } @@ -169,6 +231,20 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" override def userClass: Class[Vector] = classOf[Vector] + + override def equals(o: Any): Boolean = { + o match { + case v: VectorUDT => true + case _ => false + } + } + + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[VectorUDT].getName.hashCode() + + override def typeName: String = "vector" + + private[spark] override def asNullable: VectorUDT = this } /** @@ -176,11 +252,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { * We don't use the name `Vector` because Scala imports * [[scala.collection.immutable.Vector]] by default. */ +@Since("1.0.0") object Vectors { /** * Creates a dense vector from its values. */ + @Since("1.0.0") @varargs def dense(firstValue: Double, otherValues: Double*): Vector = new DenseVector((firstValue +: otherValues).toArray) @@ -189,6 +267,7 @@ object Vectors { /** * Creates a dense vector from a double array. */ + @Since("1.0.0") def dense(values: Array[Double]): Vector = new DenseVector(values) /** @@ -198,6 +277,7 @@ object Vectors { * @param indices index array, must be strictly increasing. * @param values value array, must have the same length as indices. */ + @Since("1.0.0") def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector = new SparseVector(size, indices, values) @@ -207,8 +287,9 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0) + require(size > 0, "The size of the requested sparse vector must be greater than 0.") val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 @@ -216,7 +297,8 @@ object Vectors { require(prev < i, s"Found duplicate indices: $i.") prev = i } - require(prev < size) + require(prev < size, s"You may not write an element to index $prev because the declared " + + s"size of your vector is $size") new SparseVector(size, indices.toArray, values.toArray) } @@ -227,6 +309,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("1.0.0") def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = { sparse(size, elements.asScala.map { case (i, x) => (i.intValue(), x.doubleValue()) @@ -234,19 +317,20 @@ object Vectors { } /** - * Creates a dense vector of all zeros. + * Creates a vector of all zeros. * * @param size vector size * @return a zero vector */ + @Since("1.1.0") def zeros(size: Int): Vector = { new DenseVector(new Array[Double](size)) } /** - * Parses a string resulted from `Vector#toString` into - * an [[org.apache.spark.mllib.linalg.Vector]]. + * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. */ + @Since("1.1.0") def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) } @@ -265,7 +349,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -290,14 +374,16 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ + @Since("1.3.0") def norm(vector: Vector, p: Double): Double = { - require(p >= 1.0) + require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + + s"You specified p=$p.") val values = vector match { case DenseVector(vs) => vs case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -341,8 +427,10 @@ object Vectors { * @param v2 second Vector. * @return squared distance between two Vectors. */ + @Since("1.3.0") def sqdist(v1: Vector, v2: Vector): Double = { - require(v1.size == v2.size, "vector dimension mismatch") + require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + + s"=${v2.size}.") var squaredDistance = 0.0 (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => @@ -350,8 +438,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -380,7 +468,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -401,7 +489,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -430,8 +518,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -453,26 +541,32 @@ object Vectors { /** * A dense vector represented by a value array. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class DenseVector(val values: Array[Double]) extends Vector { +class DenseVector @Since("1.0.0") ( + @Since("1.0.0") val values: Array[Double]) extends Vector { + @Since("1.0.0") override def size: Int = values.length override def toString: String = values.mkString("[", ",", "]") + @Since("1.0.0") override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) - override def apply(i: Int) = values(i) + @Since("1.0.0") + override def apply(i: Int): Double = values(i) + @Since("1.1.0") override def copy: DenseVector = { new DenseVector(values.clone()) } private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -480,9 +574,79 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } + + @Since("1.4.0") + override def numActives: Int = size + + @Since("1.4.0") + override def numNonzeros: Int = { + // same as values.count(_ != 0.0) but faster + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + @Since("1.4.0") + override def toSparse: SparseVector = { + val nnz = numNonzeros + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } + + @Since("1.5.0") + override def argmax: Int = { + if (size == 0) { + -1 + } else { + var maxIdx = 0 + var maxValue = values(0) + var i = 1 + while (i < size) { + if (values(i) > maxValue) { + maxIdx = i + maxValue = values(i) + } + i += 1 + } + maxIdx + } + } } +@Since("1.3.0") object DenseVector { + + /** Extracts the value array from a dense vector. */ + @Since("1.3.0") def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } @@ -493,17 +657,23 @@ object DenseVector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class SparseVector( - override val size: Int, - val indices: Array[Int], - val values: Array[Double]) extends Vector { +class SparseVector @Since("1.0.0") ( + @Since("1.0.0") override val size: Int, + @Since("1.0.0") val indices: Array[Int], + @Since("1.0.0") val values: Array[Double]) extends Vector { - require(indices.length == values.length) + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") + require(indices.length <= size, s"You provided ${indices.length} indices and values, " + + s"which exceeds the specified vector size ${size}.") override def toString: String = - "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) + s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" + @Since("1.0.0") override def toArray: Array[Double] = { val data = new Array[Double](size) var i = 0 @@ -515,15 +685,16 @@ class SparseVector( data } + @Since("1.1.0") override def copy: SparseVector = { new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -532,9 +703,137 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } + + @Since("1.4.0") + override def numActives: Int = values.length + + @Since("1.4.0") + override def numNonzeros: Int = { + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + @Since("1.4.0") + override def toSparse: SparseVector = { + val nnz = numNonzeros + if (nnz == numActives) { + this + } else { + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0.0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } + } + + @Since("1.5.0") + override def argmax: Int = { + if (size == 0) { + -1 + } else { + // Find the max active entry. + var maxIdx = indices(0) + var maxValue = values(0) + var maxJ = 0 + var j = 1 + val na = numActives + while (j < na) { + val v = values(j) + if (v > maxValue) { + maxValue = v + maxIdx = indices(j) + maxJ = j + } + j += 1 + } + + // If the max active entry is nonpositive and there exists inactive ones, find the first zero. + if (maxValue <= 0.0 && na < size) { + if (maxValue == 0.0) { + // If there exists an inactive entry before maxIdx, find it and return its index. + if (maxJ < maxIdx) { + var k = 0 + while (k < maxJ && indices(k) == k) { + k += 1 + } + maxIdx = k + } + } else { + // If the max active value is negative, find and return the first inactive index. + var k = 0 + while (k < na && indices(k) == k) { + k += 1 + } + maxIdx = k + } + } + + maxIdx + } + } + + /** + * Create a slice of this vector based on the given indices. + * @param selectedIndices Unsorted list of indices into the vector. + * This does NOT do bound checking. + * @return New SparseVector with values in the order specified by the given indices. + * + * NOTE: The API needs to be discussed before making this public. + * Also, if we have a version assuming indices are sorted, we should optimize it. + */ + private[spark] def slice(selectedIndices: Array[Int]): SparseVector = { + var currentIdx = 0 + val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx => + val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx) + val i_v = if (iIdx >= 0) { + Iterator((currentIdx, this.values(iIdx))) + } else { + Iterator() + } + currentIdx += 1 + i_v + }.unzip + new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) + } } +@Since("1.3.0") object SparseVector { + @Since("1.3.0") def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 3871152d065a..a33b6137cf9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -21,7 +21,8 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.{SparkException, Logging, Partitioner} +import org.apache.spark.{Logging, Partitioner, SparkException} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -48,7 +49,7 @@ private[mllib] class GridPartitioner( private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt - override val numPartitions = rowPartitions * colPartitions + override val numPartitions: Int = rowPartitions * colPartitions /** * Returns the index of the partition the input coordinate belongs to. @@ -84,6 +85,14 @@ private[mllib] class GridPartitioner( false } } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + rows: java.lang.Integer, + cols: java.lang.Integer, + rowsPerPart: java.lang.Integer, + colsPerPart: java.lang.Integer) + } } private[mllib] object GridPartitioner { @@ -104,6 +113,8 @@ private[mllib] object GridPartitioner { } /** + * :: Experimental :: + * * Represents a distributed matrix in blocks of local matrices. * * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that @@ -118,10 +129,12 @@ private[mllib] object GridPartitioner { * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to * zero, the number of columns will be calculated when `numCols` is invoked. */ -class BlockMatrix( - val blocks: RDD[((Int, Int), Matrix)], - val rowsPerBlock: Int, - val colsPerBlock: Int, +@Since("1.3.0") +@Experimental +class BlockMatrix @Since("1.3.0") ( + @Since("1.3.0") val blocks: RDD[((Int, Int), Matrix)], + @Since("1.3.0") val rowsPerBlock: Int, + @Since("1.3.0") val colsPerBlock: Int, private var nRows: Long, private var nCols: Long) extends DistributedMatrix with Logging { @@ -138,6 +151,7 @@ class BlockMatrix( * @param colsPerBlock Number of columns that make up each block. The blocks forming the final * columns are not required to have the given number of columns */ + @Since("1.3.0") def this( blocks: RDD[((Int, Int), Matrix)], rowsPerBlock: Int, @@ -145,17 +159,21 @@ class BlockMatrix( this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) } + @Since("1.3.0") override def numRows(): Long = { if (nRows <= 0L) estimateDim() nRows } + @Since("1.3.0") override def numCols(): Long = { if (nCols <= 0L) estimateDim() nCols } + @Since("1.3.0") val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + @Since("1.3.0") val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt private[mllib] def createPartitioner(): GridPartitioner = @@ -177,6 +195,11 @@ class BlockMatrix( assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") } + /** + * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if + * any error is found. + */ + @Since("1.3.0") def validate(): Unit = { logDebug("Validating BlockMatrix...") // check if the matrix is larger than the claimed dimensions @@ -213,18 +236,21 @@ class BlockMatrix( } /** Caches the underlying RDD. */ + @Since("1.3.0") def cache(): this.type = { blocks.cache() this } /** Persists the underlying RDD with the specified storage level. */ + @Since("1.3.0") def persist(storageLevel: StorageLevel): this.type = { blocks.persist(storageLevel) this } /** Converts to CoordinateMatrix. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => val rowStart = blockRowIndex.toLong * rowsPerBlock @@ -239,6 +265,7 @@ class BlockMatrix( } /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.3.0") def toIndexedRowMatrix(): IndexedRowMatrix = { require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + s"numCols: ${numCols()}") @@ -247,6 +274,7 @@ class BlockMatrix( } /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + @Since("1.3.0") def toLocalMatrix(): Matrix = { require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + s"Int.MaxValue. Currently numRows: ${numRows()}") @@ -271,8 +299,11 @@ class BlockMatrix( new DenseMatrix(m, n, values) } - /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the - * same underlying data. Is a lazy operation. */ + /** + * Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. + */ + @Since("1.3.0") def transpose: BlockMatrix = { val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => ((blockColIndex, blockRowIndex), mat.transpose) @@ -286,12 +317,14 @@ class BlockMatrix( new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) } - /** Adds two block matrices together. The matrices must have the same size and matching - * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are - * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even - * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will - * also be a [[DenseMatrix]]. - */ + /** + * Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + @Since("1.3.0") def add(other: BlockMatrix): BlockMatrix = { require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") @@ -319,12 +352,14 @@ class BlockMatrix( } } - /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` - * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains - * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output - * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause - * some performance issues until support for multiplying two sparse matrices is added. - */ + /** + * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + */ + @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + @@ -351,7 +386,7 @@ class BlockMatrix( if (a.nonEmpty && b.nonEmpty) { val C = b.head match { case dense: DenseMatrix => a.head.multiply(dense) - case sparse: SparseMatrix => a.head.multiply(sparse.toDense()) + case sparse: SparseMatrix => a.head.multiply(sparse.toDense) case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") } Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 078d1fac4444..644f293d88a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} * @param j column index * @param value value of the entry */ +@Since("1.0.0") @Experimental case class MatrixEntry(i: Long, j: Long, value: Double) @@ -43,16 +44,19 @@ case class MatrixEntry(i: Long, j: Long, value: Double) * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the max column index plus one. */ +@Since("1.0.0") @Experimental -class CoordinateMatrix( - val entries: RDD[MatrixEntry], +class CoordinateMatrix @Since("1.0.0") ( + @Since("1.0.0") val entries: RDD[MatrixEntry], private var nRows: Long, private var nCols: Long) extends DistributedMatrix { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L) /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0L) { computeSize() @@ -61,6 +65,7 @@ class CoordinateMatrix( } /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { computeSize() @@ -69,11 +74,13 @@ class CoordinateMatrix( } /** Transposes this CoordinateMatrix. */ + @Since("1.3.0") def transpose(): CoordinateMatrix = { new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) } /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.0.0") def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() if (nl > Int.MaxValue) { @@ -93,11 +100,13 @@ class CoordinateMatrix( * Converts to RowMatrix, dropping row indices after grouping by row index. * The number of columns must be within the integer range. */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { toIndexedRowMatrix().toRowMatrix() } /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -110,6 +119,7 @@ class CoordinateMatrix( * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { require(rowsPerBlock > 0, s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala index a0e26ce3bc46..db3433a5e245 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -19,15 +19,20 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.annotation.Since + /** * Represents a distributively stored matrix backed by one or more RDDs. */ +@Since("1.0.0") trait DistributedMatrix extends Serializable { /** Gets or computes the number of rows. */ + @Since("1.0.0") def numRows(): Long /** Gets or computes the number of columns. */ + @Since("1.0.0") def numCols(): Long /** Collects data and assembles a local dense breeze matrix (for test only). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 3be530fa0753..b20ea0dc50da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition @@ -28,6 +28,7 @@ import org.apache.spark.mllib.linalg.SingularValueDecomposition * :: Experimental :: * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. */ +@Since("1.0.0") @Experimental case class IndexedRow(index: Long, vector: Vector) @@ -42,15 +43,18 @@ case class IndexedRow(index: Long, vector: Vector) * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. */ +@Since("1.0.0") @Experimental -class IndexedRowMatrix( - val rows: RDD[IndexedRow], +class IndexedRowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[IndexedRow], private var nRows: Long, private var nCols: Int) extends DistributedMatrix { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0) + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { // Calling `first` will throw an exception if `rows` is empty. @@ -59,6 +63,7 @@ class IndexedRowMatrix( nCols } + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { // Reduce will throw an exception if `rows` is empty. @@ -71,11 +76,13 @@ class IndexedRowMatrix( * Drops row indices and converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]]. */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { new RowMatrix(rows.map(_.vector), 0L, nCols) } /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -88,6 +95,7 @@ class IndexedRowMatrix( * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { // TODO: This implementation may be optimized toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) @@ -97,6 +105,7 @@ class IndexedRowMatrix( * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entries = rows.flatMap { row => val rowIndex = row.index @@ -133,6 +142,7 @@ class IndexedRowMatrix( * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V) */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -146,7 +156,7 @@ class IndexedRowMatrix( val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => IndexedRow(i, v) } - new IndexedRowMatrix(indexedRows, nRows, nCols) + new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt) } else { null } @@ -159,6 +169,7 @@ class IndexedRowMatrix( * @param B a local matrix whose number of rows must match the number of columns of this matrix * @return an IndexedRowMatrix representing the product, which preserves partitioning */ + @Since("1.0.0") def multiply(B: Matrix): IndexedRowMatrix = { val mat = toRowMatrix().multiply(B) val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => @@ -170,6 +181,7 @@ class IndexedRowMatrix( /** * Computes the Gramian matrix `A^T A`. */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { toRowMatrix().computeGramianMatrix() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 961111507f2c..9a423ddafdc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -22,13 +22,13 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd} + svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -45,16 +45,19 @@ import org.apache.spark.storage.StorageLevel * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. */ +@Since("1.0.0") @Experimental -class RowMatrix( - val rows: RDD[Vector], +class RowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[Vector], private var nRows: Long, private var nCols: Int) extends DistributedMatrix with Logging { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[Vector]) = this(rows, 0L, 0) /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { try { @@ -70,6 +73,7 @@ class RowMatrix( } /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { nRows = rows.count() @@ -108,6 +112,7 @@ class RowMatrix( /** * Computes the Gramian matrix `A^T A`. */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -178,6 +183,7 @@ class RowMatrix( * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -219,7 +225,7 @@ class RowMatrix( val computeMode = mode match { case "auto" => - if(k > 5000) { + if (k > 5000) { logWarning(s"computing svd with k=$k and n=$n, please check necessity") } @@ -318,6 +324,7 @@ class RowMatrix( * Computes the covariance matrix, treating each row as an observation. * @return a local dense matrix of size n x n */ + @Since("1.0.0") def computeCovariance(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -371,6 +378,7 @@ class RowMatrix( * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ + @Since("1.0.0") def computePrincipalComponents(k: Int): Matrix = { val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") @@ -389,6 +397,7 @@ class RowMatrix( /** * Computes column-wise summary statistics. */ + @Since("1.0.0") def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), @@ -404,6 +413,7 @@ class RowMatrix( * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product, * which preserves partitioning */ + @Since("1.0.0") def multiply(B: Matrix): RowMatrix = { val n = numCols().toInt val k = B.numCols @@ -436,6 +446,7 @@ class RowMatrix( * @return An n x n sparse upper-triangular matrix of cosine similarities between * columns of this matrix. */ + @Since("1.2.0") def columnSimilarities(): CoordinateMatrix = { columnSimilarities(0.0) } @@ -479,6 +490,7 @@ class RowMatrix( * @return An n x n sparse upper-triangular matrix of cosine similarities * between columns of this matrix. */ + @Since("1.2.0") def columnSimilarities(threshold: Double): CoordinateMatrix = { require(threshold >= 0, s"Threshold cannot be negative: $threshold") @@ -497,6 +509,51 @@ class RowMatrix( columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) } + /** + * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR + * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. + * Reference: + * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce + * architectures" ([[http://dx.doi.org/10.1145/1996092.1996103]]) + * + * @param computeQ whether to computeQ + * @return QRDecomposition(Q, R), Q = null if computeQ = false. + */ + @Since("1.5.0") + def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { + val col = numCols().toInt + // split rows horizontally into smaller matrices, and compute QR for each of them + val blockQRs = rows.glom().map { partRows => + val bdm = BDM.zeros[Double](partRows.length, col) + var i = 0 + partRows.foreach { row => + bdm(i, ::) := row.toBreeze.t + i += 1 + } + breeze.linalg.qr.reduced(bdm).r + } + + // combine the R part from previous results vertically into a tall matrix + val combinedR = blockQRs.treeReduce{ (r1, r2) => + val stackedR = BDM.vertcat(r1, r2) + breeze.linalg.qr.reduced(stackedR).r + } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) + val finalQ = if (computeQ) { + try { + val invR = inv(combinedR) + this.multiply(Matrices.fromBreeze(invR)) + } catch { + case err: MatrixSingularException => + logWarning("R is not invertible and return Q as null") + null + } + } else { + null + } + QRDecomposition(finalQ, finalR) + } + /** * Find all similar columns using the DIMSUM sampling algorithm, described in two papers * @@ -531,7 +588,6 @@ class RowMatrix( val rand = new XORShiftRandom(indx) val scaled = new Array[Double](p.size) iter.flatMap { row => - val buf = new ListBuffer[((Int, Int), Double)]() row match { case SparseVector(size, indices, values) => val nnz = indices.size @@ -540,8 +596,9 @@ class RowMatrix( scaled(k) = values(k) / q(indices(k)) k += 1 } - k = 0 - while (k < nnz) { + + Iterator.tabulate (nnz) { k => + val buf = new ListBuffer[((Int, Int), Double)]() val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { @@ -555,8 +612,8 @@ class RowMatrix( l += 1 } } - k += 1 - } + buf + }.flatten case DenseVector(values) => val n = values.size var i = 0 @@ -564,8 +621,8 @@ class RowMatrix( scaled(i) = values(i) / q(i) i += 1 } - i = 0 - while (i < n) { + Iterator.tabulate (n) { i => + val buf = new ListBuffer[((Int, Int), Double)]() val iVal = scaled(i) if (iVal != 0 && rand.nextDouble() < p(i)) { var j = i + 1 @@ -577,10 +634,9 @@ class RowMatrix( j += 1 } } - i += 1 - } + buf + }.flatten } - buf } }.reduceByKey(_ + _).map { case ((i, j), sim) => MatrixEntry(i.toLong, j.toLong, sim) @@ -613,6 +669,7 @@ class RowMatrix( } } +@Since("1.0.0") @Experimental object RowMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 0acdab797e8f..240baeb5a158 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, @@ -63,10 +67,12 @@ abstract class Gradient extends Serializable { * http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 gives the formula of * multinomial logistic regression model. A simple calculation shows that * + * {{{ * P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i)) * P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i)) * ... * P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i)) + * }}} * * for K classes multiclass classification problem. * @@ -75,9 +81,11 @@ abstract class Gradient extends Serializable { * will be (K-1) * N. * * As a result, the loss of objective function for a single instance of data can be written as + * {{{ * l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log P(y|x, w) * = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1} * = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} + * }}} * * where \alpha(i) = 1 if i != 0, and * \alpha(i) = 0 if i == 0, @@ -86,14 +94,16 @@ abstract class Gradient extends Serializable { * For optimization, we have to calculate the first derivative of the loss function, and * a simple calculation shows that * + * {{{ * \frac{\partial l(w, x)}{\partial w_{ij}} * = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, i+1})) * x_j * = multiplier_i * x_j + * }}} * * where \delta_{i, j} = 1 if i == j, * \delta_{i, j} = 0 if i != j, and - * multiplier - * = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) + * multiplier = + * \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) * * If any of margins is larger than 709.78, the numerical computation of multiplier and loss * function will be suffered from arithmetic overflow. This issue occurs when there are outliers @@ -103,10 +113,12 @@ abstract class Gradient extends Serializable { * Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can be * easily rewritten into the following equivalent numerically stable formula. * + * {{{ * l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} * = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + maxMargin * - (1-\alpha(y)) margins_{y-1} * = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1} + * }}} * * where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1. * @@ -115,8 +127,10 @@ abstract class Gradient extends Serializable { * * For multiplier, similar trick can be applied as the following, * + * {{{ * multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) * = \exp(margins_i - maxMargin) / (1 + sum) - (1-\alpha(y)\delta_{y, i+1}) + * }}} * * where each term in \exp is also smaller than zero, so overflow is not a concern. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 4b7d0589c973..3b663b5defb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,25 +19,27 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV} +import breeze.linalg.{DenseVector => BDV, norm} import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} + /** * Class used to solve an optimization problem using Gradient Descent. * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 private var numIterations: Int = 100 private var regParam: Double = 0.0 private var miniBatchFraction: Double = 1.0 + private var convergenceTol: Double = 0.001 /** * Set the initial step size of SGD for the first step. Default 1.0. @@ -75,6 +77,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va this } + /** + * Set the convergence tolerance. Default 0.001 + * convergenceTol is a condition which decides iteration termination. + * The end of iteration is decided based on below logic. + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * Must be between 0.0 and 1.0 inclusively. + */ + def setConvergenceTol(tolerance: Double): this.type = { + require(0.0 <= tolerance && tolerance <= 1.0) + this.convergenceTol = tolerance + this + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for SGD. @@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va numIterations, regParam, miniBatchFraction, - initialWeights) + initialWeights, + convergenceTol) weights } @@ -131,17 +151,20 @@ object GradientDescent extends Logging { * Sampling, and averaging the subgradients over this subset is performed using one standard * spark map-reduce in each iteration. * - * @param data - Input data for SGD. RDD of the set of data examples, each of - * the form (label, [feature values]). - * @param gradient - Gradient object (used to compute the gradient of the loss function of - * one single data example) - * @param updater - Updater function to actually perform a gradient step in a given direction. - * @param stepSize - initial step size for the first step - * @param numIterations - number of iterations that SGD should be run. - * @param regParam - regularization parameter - * @param miniBatchFraction - fraction of the input data set that should be used for - * one iteration of SGD. Default value 1.0. - * + * @param data Input data for SGD. RDD of the set of data examples, each of + * the form (label, [feature values]). + * @param gradient Gradient object (used to compute the gradient of the loss function of + * one single data example) + * @param updater Updater function to actually perform a gradient step in a given direction. + * @param stepSize initial step size for the first step + * @param numIterations number of iterations that SGD should be run. + * @param regParam regularization parameter + * @param miniBatchFraction fraction of the input data set that should be used for + * one iteration of SGD. Default value 1.0. + * @param convergenceTol Minibatch iteration will end before numIterations if the relative + * difference between the current weight and the previous weight is less + * than this value. In measuring convergence, L2 norm is calculated. + * Default value 0.001. Must be between 0.0 and 1.0 inclusively. * @return A tuple containing two elements. The first element is a column matrix containing * weights for every feature, and the second element is an array containing the * stochastic loss computed for every iteration. @@ -154,9 +177,20 @@ object GradientDescent extends Logging { numIterations: Int, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): (Vector, Array[Double]) = { + initialWeights: Vector, + convergenceTol: Double): (Vector, Array[Double]) = { + + // convergenceTol should be set with non minibatch settings + if (miniBatchFraction < 1.0 && convergenceTol > 0.0) { + logWarning("Testing against a convergenceTol when using miniBatchFraction " + + "< 1.0 can be unstable because of the stochasticity in sampling.") + } val stochasticLossHistory = new ArrayBuffer[Double](numIterations) + // Record previous weight and current one to calculate solution vector difference + + var previousWeights: Option[Vector] = None + var currentWeights: Option[Vector] = None val numExamples = data.count() @@ -179,9 +213,11 @@ object GradientDescent extends Logging { * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 - for (i <- 1 to numIterations) { + var converged = false // indicates whether converged based on convergenceTol + var i = 1 + while (!converged && i <= numIterations) { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) @@ -199,17 +235,26 @@ object GradientDescent extends Logging { if (miniBatchSize > 0) { /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * lossSum is computed using the weights from the previous iteration * and regVal is the regularization value computed in the previous iteration as well. */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), + stepSize, i, regParam) weights = update._1 regVal = update._2 + + previousWeights = currentWeights + currentWeights = Some(weights) + if (previousWeights != None && currentWeights != None) { + converged = isConverged(previousWeights.get, + currentWeights.get, convergenceTol) + } } else { logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") } + i += 1 } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( @@ -218,4 +263,35 @@ object GradientDescent extends Logging { (weights, stochasticLossHistory.toArray) } + + /** + * Alias of [[runMiniBatchSGD]] with convergenceTol set to default value of 0.001. + */ + def runMiniBatchSGD( + data: RDD[(Double, Vector)], + gradient: Gradient, + updater: Updater, + stepSize: Double, + numIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Vector): (Vector, Array[Double]) = + GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations, + regParam, miniBatchFraction, initialWeights, 0.001) + + + private def isConverged( + previousWeights: Vector, + currentWeights: Vector, + convergenceTol: Double): Boolean = { + // To compare with convergence tolerance. + val previousBDV = previousWeights.toBreeze.toDenseVector + val currentBDV = currentWeights.toBreeze.toDenseVector + + // This represents the difference of updated weights in the iteration. + val solutionVecDiff: Double = norm(previousBDV - currentBDV) + + solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d5e4f4ccbff1..efedc112d380 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -60,6 +61,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. + * This value must be nonnegative. Lower convergence values are less tolerant + * and therefore generally cause more iterations to be run. */ def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance @@ -142,7 +145,9 @@ object LBFGS extends Logging { * one single data example) * @param updater - Updater function to actually perform a gradient step in a given direction. * @param numCorrections - The number of corrections used in the L-BFGS update. - * @param convergenceTol - The convergence tolerance of iterations for L-BFGS + * @param convergenceTol - The convergence tolerance of iterations for L-BFGS which is must be + * nonnegative. Lower values are less tolerant and therefore generally + * cause more iterations to be run. * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. * @param regParam - Regularization parameter * @@ -160,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -177,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index ccd93b318bc2..64d52bae0090 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -17,7 +17,9 @@ package org.apache.spark.mllib.optimization -import org.jblas.{DoubleMatrix, SimpleBlas} +import java.{util => ju} + +import com.github.fommil.netlib.BLAS.{getInstance => blas} /** * Object used to solve nonnegative least squares problems using a modified @@ -25,20 +27,20 @@ import org.jblas.{DoubleMatrix, SimpleBlas} */ private[spark] object NNLS { class Workspace(val n: Int) { - val scratch = new DoubleMatrix(n, 1) - val grad = new DoubleMatrix(n, 1) - val x = new DoubleMatrix(n, 1) - val dir = new DoubleMatrix(n, 1) - val lastDir = new DoubleMatrix(n, 1) - val res = new DoubleMatrix(n, 1) - - def wipe() { - scratch.fill(0.0) - grad.fill(0.0) - x.fill(0.0) - dir.fill(0.0) - lastDir.fill(0.0) - res.fill(0.0) + val scratch = new Array[Double](n) + val grad = new Array[Double](n) + val x = new Array[Double](n) + val dir = new Array[Double](n) + val lastDir = new Array[Double](n) + val res = new Array[Double](n) + + def wipe(): Unit = { + ju.Arrays.fill(scratch, 0.0) + ju.Arrays.fill(grad, 0.0) + ju.Arrays.fill(x, 0.0) + ju.Arrays.fill(dir, 0.0) + ju.Arrays.fill(lastDir, 0.0) + ju.Arrays.fill(res, 0.0) } } @@ -60,18 +62,18 @@ private[spark] object NNLS { * direction, however, while this method only uses a conjugate gradient direction if the last * iteration did not cause a previously-inactive constraint to become active. */ - def solve(ata: DoubleMatrix, atb: DoubleMatrix, ws: Workspace): Array[Double] = { + def solve(ata: Array[Double], atb: Array[Double], ws: Workspace): Array[Double] = { ws.wipe() - val n = atb.rows + val n = atb.length val scratch = ws.scratch // find the optimal unconstrained step - def steplen(dir: DoubleMatrix, res: DoubleMatrix): Double = { - val top = SimpleBlas.dot(dir, res) - SimpleBlas.gemv(1.0, ata, dir, 0.0, scratch) + def steplen(dir: Array[Double], res: Array[Double]): Double = { + val top = blas.ddot(n, dir, 1, res, 1) + blas.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1) // Push the denominator upward very slightly to avoid infinities and silliness - top / (SimpleBlas.dot(scratch, dir) + 1e-20) + top / (blas.ddot(n, scratch, 1, dir, 1) + 1e-20) } // stopping condition @@ -89,59 +91,59 @@ private[spark] object NNLS { val dir = ws.dir val lastDir = ws.lastDir val res = ws.res - val iterMax = Math.max(400, 20 * n) + val iterMax = math.max(400, 20 * n) var lastNorm = 0.0 var iterno = 0 var lastWall = 0 // Last iteration when we hit a bound constraint. var i = 0 while (iterno < iterMax) { // find the residual - SimpleBlas.gemv(1.0, ata, x, 0.0, res) - SimpleBlas.axpy(-1.0, atb, res) - SimpleBlas.copy(res, grad) + blas.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1) + blas.daxpy(n, -1.0, atb, 1, res, 1) + blas.dcopy(n, res, 1, grad, 1) // project the gradient i = 0 while (i < n) { - if (grad.data(i) > 0.0 && x.data(i) == 0.0) { - grad.data(i) = 0.0 + if (grad(i) > 0.0 && x(i) == 0.0) { + grad(i) = 0.0 } i = i + 1 } - val ngrad = SimpleBlas.dot(grad, grad) + val ngrad = blas.ddot(n, grad, 1, grad, 1) - SimpleBlas.copy(grad, dir) + blas.dcopy(n, grad, 1, dir, 1) // use a CG direction under certain conditions var step = steplen(grad, res) var ndir = 0.0 - val nx = SimpleBlas.dot(x, x) + val nx = blas.ddot(n, x, 1, x, 1) if (iterno > lastWall + 1) { val alpha = ngrad / lastNorm - SimpleBlas.axpy(alpha, lastDir, dir) + blas.daxpy(n, alpha, lastDir, 1, dir, 1) val dstep = steplen(dir, res) - ndir = SimpleBlas.dot(dir, dir) + ndir = blas.ddot(n, dir, 1, dir, 1) if (stop(dstep, ndir, nx)) { // reject the CG step if it could lead to premature termination - SimpleBlas.copy(grad, dir) - ndir = SimpleBlas.dot(dir, dir) + blas.dcopy(n, grad, 1, dir, 1) + ndir = blas.ddot(n, dir, 1, dir, 1) } else { step = dstep } } else { - ndir = SimpleBlas.dot(dir, dir) + ndir = blas.ddot(n, dir, 1, dir, 1) } // terminate? if (stop(step, ndir, nx)) { - return x.data.clone + return x.clone } // don't run through the walls i = 0 while (i < n) { - if (step * dir.data(i) > x.data(i)) { - step = x.data(i) / dir.data(i) + if (step * dir(i) > x(i)) { + step = x(i) / dir(i) } i = i + 1 } @@ -149,19 +151,19 @@ private[spark] object NNLS { // take the step i = 0 while (i < n) { - if (step * dir.data(i) > x.data(i) * (1 - 1e-14)) { - x.data(i) = 0 + if (step * dir(i) > x(i) * (1 - 1e-14)) { + x(i) = 0 lastWall = iterno } else { - x.data(i) -= step * dir.data(i) + x(i) -= step * dir(i) } i = i + 1 } iterno = iterno + 1 - SimpleBlas.copy(dir, lastDir) + blas.dcopy(n, dir, 1, lastDir, 1) lastNorm = ngrad } - x.data.clone + x.clone } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 3ed3a5b9b384..9f463e0cafb6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -116,7 +116,8 @@ class L1Updater extends Updater { // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize var i = 0 - while (i < brzWeights.length) { + val len = brzWeights.length + while (i < len) { val wi = brzWeights(i) brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal) i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala new file mode 100644 index 000000000000..274ac7c99553 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml + +import java.io.{File, OutputStream, StringWriter} +import javax.xml.transform.stream.StreamResult + +import org.jpmml.model.JAXBUtil + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory + +/** + * :: DeveloperApi :: + * Export model to the PMML format + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +@DeveloperApi +@Since("1.4.0") +trait PMMLExportable { + + /** + * Export the model to the stream result in PMML format + */ + private def toPMML(streamResult: StreamResult): Unit = { + val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) + JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) + } + + /** + * :: Experimental :: + * Export the model to a local file in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(localPath: String): Unit = { + toPMML(new StreamResult(new File(localPath))) + } + + /** + * :: Experimental :: + * Export the model to a directory on a distributed file system in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(sc: SparkContext, path: String): Unit = { + val pmml = toPMML() + sc.parallelize(Array(pmml), 1).saveAsTextFile(path) + } + + /** + * :: Experimental :: + * Export the model to the OutputStream in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(outputStream: OutputStream): Unit = { + toPMML(new StreamResult(outputStream)) + } + + /** + * :: Experimental :: + * Export the model to a String in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(): String = { + val writer = new StringWriter + toPMML(new StreamResult(writer)) + writer.toString + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala new file mode 100644 index 000000000000..622b53a252ac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel + */ +private[mllib] class BinaryClassificationPMMLModelExport( + model : GeneralizedLinearModel, + description : String, + normalizationMethod : RegressionNormalizationMethodType, + threshold: Double) + extends PMMLModelExport { + + populateBinaryClassificationPMML() + + /** + * Export the input LogisticRegressionModel or SVMModel to PMML format. + */ + private def populateBinaryClassificationPMML(): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + var interceptNO = threshold + if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { + if (threshold <= 0) { + interceptNO = Double.MinValue + } else if (threshold >= 1) { + interceptNO = Double.MaxValue + } else { + interceptNO = -math.log(1 / threshold - 1) + } + } + val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.CLASSIFICATION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withNormalizationMethod(normalizationMethod) + .withRegressionTables(regressionTableYES, regressionTableNO) + + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // add target field + val targetField = FieldName.create("target") + dataDictionary + .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala new file mode 100644 index 000000000000..1874786af000 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel abstract class + */ +private[mllib] class GeneralizedLinearPMMLModelExport( + model: GeneralizedLinearModel, + description: String) + extends PMMLModelExport { + + populateGeneralizedLinearPMML(model) + + /** + * Export the input GeneralizedLinearModel model to PMML format. + */ + private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTable = new RegressionTable(model.intercept) + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.REGRESSION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withRegressionTables(regressionTable) + + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // for completeness add target field + val targetField = FieldName.create("target") + dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala new file mode 100644 index 000000000000..069e7afc9fca --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.clustering.KMeansModel + +/** + * PMML Model Export for KMeansModel class + */ +private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ + + populateKMeansPMML(model) + + /** + * Export the input KMeansModel model to PMML format. + */ + private def populateKMeansPMML(model : KMeansModel): Unit = { + pmml.getHeader.setDescription("k-means clustering") + + if (model.clusterCenters.length > 0) { + val clusterCenter = model.clusterCenters(0) + val fields = new SArray[FieldName](clusterCenter.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val comparisonMeasure = new ComparisonMeasure() + .withKind(ComparisonMeasure.Kind.DISTANCE) + .withMeasure(new SquaredEuclidean()) + val clusteringModel = new ClusteringModel() + .withModelName("k-means") + .withMiningSchema(miningSchema) + .withComparisonMeasure(comparisonMeasure) + .withFunctionName(MiningFunctionType.CLUSTERING) + .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .withNumberOfClusters(model.clusterCenters.length) + + for (i <- 0 until clusterCenter.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + clusteringModel.withClusteringFields( + new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + } + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + for (i <- 0 until model.clusterCenters.length) { + val cluster = new Cluster() + .withName("cluster_" + i) + .withArray(new org.dmg.pmml.Array() + .withType(Array.Type.REAL) + .withN(clusterCenter.size) + .withValue(model.clusterCenters(i).toArray.mkString(" "))) + // we don't have the size of the single cluster but only the centroids (withValue) + // .withSize(value) + clusteringModel.withClusters(cluster) + } + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(clusteringModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala new file mode 100644 index 000000000000..c5fdecd3ca17 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import java.text.SimpleDateFormat +import java.util.Date + +import scala.beans.BeanProperty + +import org.dmg.pmml.{Application, Header, PMML, Timestamp} + +private[mllib] trait PMMLModelExport { + + /** + * Holder of the exported model in PMML format + */ + @BeanProperty + val pmml: PMML = new PMML + + setHeader(pmml) + + private def setHeader(pmml: PMML): Unit = { + val version = getClass.getPackage.getImplementationVersion + val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val timestamp = new Timestamp() + .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + val header = new Header() + .withApplication(app) + .withTimestamp(timestamp) + pmml.setHeader(header) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala new file mode 100644 index 000000000000..29bd689e1185 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionNormalizationMethodType + +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.regression.LassoModel +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.RidgeRegressionModel + +private[mllib] object PMMLModelExportFactory { + + /** + * Factory object to help creating the necessary PMMLModelExport implementation + * taking as input the machine learning model (for example KMeansModel). + */ + def createPMMLModelExport(model: Any): PMMLModelExport = { + model match { + case kmeans: KMeansModel => + new KMeansPMMLModelExport(kmeans) + case linear: LinearRegressionModel => + new GeneralizedLinearPMMLModelExport(linear, "linear regression") + case ridge: RidgeRegressionModel => + new GeneralizedLinearPMMLModelExport(ridge, "ridge regression") + case lasso: LassoModel => + new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") + case svm: SVMModel => + new BinaryClassificationPMMLModelExport( + svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm.getThreshold.getOrElse(0.0)) + case logistic: LogisticRegressionModel => + if (logistic.numClasses == 2) { + new BinaryClassificationPMMLModelExport( + logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT, + logistic.getThreshold.getOrElse(0.5)) + } else { + throw new IllegalArgumentException( + "PMML Export not supported for Multinomial Logistic Regression") + } + case _ => + throw new IllegalArgumentException( + "PMML Export not supported for model: " + model.getClass.getName) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 405bae62ee8b..a2d85a68cd32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution.{ExponentialDistribution, GammaDistribution, LogNormalDistribution, PoissonDistribution} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** @@ -28,17 +28,20 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} * Trait for random data generators that generate i.i.d. data. */ @DeveloperApi +@Since("1.1.0") trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** * Returns an i.i.d. sample as a generic type from an underlying distribution. */ + @Since("1.1.0") def nextValue(): T /** * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ + @Since("1.1.0") def copy(): RandomDataGenerator[T] } @@ -47,17 +50,21 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @DeveloperApi +@Since("1.1.0") class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextDouble() } - override def setSeed(seed: Long) = random.setSeed(seed) + @Since("1.1.0") + override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): UniformGenerator = new UniformGenerator() } @@ -66,17 +73,21 @@ class UniformGenerator extends RandomDataGenerator[Double] { * Generates i.i.d. samples from the standard normal distribution. */ @DeveloperApi +@Since("1.1.0") class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextGaussian() } - override def setSeed(seed: Long) = random.setSeed(seed) + @Since("1.1.0") + override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): StandardNormalGenerator = new StandardNormalGenerator() } @@ -87,16 +98,21 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { * @param mean mean for the Poisson distribution. */ @DeveloperApi -class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.1.0") +class PoissonGenerator @Since("1.1.0") ( + @Since("1.1.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new PoissonDistribution(mean) + @Since("1.1.0") override def nextValue(): Double = rng.sample() + @Since("1.1.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.1.0") override def copy(): PoissonGenerator = new PoissonGenerator(mean) } @@ -107,16 +123,21 @@ class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { * @param mean mean for the exponential distribution. */ @DeveloperApi -class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class ExponentialGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new ExponentialDistribution(mean) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): ExponentialGenerator = new ExponentialGenerator(mean) } @@ -128,16 +149,22 @@ class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] * @param scale scale for the gamma distribution */ @DeveloperApi -class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class GammaGenerator @Since("1.3.0") ( + @Since("1.3.0") val shape: Double, + @Since("1.3.0") val scale: Double) extends RandomDataGenerator[Double] { private val rng = new GammaDistribution(shape, scale) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): GammaGenerator = new GammaGenerator(shape, scale) } @@ -150,15 +177,21 @@ class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGen * @param std standard deviation for the log normal distribution */ @DeveloperApi -class LogNormalGenerator(val mean: Double, val std: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class LogNormalGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double, + @Since("1.3.0") val std: Double) extends RandomDataGenerator[Double] { private val rng = new LogNormalDistribution(mean, std) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 955c593a085d..4dd5ea214d67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} @@ -29,13 +29,14 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. + * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ @Experimental +@Since("1.1.0") object RandomRDDs { /** - * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. + * Generates an RDD comprised of `i.i.d.` samples from the uniform distribution `U(0.0, 1.0)`. * * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. @@ -44,20 +45,22 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformRDD( sc: SparkContext, size: Long, numPartitions: Int = 0, seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** * Java-friendly version of [[RandomRDDs#uniformRDD]]. */ + @Since("1.1.0") def uniformJavaRDD( jsc: JavaSparkContext, size: Long, @@ -69,6 +72,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) } @@ -76,12 +80,13 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) } /** - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * Generates an RDD comprised of `i.i.d.` samples from the standard normal distribution. * * To transform the distribution in the generated RDD from standard normal to some other normal * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. @@ -90,8 +95,9 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ + @Since("1.1.0") def normalRDD( sc: SparkContext, size: Long, @@ -104,6 +110,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalRDD]]. */ + @Since("1.1.0") def normalJavaRDD( jsc: JavaSparkContext, size: Long, @@ -115,6 +122,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) } @@ -122,20 +130,23 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) } /** - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of `i.i.d.` samples from the Poisson distribution with the input + * mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonRDD( sc: SparkContext, mean: Double, @@ -149,6 +160,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonRDD]]. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -161,6 +173,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -172,12 +185,13 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) } /** - * Generates an RDD comprised of i.i.d. samples from the exponential distribution with + * Generates an RDD comprised of `i.i.d.` samples from the exponential distribution with * the input mean. * * @param sc SparkContext used to create the RDD. @@ -185,8 +199,9 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def exponentialRDD( sc: SparkContext, mean: Double, @@ -200,6 +215,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialRDD]]. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -212,6 +228,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -223,22 +240,24 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def exponentialJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(exponentialRDD(jsc.sc, mean, size)) } /** - * Generates an RDD comprised of i.i.d. samples from the gamma distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the gamma distribution with the input * shape and scale. * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution - * @param scale scale parameter (> 0) for the gamma distribution + * @param scale scale parameter (> 0) for the gamma distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def gammaRDD( sc: SparkContext, shape: Double, @@ -253,6 +272,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaRDD]]. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -266,6 +286,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -278,26 +299,28 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaRDD( - jsc: JavaSparkContext, - shape: Double, - scale: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + shape: Double, + scale: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(gammaRDD(jsc.sc, shape, scale, size)) } /** - * Generates an RDD comprised of i.i.d. samples from the log normal distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the log normal distribution with the input * mean and standard deviation * * @param sc SparkContext used to create the RDD. * @param mean mean for the log normal distribution - * @param std standard deviation for the log normal distribution + * @param std standard deviation for the log normal distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def logNormalRDD( sc: SparkContext, mean: Double, @@ -312,6 +335,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalRDD]]. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -325,6 +349,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -337,27 +362,29 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( - jsc: JavaSparkContext, - mean: Double, - std: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + mean: Double, + std: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(logNormalRDD(jsc.sc, mean, std, size)) } /** * :: DeveloperApi :: - * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. * @param generator RandomDataGenerator used to populate the RDD. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomRDD[T: ClassTag]( sc: SparkContext, generator: RandomDataGenerator[T], @@ -370,7 +397,7 @@ object RandomRDDs { // TODO Generate RDD[Vector] from multivariate distributions. /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * uniform distribution on `U(0.0, 1.0)`. * * @param sc SparkContext used to create the RDD. @@ -380,6 +407,7 @@ object RandomRDDs { * @param seed Seed for the RNG that generates the seed for the generator in each partition. * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformVectorRDD( sc: SparkContext, numRows: Long, @@ -393,6 +421,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -405,6 +434,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -416,6 +446,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -424,7 +455,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -432,8 +463,9 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ + @Since("1.1.0") def normalVectorRDD( sc: SparkContext, numRows: Long, @@ -447,6 +479,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -459,6 +492,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -470,6 +504,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -478,7 +513,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from a + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from a * log normal distribution. * * @param sc SparkContext used to create the RDD. @@ -488,8 +523,9 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples. + * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ + @Since("1.3.0") def logNormalVectorRDD( sc: SparkContext, mean: Double, @@ -506,6 +542,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalVectorRDD]]. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -520,6 +557,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -534,6 +572,7 @@ object RandomRDDs { * [[RandomRDDs#logNormalJavaVectorRDD]] with the default number of partitions and * the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -544,7 +583,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -553,8 +592,9 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonVectorRDD( sc: SparkContext, mean: Double, @@ -569,6 +609,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -582,6 +623,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -594,6 +636,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -603,7 +646,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * exponential distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -612,8 +655,9 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def exponentialVectorRDD( sc: SparkContext, mean: Double, @@ -629,6 +673,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialVectorRDD]]. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -642,6 +687,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -655,6 +701,7 @@ object RandomRDDs { * [[RandomRDDs#exponentialJavaVectorRDD]] with the default number of partitions * and the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -665,18 +712,19 @@ object RandomRDDs { /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * gamma distribution with the input shape and scale. * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution. - * @param scale scale parameter (> 0) for the gamma distribution. + * @param scale scale parameter (> 0) for the gamma distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def gammaVectorRDD( sc: SparkContext, shape: Double, @@ -692,6 +740,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaVectorRDD]]. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -706,6 +755,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -719,6 +769,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -731,7 +782,7 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples produced by the * input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. @@ -740,9 +791,10 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomVectorRDD(sc: SparkContext, generator: RandomDataGenerator[Double], numRows: Long, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala new file mode 100644 index 000000000000..1b93e2d764c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * Machine learning specific Pair RDD functions. + */ +@DeveloperApi +class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Serializable { + /** + * Returns the top k (largest) elements for each key from this RDD as defined by the specified + * implicit Ordering[T]. + * If the number of elements for a certain key is less than k, all of them will be returned. + * + * @param num k, the number of top elements to return + * @param ord the implicit ordering for T + * @return an RDD that contains the top k values for each key + */ + def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = { + self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))( + seqOp = (queue, item) => { + queue += item + }, + combOp = (queue1, queue2) => { + queue1 ++= queue2 + } + ).mapValues(_.toArray.sorted(ord.reverse)) // This is an min-heap, so we reverse the order. + } +} + +@DeveloperApi +object MLPairRDDFunctions { + /** Implicit conversion from a pair RDD to MLPairRDDFunctions. */ + implicit def fromPairRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): MLPairRDDFunctions[K, V] = + new MLPairRDDFunctions[K, V](rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 35e81fcb3de0..1facf83d806d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -72,7 +72,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int val w1 = windowSize - 1 // Get the first w1 items of each partition, starting from the second partition. val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true) + parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() var i = 0 var partitionIndex = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index f4f51f2ac521..b27ef1b949e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -18,17 +18,16 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ -@Experimental +@Since("0.8.0") case class Rating(user: Int, product: Int, rating: Double) /** @@ -84,6 +83,9 @@ class ALS private ( private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** checkpoint interval */ + private var checkpointInterval: Int = 10 + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. @@ -135,10 +137,8 @@ class ALS private ( } /** - * :: Experimental :: * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ - @Experimental def setAlpha(alpha: Double): this.type = { this.alpha = alpha this @@ -176,7 +176,7 @@ class ALS private ( /** * :: DeveloperApi :: * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default - * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. + * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement, * at the cost of speed. */ @@ -186,6 +186,19 @@ class ALS private ( this } + /** + * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with + * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps + * with eliminating temporary shuffle files on disk, which can be important when there are many + * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]], + * this setting is ignored. + */ + @DeveloperApi + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -216,6 +229,7 @@ class ALS private ( nonnegative = nonnegative, intermediateRDDStorageLevel = intermediateRDDStorageLevel, finalRDDStorageLevel = StorageLevel.NONE, + checkpointInterval = checkpointInterval, seed = seed) val userFactors = floatUserFactors @@ -242,6 +256,7 @@ class ALS private ( /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. */ +@Since("0.8.0") object ALS { /** * Train a matrix factorization model given an RDD of ratings given by users to some products, @@ -257,6 +272,7 @@ object ALS { * @param blocks level of parallelism to split computation into * @param seed random seed */ + @Since("0.9.1") def train( ratings: RDD[Rating], rank: Int, @@ -281,6 +297,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into */ + @Since("0.8.0") def train( ratings: RDD[Rating], rank: Int, @@ -303,6 +320,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { train(ratings, rank, iterations, lambda, -1) @@ -319,6 +337,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { train(ratings, rank, iterations, 0.01, -1) @@ -339,6 +358,7 @@ object ALS { * @param alpha confidence parameter * @param seed random seed */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -365,6 +385,7 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -389,6 +410,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, lambda, -1, alpha) @@ -406,6 +428,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ed2f8b41bcae..ba4cfdcd9f1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,26 @@ package org.apache.spark.mllib.recommendation +import java.io.IOException import java.lang.{Integer => JavaInteger} -import org.jblas.DoubleMatrix +import scala.collection.mutable -import org.apache.spark.Logging +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.storage.StorageLevel /** @@ -38,10 +51,12 @@ import org.apache.spark.storage.StorageLevel * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. */ +@Since("0.8.0") class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + val productFeatures: RDD[(Int, Array[Double])]) + extends Saveable with Serializable with Logging { require(rank > 0) validateFeatures("User", userFeatures) @@ -49,7 +64,7 @@ class MatrixFactorizationModel( /** Validates factors and warns users if there are performance concerns. */ private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = { - require(features.first()._2.size == rank, + require(features.first()._2.length == rank, s"$name feature dimension does not match the rank $rank.") if (features.partitioner.isEmpty) { logWarning(s"$name factor does not have a partitioner. " @@ -61,35 +76,78 @@ class MatrixFactorizationModel( } /** Predict the rating of one user for one product. */ + @Since("0.8.0") def predict(user: Int, product: Int): Double = { - val userVector = new DoubleMatrix(userFeatures.lookup(user).head) - val productVector = new DoubleMatrix(productFeatures.lookup(product).head) - userVector.dot(productVector) + val userVector = userFeatures.lookup(user).head + val productVector = productFeatures.lookup(product).head + blas.ddot(rank, userVector, 1, productVector, 1) } /** - * Predict the rating of many users for many products. - * The output RDD has an element per each element in the input RDD (including all duplicates) - * unless a user or product is missing in the training set. - * - * @param usersProducts RDD of (user, product) pairs. - * @return RDD of Ratings. - */ + * Return approximate numbers of users and products in the given usersProducts tuples. + * This method is based on `countApproxDistinct` in class `RDD`. + * + * @param usersProducts RDD of (user, product) pairs. + * @return approximate numbers of users and products. + */ + private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = { + val zeroCounterUser = new HyperLogLogPlus(4, 0) + val zeroCounterProduct = new HyperLogLogPlus(4, 0) + val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))( + (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => { + hllTuple._1.offer(v._1) + hllTuple._2.offer(v._2) + hllTuple + }, + (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => { + h1._1.addAll(h2._1) + h1._2.addAll(h2._2) + h1 + }) + (aggregated._1.cardinality(), aggregated._2.cardinality()) + } + + /** + * Predict the rating of many users for many products. + * The output RDD has an element per each element in the input RDD (including all duplicates) + * unless a user or product is missing in the training set. + * + * @param usersProducts RDD of (user, product) pairs. + * @return RDD of Ratings. + */ + @Since("0.9.0") def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { - val users = userFeatures.join(usersProducts).map{ - case (user, (uFeatures, product)) => (product, (user, uFeatures)) - } - users.join(productFeatures).map { - case (product, ((user, uFeatures), pFeatures)) => - val userVector = new DoubleMatrix(uFeatures) - val productVector = new DoubleMatrix(pFeatures) - Rating(user, product, userVector.dot(productVector)) + // Previously the partitions of ratings are only based on the given products. + // So if the usersProducts given for prediction contains only few products or + // even one product, the generated ratings will be pushed into few or single partition + // and can't use high parallelism. + // Here we calculate approximate numbers of users and products. Then we decide the + // partitions should be based on users or products. + val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts) + + if (usersCount < productsCount) { + val users = userFeatures.join(usersProducts).map { + case (user, (uFeatures, product)) => (product, (user, uFeatures)) + } + users.join(productFeatures).map { + case (product, ((user, uFeatures), pFeatures)) => + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + } + } else { + val products = productFeatures.join(usersProducts.map(_.swap)).map { + case (product, (pFeatures, user)) => (user, (product, pFeatures)) + } + products.join(userFeatures).map { + case (user, ((product, pFeatures), uFeatures)) => + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + } } } /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. */ + @Since("1.2.0") def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() } @@ -105,8 +163,9 @@ class MatrixFactorizationModel( * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. */ + @Since("1.1.0") def recommendProducts(user: Int, num: Int): Array[Rating] = - recommend(userFeatures.lookup(user).head, productFeatures, num) + MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) /** @@ -121,18 +180,213 @@ class MatrixFactorizationModel( * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. */ + @Since("1.1.0") def recommendUsers(product: Int, num: Int): Array[Rating] = - recommend(productFeatures.lookup(product).head, userFeatures, num) + MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + protected override val formatVersion: String = "1.0" + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + MatrixFactorizationModel.SaveLoadV1_0.save(this, path) + } + + /** + * Recommends topK products for all users. + * + * @param num how many products to return for every user. + * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of + * rating objects which contains the same userId, recommended productID and a "score" in the + * rating field. Semantics of score is same as recommendProducts API + */ + @Since("1.4.0") + def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { + MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { + case (user, top) => + val ratings = top.map { case (product, rating) => Rating(user, product, rating) } + (user, ratings) + } + } + + + /** + * Recommends topK users for all products. + * + * @param num how many users to return for every product. + * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array + * of rating objects which contains the recommended userId, same productID and a "score" in the + * rating field. Semantics of score is same as recommendUsers API + */ + @Since("1.4.0") + def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { + MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { + case (product, top) => + val ratings = top.map { case (user, rating) => Rating(user, product, rating) } + (product, ratings) + } + } +} + +@Since("1.3.0") +object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { + + import org.apache.spark.mllib.util.Loader._ + + /** + * Makes recommendations for a single user (or product). + */ private def recommend( recommendToFeatures: Array[Double], recommendableFeatures: RDD[(Int, Array[Double])], num: Int): Array[(Int, Double)] = { - val recommendToVector = new DoubleMatrix(recommendToFeatures) - val scored = recommendableFeatures.map { case (id,features) => - (id, recommendToVector.dot(new DoubleMatrix(features))) + val scored = recommendableFeatures.map { case (id, features) => + (id, blas.ddot(features.length, recommendToFeatures, 1, features, 1)) } scored.top(num)(Ordering.by(_._2)) } + + /** + * Makes recommendations for all users (or products). + * @param rank rank + * @param srcFeatures src features to receive recommendations + * @param dstFeatures dst features used to make recommendations + * @param num number of recommendations for each record + * @return an RDD of (srcId: Int, recommendations), where recommendations are stored as an array + * of (dstId, rating) pairs. + */ + private def recommendForAll( + rank: Int, + srcFeatures: RDD[(Int, Array[Double])], + dstFeatures: RDD[(Int, Array[Double])], + num: Int): RDD[(Int, Array[(Int, Double)])] = { + val srcBlocks = blockify(rank, srcFeatures) + val dstBlocks = blockify(rank, dstFeatures) + val ratings = srcBlocks.cartesian(dstBlocks).flatMap { + case ((srcIds, srcFactors), (dstIds, dstFactors)) => + val m = srcIds.length + val n = dstIds.length + val ratings = srcFactors.transpose.multiply(dstFactors) + val output = new Array[(Int, (Int, Double))](m * n) + var k = 0 + ratings.foreachActive { (i, j, r) => + output(k) = (srcIds(i), (dstIds(j), r)) + k += 1 + } + output.toSeq + } + ratings.topByKey(num)(Ordering.by(_._2)) + } + + /** + * Blockifies features to use Level-3 BLAS. + */ + private def blockify( + rank: Int, + features: RDD[(Int, Array[Double])]): RDD[(Array[Int], DenseMatrix)] = { + val blockSize = 4096 // TODO: tune the block size + val blockStorage = rank * blockSize + features.mapPartitions { iter => + iter.grouped(blockSize).map { grouped => + val ids = mutable.ArrayBuilder.make[Int] + ids.sizeHint(blockSize) + val factors = mutable.ArrayBuilder.make[Double] + factors.sizeHint(blockStorage) + var i = 0 + grouped.foreach { case (id, factor) => + ids += id + factors ++= factor + i += 1 + } + (ids.result(), new DenseMatrix(rank, i, factors.result())) + } + } + } + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") + override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => + throw new IOException("MatrixFactorizationModel.load did not recognize model with" + + s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private[recommendation] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[recommendation] + val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + + /** + * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ + def save(model: MatrixFactorizationModel, path: String): Unit = { + val sc = model.userFeatures.sparkContext + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + model.userFeatures.toDF("id", "features").write.parquet(userPath(path)) + model.productFeatures.toDF("id", "features").write.parquet(productPath(path)) + } + + def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rank = (metadata \ "rank").extract[Int] + val userFeatures = sqlContext.read.parquet(userPath(path)) + .map { case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + val productFeatures = sqlContext.read.parquet(productPath(path)) + .map { case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + new MatrixFactorizationModel(rank, userFeatures, productFeatures) + } + + private def userPath(path: String): String = { + new Path(dataPath(path), "user").toUri.toString + } + + private def productPath(path: String): String = { + new Path(dataPath(path), "product").toUri.toString + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 17de215b97f9..7e3b4d5648fe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD @@ -34,9 +34,13 @@ import org.apache.spark.storage.StorageLevel * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ +@Since("0.8.0") @DeveloperApi -abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) +abstract class GeneralizedLinearModel @Since("1.0.0") ( + @Since("1.0.0") val weights: Vector, + @Since("0.8.0") val intercept: Double) extends Serializable { /** @@ -53,7 +57,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. @@ -71,26 +77,39 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * * @param testData array representing a single data point * @return Double prediction from the trained model + * */ + @Since("1.0.0") def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } - override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) + /** + * Print a summary of the model. + */ + override def toString: String = { + s"${this.getClass.getName}: intercept = ${intercept}, numFeatures = ${weights.size}" + } } /** * :: DeveloperApi :: * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM). * This class should be extended with an Optimizer to create a new GLM. + * */ +@Since("0.8.0") @DeveloperApi abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] extends Logging with Serializable { protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() - /** The optimizer to solve the problem. */ + /** + * The optimizer to solve the problem. + * + */ + @Since("0.8.0") def optimizer: Optimizer /** Whether to add intercept (default: false). */ @@ -123,10 +142,17 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ private var useFeatureScaling = false + /** + * The dimension of training features. + * + */ + @Since("1.4.0") + def getNumFeatures: Int = this.numFeatures + /** * The dimension of training features. */ - protected var numFeatures: Int = 0 + protected var numFeatures: Int = -1 /** * Set if the algorithm should use feature scaling to improve the convergence during optimization. @@ -141,10 +167,19 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ protected def createModel(weights: Vector, intercept: Double): M + /** + * Get if the algorithm uses addIntercept + * + */ + @Since("1.4.0") + def isAddIntercept: Boolean = this.addIntercept + /** * Set if the algorithm should add an intercept. Default false. * We set the default to false because adding the intercept will cause memory allocation. + * */ + @Since("0.8.0") def setIntercept(addIntercept: Boolean): this.type = { this.addIntercept = addIntercept this @@ -152,7 +187,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Set if the algorithm should validate data before training. Default true. + * */ + @Since("0.8.0") def setValidateData(validateData: Boolean): this.type = { this.validateData = validateData this @@ -161,9 +198,13 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. + * */ + @Since("0.8.0") def run(input: RDD[LabeledPoint]): M = { - numFeatures = input.first().features.size + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } /** * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights, @@ -178,11 +219,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ val initialWeights = { if (numOfLinearPredictor == 1) { - Vectors.dense(new Array[Double](numFeatures)) + Vectors.zeros(numFeatures) } else if (addIntercept) { - Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else { - Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + Vectors.zeros(numFeatures * numOfLinearPredictor) } } run(input, initialWeights) @@ -191,9 +232,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. + * */ + @Since("1.0.0") def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - numFeatures = input.first().features.size + + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -225,26 +271,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Currently, it's only enabled in LogisticRegressionWithLBFGS */ val scaler = if (useFeatureScaling) { - (new StandardScaler(withStd = true, withMean = false)).fit(input.map(x => x.features)) + new StandardScaler(withStd = true, withMean = false).fit(input.map(_.features)) } else { null } // Prepend an extra variable consisting of all 1.0's for the intercept. - val data = if (addIntercept) { - if (useFeatureScaling) { - input.map(labeledPoint => - (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) - } else { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) - } - } else { - if (useFeatureScaling) { - input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + // TODO: Apply feature scaling to the weight vector instead of input data. + val data = + if (addIntercept) { + if (useFeatureScaling) { + input.map(lp => (lp.label, appendBias(scaler.transform(lp.features)))).cache() + } else { + input.map(lp => (lp.label, appendBias(lp.features))).cache() + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(lp => (lp.label, scaler.transform(lp.features))).cache() + } else { + input.map(lp => (lp.label, lp.features)) + } } - } /** * TODO: For better convergence, in logistic regression, the intercepts should be computed diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 5ed6477bae3b..877d31ba4130 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -21,12 +21,24 @@ import java.io.Serializable import java.lang.{Double => JDouble} import java.util.Arrays.binarySearch +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext /** + * :: Experimental :: + * * Regression model for isotonic regression. * * @param boundaries Array of boundaries for which predictions are known. @@ -34,11 +46,14 @@ import org.apache.spark.rdd.RDD * @param predictions Array of predictions associated to the boundaries at the same index. * Results of isotonic regression and therefore monotone. * @param isotonic indicates whether this is isotonic or antitonic. + * */ -class IsotonicRegressionModel ( - val boundaries: Array[Double], - val predictions: Array[Double], - val isotonic: Boolean) extends Serializable { +@Since("1.3.0") +@Experimental +class IsotonicRegressionModel @Since("1.3.0") ( + @Since("1.3.0") val boundaries: Array[Double], + @Since("1.3.0") val predictions: Array[Double], + @Since("1.3.0") val isotonic: Boolean) extends Serializable with Saveable { private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse @@ -46,10 +61,21 @@ class IsotonicRegressionModel ( assertOrdered(boundaries) assertOrdered(predictions)(predictionOrd) + /** + * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. + */ + @Since("1.4.0") + def this(boundaries: java.lang.Iterable[Double], + predictions: java.lang.Iterable[Double], + isotonic: java.lang.Boolean) = { + this(boundaries.asScala.toArray, predictions.asScala.toArray, isotonic) + } + /** Asserts the input array is monotone with the given ordering. */ private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = { var i = 1 - while (i < xs.length) { + val len = xs.length + while (i < len) { require(ord.compare(xs(i - 1), xs(i)) <= 0, s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.") i += 1 @@ -62,7 +88,9 @@ class IsotonicRegressionModel ( * * @param testData Features to be labeled. * @return Predicted labels. + * */ + @Since("1.3.0") def predict(testData: RDD[Double]): RDD[Double] = { testData.map(predict) } @@ -73,7 +101,9 @@ class IsotonicRegressionModel ( * * @param testData Features to be labeled. * @return Predicted labels. + * */ + @Since("1.3.0") def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) } @@ -93,7 +123,9 @@ class IsotonicRegressionModel ( * 3) If testData falls between two values in boundary array then prediction is treated * as piecewise linear function and interpolated value is returned. In case there are * multiple values with the same boundary then the same rules as in 2) are used. + * */ + @Since("1.3.0") def predict(testData: Double): Double = { def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { @@ -120,9 +152,89 @@ class IsotonicRegressionModel ( predictions(foundIndex) } } + + /** A convenient method for boundaries called by the Python API. */ + private[mllib] def boundaryVector: Vector = Vectors.dense(boundaries) + + /** A convenient method for boundaries called by the Python API. */ + private[mllib] def predictionVector: Vector = Vectors.dense(predictions) + + @Since("1.4.0") + override def save(sc: SparkContext, path: String): Unit = { + IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("1.4.0") +object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { + + import org.apache.spark.mllib.util.Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" + + /** Model data for model import/export */ + case class Data(boundary: Double, prediction: Double) + + def save( + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean): Unit = { + val sqlContext = new SQLContext(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("isotonic" -> isotonic))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + sqlContext.createDataFrame( + boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } + ).write.parquet(dataPath(path)) + } + + def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.read.parquet(dataPath(path)) + + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("boundary", "prediction").collect() + val (boundaries, predictions) = dataArray.map { x => + (x.getDouble(0), x.getDouble(1)) + }.toList.sortBy(_._1).unzip + (boundaries.toArray, predictions.toArray) + } + } + + @Since("1.4.0") + override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val isotonic = (metadata \ "isotonic").extract[Boolean] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) + new IsotonicRegressionModel(boundaries, predictions, isotonic) + case _ => throw new Exception( + s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)" + ) + } + } } /** + * :: Experimental :: + * * Isotonic regression. * Currently implemented using parallelized pool adjacent violators algorithm. * Only univariate (single feature) algorithm supported. @@ -130,14 +242,18 @@ class IsotonicRegressionModel ( * Sequential PAV implementation based on: * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. - * Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf + * Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] * * Sequential PAV parallelization based on: * Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. * "An approach to parallelizing isotonic regression." * Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. - * Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf + * Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] + * + * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] */ +@Experimental +@Since("1.3.0") class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { /** @@ -145,6 +261,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * * @return New instance of IsotonicRegression. */ + @Since("1.3.0") def this() = this(true) /** @@ -153,6 +270,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. * @return This instance of IsotonicRegression. */ + @Since("1.3.0") def setIsotonic(isotonic: Boolean): this.type = { this.isotonic = isotonic this @@ -168,6 +286,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { val preprocessedInput = if (isotonic) { input @@ -193,6 +312,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) } @@ -229,11 +349,12 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali } var i = 0 - while (i < input.length) { + val len = input.length + while (i < len) { var j = i // Find monotonicity violating sequence, if any. - while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) { + while (j < len - 1 && input(j)._1 > input(j + 1)._1) { j = j + 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 2067b36f246b..c284ad232537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -29,21 +30,28 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. */ +@Since("0.8.0") @BeanInfo -case class LabeledPoint(label: Double, features: Vector) { +case class LabeledPoint @Since("1.0.0") ( + @Since("0.8.0") label: Double, + @Since("1.0.0") features: Vector) { override def toString: String = { - "(%s,%s)".format(label, features) + s"($label,$features)" } } /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. + * */ +@Since("1.1.0") object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. + * */ + @Since("1.1.0") def parse(s: String): LabeledPoint = { if (s.startsWith("(")) { NumericParser.parse(s) match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 8ecd5c6ad93c..a9aba173fa0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,9 +17,13 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -27,12 +31,14 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ -class LassoModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class LassoModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, @@ -40,16 +46,45 @@ class LassoModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("1.3.0") +object LassoModel extends Loader[LassoModel] { + + @Since("1.3.0") + override def load(sc: SparkContext, path: String): LassoModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LassoModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LassoModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LassoWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -59,6 +94,7 @@ class LassoWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new L1Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -69,6 +105,7 @@ class LassoWithSGD private ( * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -78,7 +115,9 @@ class LassoWithSGD private ( /** * Top-level methods for calling Lasso. + * */ +@Since("0.8.0") object LassoWithSGD { /** @@ -95,7 +134,9 @@ object LassoWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -118,7 +159,9 @@ object LassoWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -139,7 +182,9 @@ object LassoWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -157,7 +202,9 @@ object LassoWithSGD { * matrix A as well as the corresponding right hand side label y * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377ff..4996ace5df85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,20 +17,28 @@ package org.apache.spark.mllib.regression -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ -class LinearRegressionModel ( - override val weights: Vector, - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { +@Since("0.8.0") +class LinearRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable + with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, @@ -38,17 +46,46 @@ class LinearRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("1.3.0") +object LinearRegressionModel extends Loader[LinearRegressionModel] { + + @Since("1.3.0") + override def load(sc: SparkContext, path: String): LinearRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LinearRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LinearRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a linear regression model with no regularization using Stochastic Gradient Descent. * This solves the least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + * f(weights) = 1/n ||A weights-y||^2^ * (which is the mean squared error). * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -57,6 +94,7 @@ class LinearRegressionWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new SimpleUpdater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -66,6 +104,7 @@ class LinearRegressionWithSGD private[mllib] ( * Construct a LinearRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -75,7 +114,9 @@ class LinearRegressionWithSGD private[mllib] ( /** * Top-level methods for calling LinearRegression. + * */ +@Since("0.8.0") object LinearRegressionWithSGD { /** @@ -91,7 +132,9 @@ object LinearRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -112,7 +155,9 @@ object LinearRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -131,7 +176,9 @@ object LinearRegressionWithSGD { * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -148,7 +195,9 @@ object LinearRegressionWithSGD { * matrix A as well as the corresponding right hand side label y * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LinearRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e7a..0e72d6591ce8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,11 +17,14 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.json4s.{DefaultFormats, JValue} + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +@Since("0.8.0") @Experimental trait RegressionModel extends Serializable { /** @@ -29,7 +32,9 @@ trait RegressionModel extends Serializable { * * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -37,14 +42,30 @@ trait RegressionModel extends Serializable { * * @param testData array representing a single data point * @return Double prediction from the trained model + * */ + @Since("1.0.0") def predict(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object RegressionModel { + + /** + * Helper method for loading GLM regression model metadata. + * @return numFeatures + */ + def getNumFeatures(metadata: JValue): Int = { + implicit val formats = DefaultFormats + (metadata \ "numFeatures").extract[Int] + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 076ba35051c9..0a44ff559d55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,22 +17,29 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.optimization._ +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD + /** * Regression model trained using RidgeRegression. * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ -class RidgeRegressionModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class RidgeRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, @@ -40,16 +47,45 @@ class RidgeRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("1.3.0") +object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + + @Since("1.3.0") + override def load(sc: SparkContext, path: String): RidgeRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new RidgeRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"RidgeRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. - * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + * This solves the l2-regularized least squares regression formulation + * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class RidgeRegressionWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -59,7 +95,7 @@ class RidgeRegressionWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new SquaredL2Updater() - + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -70,6 +106,7 @@ class RidgeRegressionWithSGD private ( * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -79,7 +116,9 @@ class RidgeRegressionWithSGD private ( /** * Top-level methods for calling RidgeRegression. + * */ +@Since("0.8.0") object RidgeRegressionWithSGD { /** @@ -95,7 +134,9 @@ object RidgeRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -117,7 +158,9 @@ object RidgeRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -137,13 +180,15 @@ object RidgeRegressionWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 0.01) + train(input, numIterations, stepSize, regParam, 1.0) } /** @@ -154,7 +199,9 @@ object RidgeRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 44a8dbb994cf..73948b2d9851 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -20,8 +20,10 @@ package org.apache.spark.mllib.regression import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream /** @@ -39,31 +41,37 @@ import org.apache.spark.streaming.dstream.DStream * * For example usage, see `StreamingLinearRegressionWithSGD`. * - * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn + * NOTE: In some use cases, the order in which trainOn and predictOn * are called in an application will affect the results. When called on * the same DStream, if trainOn is called before predictOn, when new data * arrive the model will update and the prediction will be based on the new * model. Whereas if predictOn is called first, the prediction will use the model * from the previous update. * - * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this + * NOTE: It is ok to call predictOn repeatedly on multiple streams; this * will generate predictions for each one all using the current model. * It is also ok to call trainOn on different streams; this will update * the model using each of the different sources, in sequence. * + * */ +@Since("1.1.0") @DeveloperApi abstract class StreamingLinearAlgorithm[ M <: GeneralizedLinearModel, A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = None + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A - /** Return the latest model. */ + /** + * Return the latest model. + * + */ + @Since("1.1.0") def latestModel(): M = { model.get } @@ -76,40 +84,52 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing labeled data */ - def trainOn(data: DStream[LabeledPoint]) { + @Since("1.1.0") + def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - val initialWeights = - model match { - case Some(m) => - m.weights - case None => - val numFeatures = rdd.first().features.size - Vectors.dense(numFeatures) + if (!rdd.isEmpty) { + model = Some(algorithm.run(rdd, model.get.weights)) + logInfo(s"Model updated at time ${time.toString}") + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") } - model = Some(algorithm.run(rdd, initialWeights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + logInfo(s"Current model: weights, ${display}") } - logInfo("Current model: weights, %s".format (display)) } } + /** + * Java-friendly version of `trainOn`. + */ + @Since("1.3.0") + def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) + /** * Use the model to make predictions on batches of data from a DStream * * @param data DStream containing feature vectors * @return DStream containing predictions + * */ + @Since("1.1.0") def predictOn(data: DStream[Vector]): DStream[Double] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.get.predict) + data.map{x => model.get.predict(x)} + } + + /** + * Java-friendly version of `predictOn`. + * + */ + @Since("1.3.0") + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } /** @@ -117,11 +137,25 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing feature vectors * @tparam K key type * @return DStream containing the input keys and the predictions as values + * */ + @Since("1.1.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.get.predict) + data.mapValues{x => model.get.predict(x)} + } + + + /** + * Java-friendly version of `predictOnValues`. + * + */ + @Since("1.3.0") + def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Double)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e5e6301127a2..fe1d487cdd07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.Vector /** @@ -39,9 +39,9 @@ import org.apache.spark.mllib.linalg.Vector * .setNumIterations(10) * .setInitialWeights(Vectors.dense(...)) * .trainOn(DStream) - * */ @Experimental +@Since("1.1.0") class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -55,32 +55,56 @@ class StreamingLinearRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.1.0") def this() = this(0.1, 50, 1.0) + @Since("1.1.0") val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) - /** Set the step size for gradient descent. Default: 0.1. */ + protected var model: Option[LinearRegressionModel] = None + + /** + * Set the step size for gradient descent. Default: 0.1. + */ + @Since("1.1.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } - /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + /** + * Set the number of iterations of gradient descent to run per update. Default: 50. + */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } - /** Set the fraction of each batch to use for updates. Default: 1.0. */ + /** + * Set the fraction of each batch to use for updates. Default: 1.0. + */ + @Since("1.1.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } - /** Set the initial weights. Default: [0.0, 0.0]. */ + /** + * Set the initial weights. + */ + @Since("1.1.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } + /** + * Set the convergence tolerance. Default: 0.001. + */ + @Since("1.5.0") + def setConvergenceTol(tolerance: Double): this.type = { + this.algorithm.optimizer.setConvergenceTol(tolerance) + this + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala new file mode 100644 index 000000000000..317d3a570263 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * Helper methods for import/export of GLM regression models. + */ +private[regression] object GLMRegressionModel { + + object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) + + /** + * Helper method for saving GLM regression model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> weights.size))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.write.parquet(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM regression model data. + * @param modelClass String name for model class (used for error messages) + * @param numFeatures Number of features, to be checked against loaded data. + * The length of the weights vector should equal numFeatures. + */ + def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.read.parquet(datapath) + val dataArray = dataRDD.select("weights", "intercept").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") + data match { + case Row(weights: Vector, intercept: Double) => + assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + + s" found ${weights.size} features when loading $modelClass weights from $datapath") + Data(weights, intercept) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala new file mode 100644 index 000000000000..4a856f7f3434 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Kernel density estimation. Given a sample from a population, estimate its probability density + * function at each of the given evaluation points using kernels. Only Gaussian kernel is supported. + * + * Scala example: + * + * {{{ + * val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0)) + * val kd = new KernelDensity() + * .setSample(sample) + * .setBandwidth(3.0) + * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) + * }}} + */ +@Since("1.4.0") +@Experimental +class KernelDensity extends Serializable { + + import KernelDensity._ + + /** Bandwidth of the kernel function. */ + private var bandwidth: Double = 1.0 + + /** A sample from a population. */ + private var sample: RDD[Double] = _ + + /** + * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). + */ + @Since("1.4.0") + def setBandwidth(bandwidth: Double): this.type = { + require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") + this.bandwidth = bandwidth + this + } + + /** + * Sets the sample to use for density estimation. + */ + @Since("1.4.0") + def setSample(sample: RDD[Double]): this.type = { + this.sample = sample + this + } + + /** + * Sets the sample to use for density estimation (for Java users). + */ + @Since("1.4.0") + def setSample(sample: JavaRDD[java.lang.Double]): this.type = { + this.sample = sample.rdd.asInstanceOf[RDD[Double]] + this + } + + /** + * Estimates probability density function at the given array of points. + */ + @Since("1.4.0") + def estimate(points: Array[Double]): Array[Double] = { + val sample = this.sample + val bandwidth = this.bandwidth + + require(sample != null, "Must set sample before calling estimate.") + + val n = points.length + // This gets used in each Gaussian PDF computation, so compute it up front + val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi) + val (densities, count) = sample.aggregate((new Array[Double](n), 0L))( + (x, y) => { + var i = 0 + while (i < n) { + x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i)) + i += 1 + } + (x._1, x._2 + 1) + }, + (x, y) => { + blas.daxpy(n, 1.0, y._1, 1, x._1, 1) + (x._1, x._2 + y._2) + }) + blas.dscal(n, 1.0 / count, densities, 1) + densities + } +} + +private object KernelDensity { + + /** Evaluates the PDF of a normal distribution. */ + def normPdf( + mean: Double, + standardDeviation: Double, + logStandardDeviationPlusHalfLog2Pi: Double, + x: Double): Double = { + val x0 = x - mean + val x1 = x0 / standardDeviation + val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi + math.exp(logDensity) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fcc2a148791b..51b713e263e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector} /** @@ -34,6 +34,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. */ +@Since("1.1.0") @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -53,6 +54,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. */ + @Since("1.1.0") def add(sample: Vector): this.type = { if (n == 0) { require(sample.size > 0, s"Vector should have dimension larger than zero.") @@ -70,23 +72,30 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") + val localCurrMean = currMean + val localCurrM2n = currM2n + val localCurrM2 = currM2 + val localCurrL1 = currL1 + val localNnz = nnz + val localCurrMax = currMax + val localCurrMin = currMin sample.foreachActive { (index, value) => if (value != 0.0) { - if (currMax(index) < value) { - currMax(index) = value + if (localCurrMax(index) < value) { + localCurrMax(index) = value } - if (currMin(index) > value) { - currMin(index) = value + if (localCurrMin(index) > value) { + localCurrMin(index) = value } - val prevMean = currMean(index) + val prevMean = localCurrMean(index) val diff = value - prevMean - currMean(index) = prevMean + diff / (nnz(index) + 1.0) - currM2n(index) += (value - currMean(index)) * diff - currM2(index) += value * value - currL1(index) += math.abs(value) + localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) + localCurrM2n(index) += (value - localCurrMean(index)) * diff + localCurrM2(index) += value * value + localCurrL1(index) += math.abs(value) - nnz(index) += 1.0 + localNnz(index) += 1.0 } } @@ -101,6 +110,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. */ + @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + @@ -130,18 +140,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.clone - this.currM2n = other.currM2n.clone - this.currM2 = other.currM2.clone - this.currL1 = other.currL1.clone + this.currMean = other.currMean.clone() + this.currM2n = other.currM2n.clone() + this.currM2 = other.currM2.clone() + this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt - this.nnz = other.nnz.clone - this.currMax = other.currMax.clone - this.currMin = other.currMin.clone + this.nnz = other.nnz.clone() + this.currMax = other.currMax.clone() + this.currMin = other.currMin.clone() } this } + /** + * Sample mean of each dimension. + * + */ + @Since("1.1.0") override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -154,6 +169,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMean) } + /** + * Sample variance of each dimension. + * + */ + @Since("1.1.0") override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -165,7 +185,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if (denominator > 0.0) { val deltaMean = currMean var i = 0 - while (i < currM2n.size) { + val len = currM2n.length + while (i < len) { realVariance(i) = currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt realVariance(i) /= denominator @@ -175,14 +196,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realVariance) } + /** + * Sample size. + * + */ + @Since("1.1.0") override def count: Long = totalCnt + /** + * Number of nonzero elements in each dimension. + * + */ + @Since("1.1.0") override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } + /** + * Maximum value of each dimension. + * + */ + @Since("1.1.0") override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -194,6 +230,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMax) } + /** + * Minimum value of each dimension. + * + */ + @Since("1.1.0") override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -205,19 +246,30 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMin) } + /** + * L2 (Euclidian) norm of each dimension. + * + */ + @Since("1.2.0") override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) var i = 0 - while (i < currM2.size) { + val len = currM2.length + while (i < len) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } Vectors.dense(realMagnitude) } + /** + * L1 norm of each dimension. + * + */ + @Since("1.2.0") override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 6a364c93284a..39a16fb743d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -17,50 +17,60 @@ package org.apache.spark.mllib.stat +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. */ +@Since("1.0.0") trait MultivariateStatisticalSummary { /** * Sample mean vector. */ + @Since("1.0.0") def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. */ + @Since("1.0.0") def variance: Vector /** * Sample size. */ + @Since("1.0.0") def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. */ + @Since("1.0.0") def numNonzeros: Vector /** * Maximum value of each column. */ + @Since("1.0.0") def max: Vector /** * Minimum value of each column. */ + @Since("1.0.0") def min: Vector /** * Euclidean magnitude of each column */ + @Since("1.2.0") def normL2: Vector /** * L1 norm of each column */ + @Since("1.2.0") def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf4e807b4cf..84d64a5bfb38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -17,45 +17,48 @@ package org.apache.spark.mllib.stat -import org.apache.spark.annotation.Experimental +import scala.annotation.varargs + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations -import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest, + KolmogorovSmirnovTestResult} import org.apache.spark.rdd.RDD /** + * :: Experimental :: * API for statistical functions in MLlib. */ +@Since("1.1.0") @Experimental object Statistics { /** - * :: Experimental :: * Computes column-wise summary statistics for the input RDD[Vector]. * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. */ - @Experimental + @Since("1.1.0") def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } /** - * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ - @Experimental + @Since("1.1.0") def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** - * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -69,11 +72,10 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ - @Experimental + @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** - * :: Experimental :: * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * @@ -84,11 +86,17 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ - @Experimental + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** - * :: Experimental :: + * Java-friendly version of [[corr()]] + */ + @Since("1.4.1") + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) + + /** * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -99,14 +107,20 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` - *@return A Double containing the correlation between the two input RDD[Double]s using the + * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ - @Experimental + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** - * :: Experimental :: + * Java-friendly version of [[corr()]] + */ + @Since("1.4.1") + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) + + /** * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * @@ -120,13 +134,12 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental + @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * @@ -136,11 +149,10 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental + @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** - * :: Experimental :: * Conduct Pearson's independence test on the input contingency matrix, which cannot contain * negative entries or columns or rows that sum up to 0. * @@ -148,11 +160,10 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental + @Since("1.1.0") def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** - * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which * the chi-squared statistic is computed. All label and feature values must be categorical. @@ -162,8 +173,59 @@ object Statistics { * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. */ - @Experimental + @Since("1.1.0") def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } + + /** Java-friendly version of [[chiSqTest()]] */ + @Since("1.5.0") + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) + + /** + * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]] + * + * @param data an `RDD[Double]` containing the sample of data to test + * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value + * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test + * statistic, p-value, and null hypothesis. + */ + @Since("1.5.0") + def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double) + : KolmogorovSmirnovTestResult = { + KolmogorovSmirnovTest.testOneSample(data, cdf) + } + + /** + * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability + * distribution equality. Currently supports the normal distribution, taking as parameters + * the mean and standard deviation. + * (distName = "norm") + * @param data an `RDD[Double]` containing the sample of data to test + * @param distName a `String` name for a theoretical distribution + * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test + * statistic, p-value, and null hypothesis. + */ + @Since("1.5.0") + @varargs + def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*) + : KolmogorovSmirnovTestResult = { + KolmogorovSmirnovTest.testOneSample(data, distName, params: _*) + } + + /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @Since("1.5.0") + @varargs + def kolmogorovSmirnovTest( + data: JavaDoubleRDD, + distName: String, + params: Double*): KolmogorovSmirnovTestResult = { + kolmogorovSmirnovTest(data.rdd.asInstanceOf[RDD[Double]], distName, params: _*) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala index 23b291eee070..8a821d1b23ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -101,7 +101,7 @@ private[stat] object PearsonCorrelation extends Correlation with Logging { Matrices.fromBreeze(cov) } - private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = { - math.abs(value) <= threshhold + private def closeToZero(value: Double, threshold: Double = 1e-12): Boolean = { + math.abs(value) <= threshold } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index fd186b5ee6f7..92a5af708d04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,9 +17,9 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym} +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} -import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.mllib.util.MLUtils @@ -29,102 +29,107 @@ import org.apache.spark.mllib.util.MLUtils * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]]) - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ +@Since("1.3.0") @DeveloperApi -class MultivariateGaussian ( - val mu: Vector, - val sigma: Matrix) extends Serializable { +class MultivariateGaussian @Since("1.3.0") ( + @Since("1.3.0") val mu: Vector, + @Since("1.3.0") val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - + private val breezeMu = mu.toBreeze.toDenseVector - + /** * private[mllib] constructor - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) } - + /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t - * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - - /** Returns density of this multivariate Gaussian at given point, x */ + + /** Returns density of this multivariate Gaussian at given point, x + */ + @Since("1.3.0") def pdf(x: Vector): Double = { - pdf(x.toBreeze.toDenseVector) + pdf(x.toBreeze) } - - /** Returns the log-density of this multivariate Gaussian at given point, x */ + + /** Returns the log-density of this multivariate Gaussian at given point, x + */ + @Since("1.3.0") def logpdf(x: Vector): Double = { - logpdf(x.toBreeze.toDenseVector) + logpdf(x.toBreeze) } - + /** Returns density of this multivariate Gaussian at given point, x */ - private[mllib] def pdf(x: DBV[Double]): Double = { + private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ - private[mllib] def logpdf(x: DBV[Double]): Double = { + private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 } - + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. - * - * We here compute distribution-fixed parts + * + * We here compute distribution-fixed parts * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and * D^(-1/2)^ * U, where sigma = U * D * U.t - * + * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, * we can use the eigendecomposition. We also do not compute the inverse directly; noting - * that - * + * that + * * sigma = U * D * U.t - * inv(Sigma) = U * inv(D) * U.t + * inv(Sigma) = U * inv(D) * U.t * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * + * * and thus - * + * * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ - * - * To guard against singular covariance matrices, this method computes both the + * + * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t - + // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length - + try { // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - - // calculate the root-pseudo-inverse of the diagonal matrix of singular values + + // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index ea82d39b72c0..23c8d7c7c807 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging { * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ - def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = { val method = methodFromString(methodName) val numRows = counts.numRows val numCols = counts.numCols @@ -205,8 +205,10 @@ private[stat] object ChiSqTest extends Logging { val colSums = new Array[Double](numCols) val rowSums = new Array[Double](numRows) val colMajorArr = counts.toArray + val colMajorArrLen = colMajorArr.length + var i = 0 - while (i < colMajorArr.size) { + while (i < colMajorArrLen) { val elem = colMajorArr(i) if (elem < 0.0) { throw new IllegalArgumentException("Contingency table cannot contain negative entries.") @@ -220,7 +222,7 @@ private[stat] object ChiSqTest extends Logging { // second pass to collect statistic var statistic = 0.0 var j = 0 - while (j < colMajorArr.size) { + while (j < colMajorArrLen) { val col = j / numRows val colSum = colSums(col) if (colSum == 0.0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala new file mode 100644 index 000000000000..2b3ed6df486c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import scala.annotation.varargs + +import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution} +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest} + +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD + +/** + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]] + * + * Implementation note: We seek to implement the KS test with a minimal number of distributed + * passes. We sort the RDD, and then perform the following operations on a per-partition basis: + * calculate an empirical cumulative distribution value for each observation, and a theoretical + * cumulative distribution value. We know the latter to be correct, while the former will be off by + * a constant (how large the constant is depends on how many values precede it in other partitions). + * However, given that this constant simply shifts the empirical CDF upwards, but doesn't + * change its shape, and furthermore, that constant is the same within a given partition, we can + * pick 2 values in each partition that can potentially resolve to the largest global distance. + * Namely, we pick the minimum distance and the maximum distance. Additionally, we keep track of how + * many elements are in each partition. Once these three values have been returned for every + * partition, we can collect and operate locally. Locally, we can now adjust each distance by the + * appropriate constant (the cumulative sum of number of elements in the prior partitions divided by + * thedata set size). Finally, we take the maximum absolute value, and this is the statistic. + */ +private[stat] object KolmogorovSmirnovTest extends Logging { + + // Null hypothesis for the type of KS test to be included in the result. + object NullHypothesis extends Enumeration { + type NullHypothesis = Value + val OneSampleTwoSided = Value("Sample follows theoretical distribution") + } + + /** + * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution + * @param data `RDD[Double]` data on which to run test + * @param cdf `Double => Double` function to calculate the theoretical CDF + * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test + * results (p-value, statistic, and null hypothesis) + */ + def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = { + val n = data.count().toDouble + val localData = data.sortBy(x => x).mapPartitions { part => + val partDiffs = oneSampleDifferences(part, n, cdf) // local distances + searchOneSampleCandidates(partDiffs) // candidates: local extrema + }.collect() + val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme + evalOneSampleP(ksStat, n.toLong) + } + + /** + * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution + * @param data `RDD[Double]` data on which to run test + * @param distObj `RealDistribution` a theoretical distribution + * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test + * results (p-value, statistic, and null hypothesis) + */ + def testOneSample(data: RDD[Double], distObj: RealDistribution): KolmogorovSmirnovTestResult = { + val cdf = (x: Double) => distObj.cumulativeProbability(x) + testOneSample(data, cdf) + } + + /** + * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a + * partition + * @param partData `Iterator[Double]` 1 partition of a sorted RDD + * @param n `Double` the total size of the RDD + * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value + * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema + * in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF, + * the second element corresponds to empirical CDF - CDF. We can then search the resulting + * iterator for the minimum of the first and the maximum of the second element, and provide + * this as a partition's candidate extrema + */ + private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double) + : Iterator[(Double, Double)] = { + // zip data with index (within that partition) + // calculate local (unadjusted) empirical CDF and subtract CDF + partData.zipWithIndex.map { case (v, ix) => + // dp and dl are later adjusted by constant, when global info is available + val dp = (ix + 1) / n + val dl = ix / n + val cdfVal = cdf(v) + (dl - cdfVal, dp - cdfVal) + } + } + + /** + * Search the unadjusted differences in a partition and return the + * two extrema (furthest below and furthest above CDF), along with a count of elements in that + * partition + * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF + * and CDFin a partition, which come as a tuple of + * (empirical CDF - 1/N - CDF, empirical CDF - CDF) + * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements + */ + private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)]) + : Iterator[(Double, Double, Double)] = { + val initAcc = (Double.MaxValue, Double.MinValue, 0.0) + val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) => + (math.min(pMin, dl), math.max(pMax, dp), pCt + 1) + } + val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults) + results.iterator + } + + /** + * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after + * adjusting local extrema estimates from individual partitions with the amount of elements in + * preceding partitions + * @param localData `Array[(Double, Double, Double)]` A local array containing the collected + * results of `searchOneSampleCandidates` across all partitions + * @param n `Double`The size of the RDD + * @return The one-sample Kolmogorov Smirnov Statistic + */ + private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double) + : Double = { + val initAcc = (Double.MinValue, 0.0) + // adjust differences based on the number of elements preceding it, which should provide + // the correct distance between empirical CDF and CDF + val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) => + val adjConst = prevCt / n + val dist1 = math.abs(minCand + adjConst) + val dist2 = math.abs(maxCand + adjConst) + val maxVal = Array(prevMax, dist1, dist2).max + (maxVal, prevCt + ct) + } + results._1 + } + + /** + * A convenience function that allows running the KS test for 1 set of sample data against + * a named distribution + * @param data the sample data that we wish to evaluate + * @param distName the name of the theoretical distribution + * @param params Variable length parameter for distribution's parameters + * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the + * test results (p-value, statistic, and null hypothesis) + */ + @varargs + def testOneSample(data: RDD[Double], distName: String, params: Double*) + : KolmogorovSmirnovTestResult = { + val distObj = + distName match { + case "norm" => { + if (params.nonEmpty) { + // parameters are passed, then can only be 2 + require(params.length == 2, "Normal distribution requires mean and standard " + + "deviation as parameters") + new NormalDistribution(params(0), params(1)) + } else { + // if no parameters passed in initializes to standard normal + logInfo("No parameters specified for normal distribution," + + "initialized to standard normal (i.e. N(0, 1))") + new NormalDistribution(0, 1) + } + } + case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" + + s" convenience method. Current options are:['norm'].") + } + + testOneSample(data, distObj) + } + + private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = { + val pval = 1 - new CommonMathKolmogorovSmirnovTest().cdf(ksStat, n.toInt) + new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index 4784f9e94790..d01b3707be94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.test -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: @@ -25,28 +25,33 @@ import org.apache.spark.annotation.Experimental * @tparam DF Return type of `degreesOfFreedom`. */ @Experimental +@Since("1.1.0") trait TestResult[DF] { /** * The probability of obtaining a test statistic result at least as extreme as the one that was * actually observed, assuming that the null hypothesis is true. */ + @Since("1.1.0") def pValue: Double /** * Returns the degree(s) of freedom of the hypothesis test. * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. */ + @Since("1.1.0") def degreesOfFreedom: DF /** * Test statistic. */ + @Since("1.1.0") def statistic: Double /** * Null hypothesis of the test. */ + @Since("1.1.0") def nullHypothesis: String /** @@ -78,11 +83,12 @@ trait TestResult[DF] { * Object containing the test results for the chi-squared hypothesis test. */ @Experimental +@Since("1.1.0") class ChiSqTestResult private[stat] (override val pValue: Double, - override val degreesOfFreedom: Int, - override val statistic: Double, - val method: String, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.1.0") override val degreesOfFreedom: Int, + @Since("1.1.0") override val statistic: Double, + @Since("1.1.0") val method: String, + @Since("1.1.0") override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { "Chi squared test summary:\n" + @@ -90,3 +96,22 @@ class ChiSqTestResult private[stat] (override val pValue: Double, super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for the Kolmogorov-Smirnov test. + */ +@Experimental +@Since("1.5.0") +class KolmogorovSmirnovTestResult private[stat] ( + @Since("1.5.0") override val pValue: Double, + @Since("1.5.0") override val statistic: Double, + @Since("1.5.0") override val nullHypothesis: String) extends TestResult[Int] { + + @Since("1.5.0") + override val degreesOfFreedom = 0 + + override def toString: String = { + "Kolmogorov-Smirnov test summary:\n" + super.toString + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b3e8ed9af8c5..4a77d4adcd86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder - -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.SparkContext._ - /** * :: Experimental :: @@ -48,8 +44,10 @@ import org.apache.spark.SparkContext._ * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. */ +@Since("1.0.0") @Experimental -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { +class DecisionTree @Since("1.0.0") (private val strategy: Strategy) + extends Serializable with Logging { strategy.assertValid() @@ -58,6 +56,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) @@ -66,6 +65,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } +@Since("1.0.0") object DecisionTree extends Serializable with Logging { /** @@ -83,7 +83,8 @@ object DecisionTree extends Serializable with Logging { * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction - */ + */ + @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).run(input) } @@ -105,6 +106,7 @@ object DecisionTree extends Serializable with Logging { * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -132,6 +134,7 @@ object DecisionTree extends Serializable with Logging { * @param numClasses number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -165,6 +168,7 @@ object DecisionTree extends Serializable with Logging { * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -173,7 +177,7 @@ object DecisionTree extends Serializable with Logging { numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) @@ -197,6 +201,7 @@ object DecisionTree extends Serializable with Logging { * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ + @Since("1.1.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -212,6 +217,7 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] */ + @Since("1.1.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -241,6 +247,7 @@ object DecisionTree extends Serializable with Logging { * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ + @Since("1.1.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -254,6 +261,7 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] */ + @Since("1.1.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], @@ -331,14 +339,14 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (feature, bin). * @param treePoint Data point being aggregated. - * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param splits possible splits indexed (numFeatures)(numSplits) * @param unorderedFeatures Set of indices of unordered features. * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - bins: Array[Array[Bin]], + splits: Array[Array[Split]], unorderedFeatures: Set[Int], instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { @@ -366,7 +374,7 @@ object DecisionTree extends Serializable with Logging { val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { - if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { + if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } else { @@ -510,8 +518,8 @@ object DecisionTree extends Serializable with Logging { if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, instanceWeight, featuresForNode) } } } @@ -772,7 +780,7 @@ object DecisionTree extends Serializable with Logging { */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) val predict = calculatePredict(parentNodeAgg) @@ -1028,35 +1036,15 @@ object DecisionTree extends Serializable with Logging { // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { - // TODO: The second half of the bins are unused. Actually, we could just use - // splits and not build bins for unordered features. That should be part of - // a later PR since it will require changing other code (using splits instead - // of bins in a few places). // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations + // 2^(maxFeatureValue - 1) - 1 combinations splits(featureIndex) = new Array[Split](numSplits) - bins(featureIndex) = new Array[Bin](numBins) var splitIndex = 0 while (splitIndex < numSplits) { val categories: List[Double] = extractMultiClassCategories(splitIndex + 1, featureArity) splits(featureIndex)(splitIndex) = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(splitIndex) = { - if (splitIndex == 0) { - new Bin( - new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), - Categorical, - Double.MinValue) - } else { - new Bin( - splits(featureIndex)(splitIndex - 1), - splits(featureIndex)(splitIndex), - Categorical, - Double.MinValue) - } - } splitIndex += 1 } } else { @@ -1064,8 +1052,11 @@ object DecisionTree extends Serializable with Logging { // Bins correspond to feature values, so we do not need to compute splits or bins // beforehand. Splits are constructed as needed during training. splits(featureIndex) = new Array[Split](0) - bins(featureIndex) = new Array[Bin](0) } + // For ordered features, bins correspond to feature values. + // For unordered categorical features, there is no need to construct the bins. + // since there is a one-to-one correspondence between the splits and the bins. + bins(featureIndex) = new Array[Bin](0) } featureIndex += 1 } @@ -1140,7 +1131,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splits = new ArrayBuffer[Double] + val splitsBuilder = ArrayBuilder.make[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1158,17 +1149,20 @@ object DecisionTree extends Serializable with Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splits.append(valueCounts(index - 1)._1) + splitsBuilder += valueCounts(index - 1)._1 targetCount += stride } index += 1 } - splits.toArray + splitsBuilder.result() } } - assert(splits.length > 0) + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") // set number of splits accordingly metadata.setNumSplits(featureIndex, splits.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 61f6b1313f82..95ed48cea671 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -18,8 +18,9 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -48,8 +49,9 @@ import org.apache.spark.storage.StorageLevel * * @param boostingStrategy Parameters for the gradient boosting algorithm. */ +@Since("1.2.0") @Experimental -class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) +class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { /** @@ -57,14 +59,16 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -73,12 +77,54 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. */ + @Since("1.2.0") def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } -} + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @return a gradient boosted trees model that can be used for prediction + */ + @Since("1.4.0") + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate = true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. + */ + @Since("1.4.0") + def runWithValidation( + input: JavaRDD[LabeledPoint], + validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + runWithValidation(input.rdd, validationInput.rdd) + } +} +@Since("1.2.0") object GradientBoostedTrees extends Logging { /** @@ -90,6 +136,7 @@ object GradientBoostedTrees extends Logging { * @param boostingStrategy Configuration options for the boosting algorithm. * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { @@ -99,6 +146,7 @@ object GradientBoostedTrees extends Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] */ + @Since("1.2.0") def train( input: JavaRDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { @@ -108,13 +156,16 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { - + validationInput: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy, + validate: Boolean): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -129,57 +180,95 @@ object GradientBoostedTrees extends Logging { val learningRate = boostingStrategy.learningRate // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol treeStrategy.algo = Regression treeStrategy.impurity = Variance treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) + true + } else { + false } + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + timer.stop("init") logDebug("##########") logDebug("Building tree 0") logDebug("##########") - var data = input // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(input) + val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = 1.0 - val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) + baseLearnerWeights(0) = firstTreeWeight + + var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - // psuedo-residual for second iteration - data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), - point.features)) + var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 + var bestM = 1 var m = 1 - while (m < numIterations) { + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") - // Create partial model + // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) - // Update data with pseudo-residuals - data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), - point.features)) + + predError = GradientBoostedTreesModel.updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + + validatePredError = GradientBoostedTreesModel.updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) + val currentValidateError = validatePredError.values.mean() + if (bestValidateError - currentValidateError < validationTol) { + doneLearning = true + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } m += 1 } @@ -188,7 +277,19 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() + if (persistedInput) input.unpersist() + + if (validate) { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 482dd4b272d1..63a902f3eb51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,11 +17,13 @@ package org.apache.spark.mllib.tree +import java.io.IOException + import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy @@ -34,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -202,7 +205,6 @@ private class RandomForest ( Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, - checkpointDir = strategy.checkpointDir, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { @@ -244,7 +246,12 @@ private class RandomForest ( // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { - nodeIdCache.get.deleteAllCheckpoints() + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e: IOException => + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + } } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) @@ -253,6 +260,7 @@ private class RandomForest ( } +@Since("1.2.0") object RandomForest extends Serializable with Logging { /** @@ -270,6 +278,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, @@ -307,6 +316,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -326,6 +336,7 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] */ + @Since("1.2.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -356,6 +367,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, @@ -392,6 +404,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -410,6 +423,7 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] */ + @Since("1.2.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], @@ -427,6 +441,7 @@ object RandomForest extends Serializable with Logging { /** * List of supported feature subset sampling strategies. */ + @Since("1.2.0") val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "sqrt", "log2", "onethird") @@ -467,9 +482,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index b6099259971b..853c7319ec44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -17,15 +17,18 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to select the algorithm for the decision tree */ +@Since("1.0.0") @Experimental object Algo extends Enumeration { + @Since("1.0.0") type Algo = Value + @Since("1.0.0") val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index ed8e6a796f8c..b5c72fba3ede 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} @@ -34,15 +34,21 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] + * @param validationTol Useful when runWithValidation is used. If the error rate on the + * validation input between two iterations is less than the validationTol + * then stop. Ignored when + * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ +@Since("1.2.0") @Experimental -case class BoostingStrategy( +case class BoostingStrategy @Since("1.4.0") ( // Required boosting parameters - @BeanProperty var treeStrategy: Strategy, - @BeanProperty var loss: Loss, + @Since("1.2.0") @BeanProperty var treeStrategy: Strategy, + @Since("1.2.0") @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1) extends Serializable { + @Since("1.2.0") @BeanProperty var numIterations: Int = 100, + @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, + @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. @@ -65,6 +71,7 @@ case class BoostingStrategy( } } +@Since("1.2.0") @Experimental object BoostingStrategy { @@ -73,6 +80,7 @@ object BoostingStrategy { * @param algo Learning goal. Supported: "Classification" or "Regression" * @return Configuration for boosting algorithm */ + @Since("1.2.0") def defaultParams(algo: String): BoostingStrategy = { defaultParams(Algo.fromString(algo)) } @@ -84,15 +92,16 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ + @Since("1.3.0") def defaultParams(algo: Algo): BoostingStrategy = { - val treeStragtegy = Strategy.defaultStategy(algo) - treeStragtegy.maxDepth = 3 + val treeStrategy = Strategy.defaultStrategy(algo) + treeStrategy.maxDepth = 3 algo match { case Algo.Classification => - treeStragtegy.numClasses = 2 - new BoostingStrategy(treeStragtegy, LogLoss) + treeStrategy.numClasses = 2 + new BoostingStrategy(treeStrategy, LogLoss) case Algo.Regression => - new BoostingStrategy(treeStragtegy, SquaredError) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index f4c877232750..4e0cd473def0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -17,14 +17,17 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to describe whether a feature is "continuous" or "categorical" */ +@Since("1.0.0") @Experimental object FeatureType extends Enumeration { + @Since("1.0.0") type FeatureType = Value + @Since("1.0.0") val Continuous, Categorical = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 7da976e55a72..8262db8a4f11 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -17,14 +17,17 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum for selecting the quantile calculation strategy */ +@Since("1.0.0") @Experimental object QuantileStrategy extends Enumeration { + @Since("1.0.0") type QuantileStrategy = Value + @Since("1.0.0") val Sort, MinMax, ApproxHist = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3308adb6752f..89cc13b7c06c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -62,37 +62,46 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will * maintain a separate RDD of node Id cache for each row. - * @param checkpointDir If the node Id cache is used, it will help to checkpoint - * the node Id cache periodically. This is the checkpoint directory - * to be used for the node Id cache. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. - * E.g. 10 means that the cache will get checkpointed every 10 updates. + * E.g. 10 means that the cache will get checkpointed every 10 updates. If + * the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. */ +@Since("1.0.0") @Experimental -class Strategy ( - @BeanProperty var algo: Algo, - @BeanProperty var impurity: Impurity, - @BeanProperty var maxDepth: Int, - @BeanProperty var numClasses: Int = 2, - @BeanProperty var maxBins: Int = 32, - @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - @BeanProperty var minInstancesPerNode: Int = 1, - @BeanProperty var minInfoGain: Double = 0.0, - @BeanProperty var maxMemoryInMB: Int = 256, - @BeanProperty var subsamplingRate: Double = 1, - @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointDir: Option[String] = None, - @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - - def isMulticlassClassification = +class Strategy @Since("1.3.0") ( + @Since("1.0.0") @BeanProperty var algo: Algo, + @Since("1.0.0") @BeanProperty var impurity: Impurity, + @Since("1.0.0") @BeanProperty var maxDepth: Int, + @Since("1.2.0") @BeanProperty var numClasses: Int = 2, + @Since("1.0.0") @BeanProperty var maxBins: Int = 32, + @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, + @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, + @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, + @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, + @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { + + /** + */ + @Since("1.2.0") + def isMulticlassClassification: Boolean = { algo == Classification && numClasses > 2 - def isMulticlassWithCategoricalFeatures - = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + } + + /** + */ + @Since("1.2.0") + def isMulticlassWithCategoricalFeatures: Boolean = { + isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + } /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ + @Since("1.1.0") def this( algo: Algo, impurity: Impurity, @@ -107,6 +116,7 @@ class Strategy ( /** * Sets Algorithm using a String. */ + @Since("1.2.0") def setAlgo(algo: String): Unit = algo match { case "Classification" => setAlgo(Classification) case "Regression" => setAlgo(Regression) @@ -115,6 +125,7 @@ class Strategy ( /** * Sets categoricalFeaturesInfo using a Java Map. */ + @Since("1.2.0") def setCategoricalFeaturesInfo( categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { this.categoricalFeaturesInfo = @@ -147,11 +158,6 @@ class Strategy ( s" Valid values are integers >= 0.") require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + s" Valid values are integers >= 2.") - categoricalFeaturesInfo.foreach { case (feature, arity) => - require(arity >= 2, - s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + - s" feature $feature has $arity categories. The number of categories should be >= 2.") - } require(minInstancesPerNode >= 1, s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, @@ -161,14 +167,18 @@ class Strategy ( s"$subsamplingRate") } - /** Returns a shallow copy of this instance. */ + /** + * Returns a shallow copy of this instance. + */ + @Since("1.2.0") def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) } } +@Since("1.2.0") @Experimental object Strategy { @@ -176,15 +186,17 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ + @Since("1.2.0") def defaultStrategy(algo: String): Strategy = { - defaultStategy(Algo.fromString(algo)) + defaultStrategy(Algo.fromString(algo)) } /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo Algo.Classification or Algo.Regression */ - def defaultStategy(algo: Algo): Strategy = algo match { + @Since("1.3.0") + def defaultStrategy(algo: Algo): Strategy = algo match { case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) @@ -192,4 +204,9 @@ object Strategy { new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } + + @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0") + @Since("1.2.0") + def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo) + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala index 089010c81ffb..572815df0bc4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted * dataset support, update. (We store subsampleWeights as Double for this future extension.) */ -private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable -private[tree] object BaggedPoint { +private[spark] object BaggedPoint { /** * Convert an input dataset into its BaggedPoint representation, @@ -60,7 +60,7 @@ private[tree] object BaggedPoint { subsamplingRate: Double, numSubsamples: Int, withReplacement: Boolean, - seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { if (withReplacement) { convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) } else { @@ -76,7 +76,7 @@ private[tree] object BaggedPoint { input: RDD[Datum], subsamplingRate: Double, numSubsamples: Int, - seed: Int): RDD[BaggedPoint[Datum]] = { + seed: Long): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => // Use random seed = seed + partitionIndex + 1 to make generation reproducible. val rng = new XORShiftRandom @@ -100,7 +100,7 @@ private[tree] object BaggedPoint { input: RDD[Datum], subsample: Double, numSubsamples: Int, - seed: Int): RDD[BaggedPoint[Datum]] = { + seed: Long): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => // Use random seed = seed + partitionIndex + 1 to make generation reproducible. val poisson = new PoissonDistribution(subsample) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index ce8825cc0322..7985ed4b4c0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.impurity._ * and helps with indexing. * This class is abstract to support learning with and without feature subsampling. */ -private[tree] class DTStatsAggregator( +private[spark] class DTStatsAggregator( val metadata: DecisionTreeMetadata, featureSubset: Option[Array[Int]]) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index f1a6ed230186..21ee49c45788 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD * I.e., the feature takes values in {0, ..., arity - 1}. * @param numBins Number of bins for each feature. */ -private[tree] class DecisionTreeMetadata( +private[spark] class DecisionTreeMetadata( val numFeatures: Int, val numExamples: Long, val numClasses: Int, @@ -94,7 +94,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata extends Logging { +private[spark] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. @@ -107,7 +107,10 @@ private[tree] object DecisionTreeMetadata extends Logging { numTrees: Int, featureSubsetStrategy: String): DecisionTreeMetadata = { - val numFeatures = input.take(1)(0).features.size + val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { + throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + + s"but was given by empty one.") + } val numExamples = input.count() val numClasses = strategy.algo match { case Classification => strategy.numClasses @@ -125,9 +128,13 @@ private[tree] object DecisionTreeMetadata extends Logging { // based on the number of training examples. if (strategy.categoricalFeaturesInfo.nonEmpty) { val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= $maxCategoriesPerFeature)") + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") } val unorderedFeatures = new mutable.HashSet[Int]() @@ -137,21 +144,28 @@ private[tree] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - numBins(featureIndex) = numCategories + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 83011b48b7d9..8f9eb24b57b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater( * The nodeIdsForInstances RDD needs to be updated at each iteration. * @param nodeIdsForInstances The initial values in the cache * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointDir The checkpoint directory where - * the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). */ @DeveloperApi -private[tree] class NodeIdCache( +private[spark] class NodeIdCache( var nodeIdsForInstances: RDD[Array[Int]], - val checkpointDir: Option[String], val checkpointInterval: Int) { // Keep a reference to a previous node Ids for instances. @@ -91,12 +88,6 @@ private[tree] class NodeIdCache( private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() private var rddUpdateCount = 0 - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { - nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) - } - /** * Update the node index values in the cache. * This updates the RDD and its lineage. @@ -179,12 +170,11 @@ private[tree] class NodeIdCache( } @DeveloperApi -private[tree] object NodeIdCache { +private[spark] object NodeIdCache { /** * Initialize the node Id cache with initial node Id values. * @param data The RDD of training rows. * @param numTrees The number of trees that we want to create cache for. - * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). * @param initVal The initial values in the cache. @@ -193,12 +183,10 @@ private[tree] object NodeIdCache { def init( data: RDD[BaggedPoint[TreePoint]], numTrees: Int, - checkpointDir: Option[String], checkpointInterval: Int, initVal: Int = 1): NodeIdCache = { new NodeIdCache( data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointDir, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index d215d68c4279..aac84243d5ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental * Time tracker implementation which holds labeled timers. */ @Experimental -private[tree] class TimeTracker extends Serializable { +private[spark] class TimeTracker extends Serializable { private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 35e361ae309c..21919d69a38a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -37,11 +37,11 @@ import org.apache.spark.rdd.RDD * @param binnedFeatures Binned feature values. * Same length as LabeledPoint.features, but values are bin indices. */ -private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) extends Serializable { } -private[tree] object TreePoint { +private[spark] object TreePoint { /** * Convert an input dataset into its TreePoint representation, @@ -55,17 +55,15 @@ private[tree] object TreePoint { input: RDD[LabeledPoint], bins: Array[Array[Bin]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity and isUnordered for efficiency in the inner loop. + // Construct arrays for featureArity for efficiency in the inner loop. val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures) var featureIndex = 0 while (featureIndex < metadata.numFeatures) { featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - isUnordered(featureIndex) = metadata.isUnordered(featureIndex) featureIndex += 1 } input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered) + TreePoint.labeledPointToTreePoint(x, bins, featureArity) } } @@ -74,19 +72,17 @@ private[tree] object TreePoint { * @param bins Bins for features, of size (numFeatures, numBins). * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories * for categorical features. - * @param isUnordered Array index by feature, with value true for unordered categorical features. */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], - featureArity: Array[Int], - isUnordered: Array[Boolean]): TreePoint = { + featureArity: Array[Int]): TreePoint = { val numFeatures = labeledPoint.features.size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - isUnordered(featureIndex), bins) + bins) featureIndex += 1 } new TreePoint(labeledPoint.label, arr) @@ -96,14 +92,12 @@ private[tree] object TreePoint { * Find bin for one (labeledPoint, feature). * * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param isUnorderedFeature (only applies if feature is categorical) * @param bins Bins for features, of size (numFeatures, numBins). */ private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, featureArity: Int, - isUnorderedFeature: Boolean, bins: Array[Array[Bin]]): Int = { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index b7950e00786a..73df6b054a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during * binary classification. */ +@Since("1.0.0") @Experimental object Entropy extends Impurity { @@ -36,6 +37,7 @@ object Entropy extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -63,6 +65,7 @@ object Entropy extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") @@ -71,7 +74,8 @@ object Entropy extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + @Since("1.1.0") + def instance: this.type = this } @@ -118,7 +122,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c946db9c0d1c..f21845b21a80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] * during binary classification. */ +@Since("1.0.0") @Experimental object Gini extends Impurity { @@ -35,6 +36,7 @@ object Gini extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -59,6 +61,7 @@ object Gini extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") @@ -67,7 +70,8 @@ object Gini extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + @Since("1.1.0") + def instance: this.type = this } @@ -114,7 +118,7 @@ private[tree] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 60e2ab2bb829..4637dcceea7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -26,6 +26,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] * (b) calculating impurity values from sufficient statistics. */ +@Since("1.0.0") @Experimental trait Impurity extends Serializable { @@ -36,6 +37,7 @@ trait Impurity extends Serializable { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -47,6 +49,7 @@ trait Impurity extends Serializable { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } @@ -57,7 +60,7 @@ trait Impurity extends Serializable { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param statsSize Length of the vector of sufficient statistics for one bin. */ -private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { +private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { /** * Merge the stats from one bin into another. @@ -95,7 +98,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { /** * Make a deep copy of this [[ImpurityCalculator]]. @@ -111,11 +114,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { * Add the stats from another calculator into this one, modifying and returning this calculator. */ def add(other: ImpurityCalculator): ImpurityCalculator = { - require(stats.size == other.stats.size, + require(stats.length == other.stats.length, s"Two ImpurityCalculator instances cannot be added with different counts sizes." + - s" Sizes are ${stats.size} and ${other.stats.size}.") + s" Sizes are ${stats.length} and ${other.stats.length}.") var i = 0 - while (i < other.stats.size) { + val len = other.stats.length + while (i < len) { stats(i) += other.stats(i) i += 1 } @@ -127,11 +131,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { * calculator. */ def subtract(other: ImpurityCalculator): ImpurityCalculator = { - require(stats.size == other.stats.size, + require(stats.length == other.stats.length, s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + - s" Sizes are ${stats.size} and ${other.stats.size}.") + s" Sizes are ${stats.length} and ${other.stats.length}.") var i = 0 - while (i < other.stats.size) { + val len = other.stats.length + while (i < len) { stats(i) -= other.stats(i) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index df9eafa5da16..a74197278d6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating variance during regression */ +@Since("1.0.0") @Experimental object Variance extends Impurity { @@ -33,6 +34,7 @@ object Variance extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") @@ -45,6 +47,7 @@ object Variance extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { if (count == 0) { @@ -58,7 +61,8 @@ object Variance extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + @Since("1.0.0") + def instance: this.type = this } @@ -98,7 +102,7 @@ private[tree] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d1bde15e6b15..bab7b8c6cadf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * |y - F(x)| * where y is the label and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object AbsoluteError extends Loss { @@ -37,28 +38,17 @@ object AbsoluteError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * absolute error calculation. * The gradient with respect to F(x) is: sign(F(x) - y) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + @Since("1.2.0") + override def gradient(prediction: Double, label: Double): Double = { + if (label - prediction < 0) 1.0 else -1.0 } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean absolute error of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = model.predict(y.features) - y.label - math.abs(err) - }.mean() + override private[mllib] def computeError(prediction: Double, label: Double): Double = { + val err = label - prediction + math.abs(err) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 55213e695638..b2b4594712f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD * 2 log(1 + exp(-2 y F(x))) * where y is a label in {-1, 1} and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object LogLoss extends Loss { @@ -39,31 +40,18 @@ object LogLoss extends Loss { * Method to calculate the loss gradients for the gradient boosting calculation for binary * classification * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - val prediction = model.predict(point.features) - - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) + @Since("1.2.0") + override def gradient(prediction: Double, label: Double): Double = { + - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean log loss of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { case point => - val prediction = model.predict(point.features) - val margin = 2.0 * point.label * prediction - // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. - 2.0 * MLUtils.log1pExp(-margin) - }.mean() + override private[mllib] def computeError(prediction: Double, label: Double): Double = { + val margin = 2.0 * label * prediction + // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. + 2.0 * MLUtils.log1pExp(-margin) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index e1169d9f66ea..687cde325ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -17,27 +17,28 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. */ +@Since("1.2.0") @DeveloperApi trait Loss extends Serializable { /** * Method to calculate the gradients for the gradient boosting calculation. - * @param model Model of the weak learner. - * @param point Instance of the training dataset. + * @param prediction Predicted feature + * @param label true label. * @return Loss gradient. */ - def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double + @Since("1.2.0") + def gradient(prediction: Double, label: Double): Double /** * Method to calculate error of the base learner for the gradient boosting calculation. @@ -47,6 +48,18 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data */ - def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double + @Since("1.2.0") + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map(point => computeError(model.predict(point.features), point.label)).mean() + } + /** + * Method to calculate loss when the predictions are already known. + * Note: This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. + * @param prediction Predicted label. + * @param label True label. + * @return Measure of model error on datapoint. + */ + private[mllib] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala index 42c9ead9884b..2b112fbe1220 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.tree.loss +import org.apache.spark.annotation.Since + +@Since("1.2.0") object Losses { + @Since("1.2.0") def fromString(name: String): Loss = name match { case "leastSquaresError" => SquaredError case "leastAbsoluteError" => AbsoluteError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 50ecaa2f86f3..3f7d3d38be16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * (y - F(x))**2 * where y is the label and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object SquaredError extends Loss { @@ -37,28 +38,17 @@ object SquaredError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * squares error calculation. * The gradient with respect to F(x) is: - 2 (y - F(x)) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - 2.0 * (model.predict(point.features) - point.label) + @Since("1.2.0") + override def gradient(prediction: Double, label: Double): Double = { + - 2.0 * (label - prediction) } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean squared error of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = model.predict(y.features) - y.label - err * err - }.mean() + override private[mllib] def computeError(prediction: Double, label: Double): Double = { + val err = label - prediction + err * err } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a5760963068c..e1bf23f4c34b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,22 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.Experimental +import scala.collection.mutable + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -30,8 +41,11 @@ import org.apache.spark.rdd.RDD * @param topNode root node * @param algo algorithm type -- classification or regression */ +@Since("1.0.0") @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { +class DecisionTreeModel @Since("1.0.0") ( + @Since("1.0.0") val topNode: Node, + @Since("1.0.0") val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -39,6 +53,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features array representing a single data point * @return Double prediction from the trained model */ + @Since("1.0.0") def predict(features: Vector): Double = { topNode.predict(features) } @@ -49,17 +64,18 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features RDD representing data points to be predicted * @return RDD of predictions for each of the given data points */ + @Since("1.0.0") def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } - /** * Predict values for the given data set using the model trained. * * @param features JavaRDD representing data points to be predicted * @return JavaRDD of predictions for each of the given data points */ + @Since("1.2.0") def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { predict(features.rdd) } @@ -67,6 +83,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Get number of nodes in tree, including leaf nodes. */ + @Since("1.1.0") def numNodes: Int = { 1 + topNode.numDescendants } @@ -75,6 +92,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Get depth of tree. * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. */ + @Since("1.1.0") def depth: Int = { topNode.subtreeDepth } @@ -94,9 +112,227 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Print the full model to a string. */ + @Since("1.2.0") def toDebugString: String = { val header = toString + "\n" header + topNode.subtreeToString(2) } + /** + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) + } + + override protected def formatVersion: String = DecisionTreeModel.formatVersion +} + +@Since("1.3.0") +object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { + + private[spark] def formatVersion: String = "1.0" + + private[tree] object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + // Hard-code class name string in case it changes in the future + def thisClassName: String = "org.apache.spark.mllib.tree.DecisionTreeModel" + + case class PredictData(predict: Double, prob: Double) { + def toPredict: Predict = new Predict(predict, prob) + } + + object PredictData { + def apply(p: Predict): PredictData = PredictData(p.predict, p.prob) + + def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1)) + } + + case class SplitData( + feature: Int, + threshold: Double, + featureType: Int, + categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + def toSplit: Split = { + new Split(feature, threshold, FeatureType(featureType), categories.toList) + } + } + + object SplitData { + def apply(s: Split): SplitData = { + SplitData(s.feature, s.threshold, s.featureType.id, s.categories) + } + + def apply(r: Row): SplitData = { + SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3)) + } + } + + /** Model data for model import/export */ + case class NodeData( + treeId: Int, + nodeId: Int, + predict: PredictData, + impurity: Double, + isLeaf: Boolean, + split: Option[SplitData], + leftNodeId: Option[Int], + rightNodeId: Option[Int], + infoGain: Option[Double]) + + object NodeData { + def apply(treeId: Int, n: Node): NodeData = { + NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf, + n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id), + n.stats.map(_.gain)) + } + + def apply(r: Row): NodeData = { + val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5))) + val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6)) + val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7)) + val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8)) + NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3), + r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain) + } + } + + def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // SPARK-6120: We do a hacky check here so users understand why save() is failing + // when they run the ML guide example. + // TODO: Fix this issue for real. + val memThreshold = 768 + if (sc.isLocal) { + val driverMemory = sc.getConf.getOption("spark.driver.memory") + .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) + .map(Utils.memoryStringToMb) + .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB) + if (driverMemory <= memThreshold) { + logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + + s" driver memory (${driverMemory}m)." + + s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") + } + } else { + if (sc.executorMemory <= memThreshold) { + logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + + s" executor memory (${sc.executorMemory}m)." + + s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") + } + } + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val nodes = model.topNode.subtreeIterator.toSeq + val dataRDD: DataFrame = sc.parallelize(nodes) + .map(NodeData.apply(0, _)) + .toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.read.parquet(datapath) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[NodeData](dataRDD.schema) + val nodes = dataRDD.map(NodeData.apply) + // Build node data into a tree. + val trees = constructTrees(nodes) + assert(trees.size == 1, + "Decision tree should contain exactly one tree but got ${trees.size} trees.") + val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) + assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." + + s" Expected $numNodes nodes but found ${model.numNodes}") + model + } + + def constructTrees(nodes: RDD[NodeData]): Array[Node] = { + val trees = nodes + .groupBy(_.treeId) + .mapValues(_.toArray) + .collect() + .map { case (treeId, data) => + (treeId, constructTree(data)) + }.sortBy(_._1) + val numTrees = trees.size + val treeIndices = trees.map(_._1).toSeq + assert(treeIndices == (0 until numTrees), + s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") + trees.map(_._2) + } + + /** + * Given a list of nodes from a tree, construct the tree. + * @param data array of all node data in a tree. + */ + def constructTree(data: Array[NodeData]): Node = { + val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap + assert(dataMap.contains(1), + s"DecisionTree missing root node (id = 1).") + constructNode(1, dataMap, mutable.Map.empty) + } + + /** + * Builds a node from the node data map and adds new nodes to the input nodes map. + */ + private def constructNode( + id: Int, + dataMap: Map[Int, NodeData], + nodes: mutable.Map[Int, Node]): Node = { + if (nodes.contains(id)) { + return nodes(id) + } + val data = dataMap(id) + val node = + if (data.isLeaf) { + Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf) + } else { + val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes) + val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes) + val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity, + rightNode.impurity, leftNode.predict, rightNode.predict) + new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf, + data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats)) + } + nodes += node.id -> node + node + } + } + + /** + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") + override def load(sc: SparkContext, path: String): DecisionTreeModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val algo = (metadata \ "algo").extract[String] + val numNodes = (metadata \ "numNodes").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path, algo, numNodes) + case _ => throw new Exception( + s"DecisionTreeModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 9a50ecb550c3..091a0462c204 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** * :: DeveloperApi :: @@ -29,6 +30,7 @@ import org.apache.spark.annotation.DeveloperApi * @param leftPredict left node predict * @param rightPredict right node predict */ +@Since("1.0.0") @DeveloperApi class InformationGainStats( val gain: Double, @@ -38,25 +40,35 @@ class InformationGainStats( val leftPredict: Predict, val rightPredict: Predict) extends Serializable { - override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" - .format(gain, impurity, leftImpurity, rightImpurity) + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" } - override def equals(o: Any) = - o match { - case other: InformationGainStats => { - gain == other.gain && - impurity == other.impurity && - leftImpurity == other.leftImpurity && - rightImpurity == other.rightImpurity - } - case _ => false - } -} + override def equals(o: Any): Boolean = o match { + case other: InformationGainStats => + gain == other.gain && + impurity == other.impurity && + leftImpurity == other.leftImpurity && + rightImpurity == other.rightImpurity && + leftPredict == other.leftPredict && + rightPredict == other.rightPredict + + case _ => false + } + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + gain: java.lang.Double, + impurity: java.lang.Double, + leftImpurity: java.lang.Double, + rightImpurity: java.lang.Double, + leftPredict, + rightPredict) + } +} -private[tree] object InformationGainStats { +private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to * denote that current split doesn't satisfies minimum info gain or @@ -65,3 +77,62 @@ private[tree] object InformationGainStats { val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } + +/** + * :: DeveloperApi :: + * Impurity statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param impurityCalculator impurity statistics for current node + * @param leftImpurityCalculator impurity statistics for left child node + * @param rightImpurityCalculator impurity statistics for right child node + * @param valid whether the current split satisfies minimum info gain or + * minimum number of instances per node + */ +@DeveloperApi +private[spark] class ImpurityStats( + val gain: Double, + val impurity: Double, + val impurityCalculator: ImpurityCalculator, + val leftImpurityCalculator: ImpurityCalculator, + val rightImpurityCalculator: ImpurityCalculator, + val valid: Boolean = true) extends Serializable { + + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" + } + + def leftImpurity: Double = if (leftImpurityCalculator != null) { + leftImpurityCalculator.calculate() + } else { + -1.0 + } + + def rightImpurity: Double = if (rightImpurityCalculator != null) { + rightImpurityCalculator.calculate() + } else { + -1.0 + } +} + +private[spark] object ImpurityStats { + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + impurityCalculator, null, null, false) + } + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object + * that only 'impurity' and 'impurityCalculator' are defined. + */ + def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 2179da8dbe03..ea6e5aa5d94e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vector @@ -39,24 +39,28 @@ import org.apache.spark.mllib.linalg.Vector * @param rightNode right child * @param stats information gain stats */ +@Since("1.0.0") @DeveloperApi -class Node ( - val id: Int, - var predict: Predict, - var impurity: Double, - var isLeaf: Boolean, - var split: Option[Split], - var leftNode: Option[Node], - var rightNode: Option[Node], - var stats: Option[InformationGainStats]) extends Serializable with Logging { +class Node @Since("1.2.0") ( + @Since("1.0.0") val id: Int, + @Since("1.0.0") var predict: Predict, + @Since("1.2.0") var impurity: Double, + @Since("1.0.0") var isLeaf: Boolean, + @Since("1.0.0") var split: Option[Split], + @Since("1.0.0") var leftNode: Option[Node], + @Since("1.0.0") var rightNode: Option[Node], + @Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging { - override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + "split = " + split + ", stats = " + stats + override def toString: String = { + s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + + s"split = $split, stats = $stats" + } /** * build the left node and right nodes if not leaf * @param nodes array of nodes */ + @Since("1.0.0") @deprecated("build should no longer be used since trees are constructed on-the-fly in training", "1.2.0") def build(nodes: Array[Node]): Unit = { @@ -78,10 +82,11 @@ class Node ( * @param features feature value * @return predicted value */ + @Since("1.1.0") def predict(features: Vector) : Double = { if (isLeaf) { predict.predict - } else{ + } else { if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { leftNode.get.predict(features) @@ -149,9 +154,9 @@ class Node ( s"(feature ${split.feature} > ${split.threshold})" } case Categorical => if (left) { - s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})" } else { - s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})" } } } @@ -159,16 +164,21 @@ class Node ( if (isLeaf) { prefix + s"Predict: ${predict.predict}\n" } else { - prefix + s"If ${splitToString(split.get, left=true)}\n" + + prefix + s"If ${splitToString(split.get, left = true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + - prefix + s"Else ${splitToString(split.get, left=false)}\n" + + prefix + s"Else ${splitToString(split.get, left = false)}\n" + rightNode.get.subtreeToString(indentFactor + 1) } } + /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */ + private[tree] def subtreeIterator: Iterator[Node] = { + Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++ + rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty) + } } -private[tree] object Node { +private[spark] object Node { /** * Return a node with the given node id (but nothing else set). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 004838ee5ba0..06ceff19d863 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,19 +17,29 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ +@Since("1.2.0") @DeveloperApi -class Predict( - val predict: Double, - val prob: Double = 0.0) extends Serializable { +class Predict @Since("1.2.0") ( + @Since("1.2.0") val predict: Double, + @Since("1.2.0") val prob: Double = 0.0) extends Serializable { - override def toString = { - "predict = %f, prob = %f".format(predict, prob) + override def toString: String = s"$predict (prob = $prob)" + + override def equals(other: Any): Boolean = { + other match { + case p: Predict => predict == p.predict && prob == p.prob + case _ => false + } + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b7a85f58544a..b85a66c05a81 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @@ -31,16 +31,18 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * @param featureType type of feature -- categorical or continuous * @param categories Split left if categorical feature value is in this set, else right. */ +@Since("1.0.0") @DeveloperApi case class Split( - feature: Int, - threshold: Double, - featureType: FeatureType, - categories: List[Double]) { + @Since("1.0.0") feature: Int, + @Since("1.0.0") threshold: Double, + @Since("1.0.0") featureType: FeatureType, + @Since("1.0.0") categories: List[Double]) { - override def toString = - "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + - ", categories = " + categories + override def toString: String = { + s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + + s"categories = $categories" + } } /** @@ -67,4 +69,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) */ private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) - diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 22997110de8d..df5b8feab5d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -20,13 +20,24 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.mllib.tree.loss.Loss +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + /** * :: Experimental :: @@ -35,12 +46,65 @@ import org.apache.spark.rdd.RDD * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles */ +@Since("1.2.0") @Experimental -class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) - extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), - combiningStrategy = if (algo == Classification) Vote else Average) { +class RandomForestModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel]) + extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) + with Saveable { require(trees.forall(_.algo == algo)) + + /** + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + RandomForestModel.SaveLoadV1_0.thisClassName) + } + + override protected def formatVersion: String = RandomForestModel.formatVersion +} + +@Since("1.3.0") +object RandomForestModel extends Loader[RandomForestModel] { + + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + + /** + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") + override def load(sc: SparkContext, path: String): RandomForestModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.treeWeights.forall(_ == 1.0)) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new RandomForestModel(Algo.fromString(metadata.algo), trees) + case _ => throw new Exception(s"RandomForestModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel" + } + } /** @@ -51,14 +115,162 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis * @param trees tree ensembles * @param treeWeights tree ensemble weights */ +@Since("1.2.0") @Experimental -class GradientBoostedTreesModel( - override val algo: Algo, - override val trees: Array[DecisionTreeModel], - override val treeWeights: Array[Double]) - extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { +class GradientBoostedTreesModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel], + @Since("1.2.0") override val treeWeights: Array[Double]) + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) + with Saveable { + + require(trees.length == treeWeights.length) + + /** + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) + } + + /** + * Method to compute error or loss for every iteration of gradient boosting. + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param loss evaluation metric. + * @return an array with index i having the losses or errors for the ensemble + * containing the first i+1 trees + */ + @Since("1.4.0") + def evaluateEachIteration( + data: RDD[LabeledPoint], + loss: Loss): Array[Double] = { + + val sc = data.sparkContext + val remappedData = algo match { + case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + case _ => data + } + + val numIterations = trees.length + val evaluationArray = Array.fill(numIterations)(0.0) + val localTreeWeights = treeWeights + + var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError( + remappedData, localTreeWeights(0), trees(0), loss) + + evaluationArray(0) = predictionAndError.values.mean() + + val broadcastTrees = sc.broadcast(trees) + (1 until numIterations).foreach { nTree => + predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => + val currentTree = broadcastTrees.value(nTree) + val currentTreeWeight = localTreeWeights(nTree) + iter.map { case (point, (pred, error)) => + val newPred = pred + currentTree.predict(point.features) * currentTreeWeight + val newError = loss.computeError(newPred, point.label) + (newPred, newError) + } + } + evaluationArray(nTree) = predictionAndError.values.mean() + } + + broadcastTrees.unpersist() + evaluationArray + } + + override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion +} + +/** + */ +@Since("1.3.0") +object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { + + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + @Since("1.4.0") + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + data.map { lp => + val pred = initTreeWeight * initTree.predict(lp.features) + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + @Since("1.4.0") + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + newPredError + } + + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + + /** + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.combiningStrategy == Sum.toString) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights) + case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + } - require(trees.size == treeWeights.size) } /** @@ -167,12 +379,105 @@ private[tree] sealed class TreeEnsembleModel( } /** - * Get number of trees in forest. + * Get number of trees in ensemble. */ - def numTrees: Int = trees.size + def numTrees: Int = trees.length /** - * Get total number of nodes, summed over all trees in the forest. + * Get total number of nodes, summed over all trees in the ensemble. */ def totalNumNodes: Int = trees.map(_.numNodes).sum } + +private[tree] object TreeEnsembleModel extends Logging { + + object SaveLoadV1_0 { + + import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} + + def thisFormatVersion: String = "1.0" + + case class Metadata( + algo: String, + treeAlgo: String, + combiningStrategy: String, + treeWeights: Array[Double]) + + /** + * Model data for model import/export. + * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields + * of nested fields; once that is possible, we can use something like: + * case class EnsembleNodeData(treeId: Int, node: NodeData), + * where NodeData is from DecisionTreeModel. + */ + case class EnsembleNodeData(treeId: Int, node: NodeData) + + def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // SPARK-6120: We do a hacky check here so users understand why save() is failing + // when they run the ML guide example. + // TODO: Fix this issue for real. + val memThreshold = 768 + if (sc.isLocal) { + val driverMemory = sc.getConf.getOption("spark.driver.memory") + .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) + .map(Utils.memoryStringToMb) + .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB) + if (driverMemory <= memThreshold) { + logWarning(s"$className.save() was called, but it may fail because of too little" + + s" driver memory (${driverMemory}m)." + + s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") + } + } else { + if (sc.executorMemory <= memThreshold) { + logWarning(s"$className.save() was called, but it may fail because of too little" + + s" executor memory (${sc.executorMemory}m)." + + s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") + } + } + + // Create JSON metadata. + implicit val format = DefaultFormats + val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, + model.combiningStrategy.toString, model.treeWeights) + val metadata = compact(render( + ("class" -> className) ~ ("version" -> thisFormatVersion) ~ + ("metadata" -> Extraction.decompose(ensembleMetadata)))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => + tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) + }.toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + /** + * Read metadata from the loaded JSON metadata. + */ + def readMetadata(metadata: JValue): Metadata = { + implicit val formats = DefaultFormats + (metadata \ "metadata").extract[Metadata] + } + + /** + * Load trees for an ensemble, and return them in order. + * @param path path to load the model from + * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's + * algorithm). + */ + def loadTrees( + sc: SparkContext, + path: String, + treeAlgo: String): Array[DecisionTreeModel] = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) + val trees = constructTrees(nodes) + trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index be335a1aca58..dffe6e78939e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: * A collection of methods used to validate data before applying ML algorithms. */ @DeveloperApi +@Since("0.8.0") object DataValidators extends Logging { /** @@ -34,6 +35,7 @@ object DataValidators extends Logging { * * @return True if labels are all zero or one, false otherwise. */ + @Since("1.0.0") val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { @@ -48,6 +50,7 @@ object DataValidators extends Logging { * * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise. */ + @Since("1.3.0") def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index 6eaebaf7dba9..00fd1606a369 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.rdd.RDD /** @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * cluster with scale 1 around each center. */ @DeveloperApi +@Since("0.8.0") object KMeansDataGenerator { /** @@ -42,6 +43,7 @@ object KMeansDataGenerator { * @param r Scaling factor for the distribution of the initial centers * @param numPartitions Number of partitions of the generated RDD; default 2 */ + @Since("0.8.0") def generateKMeansRDD( sc: SparkContext, numPoints: Int, @@ -62,10 +64,13 @@ object KMeansDataGenerator { } } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 6) { + // scalastyle:off println println("Usage: KMeansGenerator " + " []") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 97f54aa62d31..d0ba454f379a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.util -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random -import org.jblas.DoubleMatrix +import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -35,6 +35,7 @@ import org.apache.spark.mllib.regression.LabeledPoint * response variable `Y`. */ @DeveloperApi +@Since("0.8.0") object LinearDataGenerator { /** @@ -46,16 +47,21 @@ object LinearDataGenerator { * @param seed Random seed * @return Java List of input. */ + @Since("0.8.0") def generateLinearInputAsList( intercept: Double, weights: Array[Double], nPoints: Int, seed: Int, eps: Double): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) + generateLinearInput(intercept, weights, nPoints, seed, eps).asJava } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -64,19 +70,56 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.length)(0.0), + Array.fill[Double](weights.length)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + @Since("0.8.0") + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = x.map { xi => - new DoubleMatrix(1, xi.length, xi: _*).dot(weightsMat) + intercept + eps * rnd.nextGaussian() + blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } @@ -93,6 +136,7 @@ object LinearDataGenerator { * * @return RDD of LabeledPoint containing sample data. */ + @Since("0.8.0") def generateLinearRDD( sc: SparkContext, nexamples: Int, @@ -100,9 +144,9 @@ object LinearDataGenerator { eps: Double, nparts: Int = 2, intercept: Double = 0.0) : RDD[LabeledPoint] = { - org.jblas.util.Random.seed(42) + val random = new Random(42) // Random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) + val w = Array.fill(nfeatures)(random.nextDouble() - 0.5) val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p => val seed = 42 + p @@ -112,10 +156,13 @@ object LinearDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: LinearDataGenerator " + " [num_examples] [num_features] [num_partitions]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 9d802678c4a7..33477ee20ebb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint @@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.Vectors * with probability `probOne` and scales features for positive examples by `eps`. */ @DeveloperApi +@Since("0.8.0") object LogisticRegressionDataGenerator { /** @@ -43,6 +44,7 @@ object LogisticRegressionDataGenerator { * @param nparts Number of partitions of the generated RDD. Default value is 2. * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. */ + @Since("0.8.0") def generateLogisticRDD( sc: SparkContext, nexamples: Int, @@ -62,10 +64,13 @@ object LogisticRegressionDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length != 5) { + // scalastyle:off println println("Usage: LogisticRegressionGenerator " + " ") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index b76fbe89c368..906bd30563bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.util +import java.{util => ju} + import scala.language.postfixOps import scala.util.Random -import org.jblas.DoubleMatrix - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD /** @@ -51,11 +52,15 @@ import org.apache.spark.rdd.RDD * testSampFact (Double) Percentage of training data to use as test data. */ @DeveloperApi +@Since("0.8.0") object MFDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: MFDataGenerator " + " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") + // scalastyle:on println System.exit(1) } @@ -72,24 +77,24 @@ object MFDataGenerator { val sc = new SparkContext(sparkMaster, "MFDataGenerator") - val A = DoubleMatrix.randn(m, rank) - val B = DoubleMatrix.randn(rank, n) - val z = 1 / scala.math.sqrt(scala.math.sqrt(rank)) - A.mmuli(z) - B.mmuli(z) - val fullData = A.mmul(B) + val random = new ju.Random(42L) + + val A = DenseMatrix.randn(m, rank, random) + val B = DenseMatrix.randn(rank, n, random) + val z = 1 / math.sqrt(rank) + val fullData = DenseMatrix.zeros(m, n) + BLAS.gemm(z, A, B, 1.0, fullData) val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt + val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n - val shuffled = rand.shuffle(1 to mn toList) + val shuffled = rand.shuffle((0 until mn).toList) val omega = shuffled.slice(0, sampSize) val ordered = omega.sortWith(_ < _).toArray val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + .map(x => (x % m, x / m, fullData.values(x))) // optionally add gaussian noise if (noise) { @@ -100,12 +105,12 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testSampSize = math.min( + math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + .map(x => (x % m, x / m, fullData.values(x))) testData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 5d6ddd47f67d..81c2f0ce6e12 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD @@ -36,6 +36,7 @@ import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. */ +@Since("0.8.0") object MLUtils { private[mllib] lazy val EPSILON = { @@ -65,6 +66,7 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, @@ -82,6 +84,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } @@ -102,6 +116,7 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -115,12 +130,14 @@ object MLUtils { * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -129,6 +146,7 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -140,6 +158,7 @@ object MLUtils { * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. */ + @Since("1.0.0") def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -150,6 +169,7 @@ object MLUtils { * * @see [[org.apache.spark.mllib.util.MLUtils#loadLibSVMFile]] */ + @Since("1.0.0") def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => @@ -170,12 +190,14 @@ object MLUtils { * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -186,6 +208,7 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) @@ -193,6 +216,7 @@ object MLUtils { * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -209,6 +233,7 @@ object MLUtils { * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ + @Since("1.0.0") @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { sc.textFile(dir).map { line => @@ -230,6 +255,7 @@ object MLUtils { * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ + @Since("1.0.0") @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) @@ -242,6 +268,7 @@ object MLUtils { * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. */ + @Since("1.0.0") @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat @@ -257,15 +284,32 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ + @Since("1.0.0") def appendBias(vector: Vector): Vector = { - val vector1 = vector.toBreeze match { - case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(1.0))) - case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(1.0), 1)) - case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + vector match { + case dv: DenseVector => + val inputValues = dv.values + val inputLength = inputValues.length + val outputValues = Array.ofDim[Double](inputLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputLength) + outputValues(inputLength) = 1.0 + Vectors.dense(outputValues) + case sv: SparseVector => + val inputValues = sv.values + val inputIndices = sv.indices + val inputValuesLength = inputValues.length + val dim = sv.size + val outputValues = Array.ofDim[Double](inputValuesLength + 1) + val outputIndices = Array.ofDim[Int](inputValuesLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength) + System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength) + outputValues(inputValuesLength) = 1.0 + outputIndices(inputValuesLength) = dim + Vectors.sparse(dim + 1, outputIndices, outputValues) + case _ => throw new IllegalArgumentException(s"Do not support vector type ${vector.getClass}") } - Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: @@ -331,7 +375,7 @@ object MLUtils { * @param x a floating-point value as input. * @return the result of `math.log(1 + math.exp(x))`. */ - private[mllib] def log1pExp(x: Double): Double = { + private[spark] def log1pExp(x: Double): Double = { if (x > 0) { x + math.log1p(math.exp(-x)) } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index f7cba6c6cb62..a841c5caf014 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import java.util.StringTokenizer -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.{ArrayBuilder, ListBuffer} import org.apache.spark.SparkException @@ -51,7 +51,7 @@ private[mllib] object NumericParser { } private def parseArray(tokenizer: StringTokenizer): Array[Double] = { - val values = ArrayBuffer.empty[Double] + val values = ArrayBuilder.make[Double] var parsing = true var allowComma = false var token: String = null @@ -67,14 +67,14 @@ private[mllib] object NumericParser { } } else { // expecting a number - values.append(parseDouble(token)) + values += parseDouble(token) allowComma = true } } if (parsing) { throw new SparkException(s"An array must end with ']'.") } - values.toArray + values.result() } private def parseTuple(tokenizer: StringTokenizer): Seq[_] = { @@ -98,6 +98,8 @@ private[mllib] object NumericParser { } } else if (token == ")") { parsing = false + } else if (token.trim.isEmpty){ + // ignore whitespaces between delim chars, e.g. ", [" } else { // expecting a number items.append(parseDouble(token)) @@ -114,7 +116,7 @@ private[mllib] object NumericParser { try { java.lang.Double.parseDouble(s) } catch { - case e: Throwable => + case e: NumberFormatException => throw new SparkException(s"Cannot parse a double from: $s", e) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index 7db97e6bac68..cde597939617 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.jblas.DoubleMatrix +import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -33,12 +33,16 @@ import org.apache.spark.mllib.regression.LabeledPoint * for the features and adds Gaussian noise with weight 0.1 to generate labels. */ @DeveloperApi +@Since("0.8.0") object SVMDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: SVMGenerator " + " [num_examples] [num_features] [num_partitions]") + // scalastyle:on println System.exit(1) } @@ -51,8 +55,7 @@ object SVMDataGenerator { val sc = new SparkContext(sparkMaster, "SVMGenerator") val globalRnd = new Random(94720) - val trueWeights = new DoubleMatrix(1, nfeatures + 1, - Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) + val trueWeights = Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()) val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => val rnd = new Random(42 + idx) @@ -60,7 +63,7 @@ object SVMDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextDouble() * 2.0 - 1.0 } - val yD = new DoubleMatrix(1, x.length, x: _*).dot(trueWeights) + rnd.nextGaussian() * 0.1 + val yD = blas.ddot(trueWeights.length, x, 1, trueWeights, 1) + rnd.nextGaussian() * 0.1 val y = if (yD < 0) 0.0 else 1.0 LabeledPoint(y, Vectors.dense(x)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala new file mode 100644 index 000000000000..4d71d534a077 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + */ +@DeveloperApi +@Since("1.3.0") +trait Saveable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") + def save(sc: SparkContext, path: String): Unit + + /** Current version of model save/load format. */ + protected def formatVersion: String + +} + +/** + * :: DeveloperApi :: + * + * Trait for classes which can load models and transformers from files. + * This should be inherited by an object paired with the model class. + */ +@DeveloperApi +@Since("1.3.0") +trait Loader[M <: Saveable] { + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") + def load(sc: SparkContext, path: String): M + +} + +/** + * Helper methods for loading models from files. + */ +private[mllib] object Loader { + + /** Returns URI for path/data using the Hadoop filesystem */ + def dataPath(path: String): String = new Path(path, "data").toUri.toString + + /** Returns URI for path/metadata using the Hadoop filesystem */ + def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString + + /** + * Check the schema of loaded model data. + * + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. + * + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. + */ + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name).sameType(field.dataType), + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } + + /** + * Load metadata from the given path. + * @return (class name, version, metadata) + */ + def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { + implicit val formats = DefaultFormats + val metadata = parse(sc.textFile(metadataPath(path)).first()) + val clazz = (metadata \ "class").extract[String] + val version = (metadata \ "version").extract[String] + (clazz, version, metadata) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 56a9dbdd58b6..0a8c9e595467 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -45,7 +45,7 @@ public void setUp() { jsql = new SQLContext(jsc); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.applySchema(points, LabeledPoint.class); + dataset = jsql.createDataFrame(points, LabeledPoint.class); } @After @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java new file mode 100644 index 000000000000..38eb58673ad5 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaAttributeGroupSuite { + + @Test + public void testAttributeGroup() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr(), + NominalAttribute.defaultAttr(), + BinaryAttribute.defaultAttr().withIndex(0), + NumericAttribute.defaultAttr().withName("age").withSparsity(0.8), + NominalAttribute.defaultAttr().withName("size").withValues("small", "medium", "large"), + BinaryAttribute.defaultAttr().withName("clicked").withValues("no", "yes"), + NumericAttribute.defaultAttr(), + NumericAttribute.defaultAttr() + }; + AttributeGroup group = new AttributeGroup("user", attrs); + Assert.assertEquals(8, group.size()); + Assert.assertEquals("user", group.name()); + Assert.assertEquals(NumericAttribute.defaultAttr().withIndex(0), group.getAttr(0)); + Assert.assertEquals(3, group.indexOf("age")); + Assert.assertFalse(group.hasAttr("abc")); + Assert.assertEquals(group, AttributeGroup.fromStructField(group.toStructField())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java new file mode 100644 index 000000000000..b74bbed23143 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute; + +import org.junit.Test; +import org.junit.Assert; + +public class JavaAttributeSuite { + + @Test + public void testAttributeType() { + AttributeType numericType = AttributeType.Numeric(); + AttributeType nominalType = AttributeType.Nominal(); + AttributeType binaryType = AttributeType.Binary(); + Assert.assertEquals(numericType, NumericAttribute.defaultAttr().attrType()); + Assert.assertEquals(nominalType, NominalAttribute.defaultAttr().attrType()); + Assert.assertEquals(binaryType, BinaryAttribute.defaultAttr().attrType()); + } + + @Test + public void testNumericAttribute() { + NumericAttribute attr = NumericAttribute.defaultAttr() + .withName("age").withIndex(0).withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.4); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } + + @Test + public void testNominalAttribute() { + NominalAttribute attr = NominalAttribute.defaultAttr() + .withName("size").withIndex(1).withValues("small", "medium", "large"); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } + + @Test + public void testBinaryAttribute() { + BinaryAttribute attr = BinaryAttribute.defaultAttr() + .withName("clicked").withIndex(2).withValues("no", "yes"); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java new file mode 100644 index 000000000000..60f25e5cce43 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaDecisionTreeClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + dt.setImpurity(impurity); + } + DecisionTreeClassificationModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + DecisionTreeClassificationModel sameModel = + DecisionTreeClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java new file mode 100644 index 000000000000..3c69467fa119 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + GBTClassifier rf = new GBTClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTClassifier.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index f4ba23c44563..618b95b9bd12 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -18,17 +18,22 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.lang.Math; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.sql.Row; + public class JavaLogisticRegressionSuite implements Serializable { @@ -36,12 +41,17 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient SQLContext jsql; private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + private double eps = 1e-5; + @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); } @After @@ -51,29 +61,101 @@ public void tearDown() { } @Test - public void logisticRegression() { + public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); + assert(lr.getLabelCol().equals("label")); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); + // Check defaults + assert(model.getThreshold() == 0.5); + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + assert(model.getProbabilityCol().equals("probability")); } @Test public void logisticRegressionWithSetters() { + // Set params, train, and check as many params as we can. LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(1.0); + .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold - .registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collectAsList(); + LogisticRegression parent = (LogisticRegression) model.parent(); + assert(parent.getMaxIter() == 10); + assert(parent.getRegParam() == 1.0); + assert(parent.getThresholds()[0] == 0.4); + assert(parent.getThresholds()[1] == 0.6); + assert(parent.getThreshold() == 0.6); + assert(model.getThreshold() == 0.6); + + // Modify model params, and check that the params worked. + model.setThreshold(1.0); + model.transform(dataset).registerTempTable("predAllZero"); + DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r: predAllZero.collectAsList()) { + assert(r.getDouble(0) == 0.0); + } + // Call transform with params, and check that the params worked. + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) + .registerTempTable("predNotAllZero"); + DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + boolean foundNonZero = false; + for (Row r: predNotAllZero.collectAsList()) { + if (r.getDouble(0) != 0.0) foundNonZero = true; + } + assert(foundNonZero); + + // Call fit() with new params, and check as many params as we can. + LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + LogisticRegression parent2 = (LogisticRegression) model2.parent(); + assert(parent2.getMaxIter() == 5); + assert(parent2.getRegParam() == 0.1); + assert(parent2.getThreshold() == 0.4); + assert(model2.getThreshold() == 0.4); + assert(model2.getProbabilityCol().equals("theProb")); + } + + @SuppressWarnings("unchecked") + @Test + public void logisticRegressionPredictorClassifierMethods() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + assert(model.numClasses() == 2); + + model.transform(dataset).registerTempTable("transformed"); + DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collect()) { + Vector raw = (Vector)row.get(0); + Vector prob = (Vector)row.get(1); + assert(raw.size() == 2); + assert(prob.size() == 2); + double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); + assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); + assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + } + + DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collect()) { + double pred = row.getDouble(0); + Vector prob = (Vector)row.get(1); + double probOfPred = prob.apply((int)pred); + for (int i = 0; i < prob.size(); ++i) { + assert(probOfPred >= prob.apply(i)); + } + } } @Test - public void logisticRegressionFitWithVarargs() { + public void logisticRegressionTrainingSummary() { LogisticRegression lr = new LogisticRegression(); - lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + LogisticRegressionModel model = lr.fit(dataset); + + LogisticRegressionTrainingSummary summary = model.summary(); + assert(summary.totalIterations() == summary.objectiveHistory().length); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java new file mode 100644 index 000000000000..ec6b4bf3c0f8 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaMultilayerPerceptronClassifierSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + sqlContext = null; + } + + @Test + public void testMLPC() { + DataFrame dataFrame = sqlContext.createDataFrame( + jsc.parallelize(Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), + LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() + .setLayers(new int[] {2, 5, 2}) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100); + MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); + DataFrame result = model.transform(dataFrame); + Row[] predictionAndLabels = result.select("prediction", "label").collect(); + for (Row r: predictionAndLabels) { + Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java new file mode 100644 index 000000000000..a700c9cddb20 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaNaiveBayesSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + public void validatePrediction(DataFrame predictionAndLabels) { + for (Row r : predictionAndLabels.collect()) { + double prediction = r.getAs(0); + double label = r.getAs(1); + assert(prediction == label); + } + } + + @Test + public void naiveBayesDefaultParams() { + NaiveBayes nb = new NaiveBayes(); + assert(nb.getLabelCol() == "label"); + assert(nb.getFeaturesCol() == "features"); + assert(nb.getPredictionCol() == "prediction"); + assert(nb.getSmoothing() == 1.0); + assert(nb.getModelType() == "multinomial"); + } + + @Test + public void testNaiveBayes() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), + RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)), + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)), + RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(jrdd, schema); + NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); + NaiveBayesModel model = nb.fit(dataset); + + DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + validatePrediction(predictionAndLabels); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java new file mode 100644 index 000000000000..2744e020e9e4 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import scala.collection.JavaConverters; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaOneVsRestSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); + jsql = new SQLContext(jsc); + int nPoints = 3; + + // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + double[] weights = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List points = JavaConverters.asJavaListConverter( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + ).asJava(); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol() , "label"); + Assert.assertEquals(ova.getPredictionCol() , "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java new file mode 100644 index 000000000000..a66a1e12927b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + RandomForestClassifier rf = new RandomForestClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestClassifier.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + Vector importances = model.featureImportances(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + RandomForestClassificationModel sameModel = + RandomForestClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java new file mode 100644 index 000000000000..d09fa7fd5637 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaKMeansSuite implements Serializable { + + private transient int k = 5; + private transient JavaSparkContext sc; + private transient DataFrame dataset; + private transient SQLContext sql; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaKMeansSuite"); + sql = new SQLContext(sc); + + dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void fitAndTransform() { + KMeans kmeans = new KMeans().setK(k).setSeed(1); + KMeansModel model = kmeans.fit(dataset); + + Vector[] centers = model.clusterCenters(); + assertEquals(k, centers.length); + + DataFrame transformed = model.transform(dataset); + List columns = Arrays.asList(transformed.columns()); + List expectedColumns = Arrays.asList("features", "prediction"); + for (String column: expectedColumns) { + assertTrue(columns.contains(column)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java new file mode 100644 index 000000000000..d5bd230a957a --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaBucketizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void bucketizerTest() { + double[] splits = {-0.5, 0.0, 0.5}; + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[] { + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits); + + Row[] result = bucketizer.transform(dataset).select("result").collect(); + + for (Row r : result) { + double index = r.getDouble(0); + Assert.assertTrue((index >= 0) && (index <= 1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java new file mode 100644 index 000000000000..845eed61c45c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaDCTSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaDCTSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + double[] input = new double[] {1D, 2D, 3D, 4D}; + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.dense(input)) + )); + DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ + new StructField("vec", (new VectorUDT()), false, Metadata.empty()) + })); + + double[] expectedResult = input.clone(); + (new DoubleDCT_1D(input.length)).forward(expectedResult, true); + + DCT dct = new DCT() + .setInputCol("vec") + .setOutputCol("resultVec"); + + Row[] result = dct.transform(dataset).select("resultVec").collect(); + Vector resultVec = result[0].getAs("resultVec"); + + Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java new file mode 100644 index 000000000000..599e9cfd23ad --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +public class JavaHashingTFSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaHashingTFSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void hashingTF() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer() + .setInputCol("sentence") + .setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); + int numFeatures = 20; + HashingTF hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("rawFeatures") + .setNumFeatures(numFeatures); + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { + Vector features = r.getAs(0); + Assert.assertEquals(features.size(), numFeatures); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java new file mode 100644 index 000000000000..d82f3b7e8c07 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaNormalizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void normalizer() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures"); + + // Normalize each Vector using $L^2$ norm. + DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + l2NormData.count(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java new file mode 100644 index 000000000000..5cf43fec6f29 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaPCASuite implements Serializable { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPCASuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + public static class VectorPair implements Serializable { + private Vector features = Vectors.dense(0.0); + private Vector expected = Vectors.dense(0.0); + + public void setFeatures(Vector features) { + this.features = features; + } + + public Vector getFeatures() { + return this.features; + } + + public void setExpected(Vector expected) { + this.expected = expected; + } + + public Vector getExpected() { + return this.expected; + } + } + + @Test + public void testPCA() { + List points = Lists.newArrayList( + Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + JavaRDD dataRDD = jsc.parallelize(points, 2); + + RowMatrix mat = new RowMatrix(dataRDD.rdd()); + Matrix pc = mat.computePrincipalComponents(3); + JavaRDD expected = mat.multiply(pc).rows().toJavaRDD(); + + JavaRDD featuresExpected = dataRDD.zip(expected).map( + new Function, VectorPair>() { + public VectorPair call(Tuple2 pair) { + VectorPair featuresExpected = new VectorPair(); + featuresExpected.setFeatures(pair._1()); + featuresExpected.setExpected(pair._2()); + return featuresExpected; + } + } + ); + + DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setK(3) + .fit(df); + List result = pca.transform(df).select("pca_features", "expected").toJavaRDD().collect(); + for (Row r : result) { + Assert.assertEquals(r.get(1), r.get(0)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java new file mode 100644 index 000000000000..5e8211c2c511 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaPolynomialExpansionSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void polynomialExpansionTest() { + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create( + Vectors.dense(-2.0, 2.3), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) + ), + RowFactory.create(Vectors.dense(0.0, 0.0), Vectors.dense(new double[9])), + RowFactory.create( + Vectors.dense(0.6, -1.1), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331) + ) + )); + + StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("expected", new VectorUDT(), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + Row[] pairs = polyExpansion.transform(dataset) + .select("polyFeatures", "expected") + .collect(); + + for (Row r : pairs) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + double[] expected = ((Vector)r.get(1)).toArray(); + Assert.assertArrayEquals(polyFeatures, expected, 1e-1); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java new file mode 100644 index 000000000000..74eb2733f06e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaStandardScalerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void standardScaler() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 000000000000..35b18c5308f6 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java new file mode 100644 index 000000000000..3806f650025b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaTokenizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void regexTokenizer() { + RegexTokenizer myRegExTokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(3); + + JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + )); + DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + + Row[] pairs = myRegExTokenizer.transform(dataset) + .select("tokens", "wantedTokens") + .collect(); + + for (Row r : pairs) { + Assert.assertEquals(r.get(0), r.get(1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java new file mode 100644 index 000000000000..b7c564caad3b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaVectorAssemblerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testVectorAssembler() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("x", DoubleType, false), + createStructField("y", new VectorUDT(), false), + createStructField("name", StringType, false), + createStructField("z", new VectorUDT(), false), + createStructField("n", LongType, false) + }); + Row row = RowFactory.create( + 0, 0.0, Vectors.dense(1.0, 2.0), "a", + Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"x", "y", "z", "n"}) + .setOutputCol("features"); + DataFrame output = assembler.transform(dataset); + Assert.assertEquals( + Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + output.select("features").first().getAs(0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 000000000000..c7ae5468b942 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Map> categoryMaps = model.javaCategoryMaps(); + Assert.assertEquals(categoryMaps.size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java new file mode 100644 index 000000000000..56988b9fb29c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructType; + + +public class JavaVectorSlicerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void vectorSlice() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + + DataFrame output = vectorSlicer.transform(dataset); + + for (Row r : output.select("userFeatures", "features").take(2)) { + Vector features = r.getAs(1); + Assert.assertEquals(features.size(), 2); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java new file mode 100644 index 000000000000..39c70157f83c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; + +public class JavaWord2VecSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testJavaWord2Vec() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), + RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), + RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + + Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); + Word2VecModel model = word2Vec.fit(documentDF); + DataFrame result = model.transform(documentDF); + + for (Row r: result.select("result").collect()) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + Assert.assertEquals(polyFeatures.length, 3); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java new file mode 100644 index 000000000000..9890155e9f86 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; + +/** + * Test Param and related classes in Java + */ +public class JavaParamsSuite { + + private transient JavaSparkContext jsc; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaParamsSuite"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testParams() { + JavaTestParams testParams = new JavaTestParams(); + Assert.assertEquals(testParams.getMyIntParam(), 1); + testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); + Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); + Assert.assertEquals(testParams.getMyStringParam(), "a"); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); + } + + @Test + public void testParamValidate() { + ParamValidators.gt(1.0); + ParamValidators.gtEq(1.0); + ParamValidators.lt(1.0); + ParamValidators.ltEq(1.0); + ParamValidators.inRange(0, 1, true, false); + ParamValidators.inRange(0, 1); + ParamValidators.inArray(Lists.newArrayList(0, 1, 3)); + ParamValidators.inArray(Lists.newArrayList("a", "b")); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java new file mode 100644 index 000000000000..dc6ce8061f62 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.ml.util.Identifiable$; + +/** + * A subclass of Params for testing. + */ +public class JavaTestParams extends JavaParams { + + public JavaTestParams() { + this.uid_ = Identifiable$.MODULE$.randomUID("javaTestParams"); + init(); + } + + public JavaTestParams(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_; + + @Override + public String uid() { + return uid_; + } + + private IntParam myIntParam_; + public IntParam myIntParam() { return myIntParam_; } + + public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } + + public JavaTestParams setMyIntParam(int value) { + set(myIntParam_, value); + return this; + } + + private DoubleParam myDoubleParam_; + public DoubleParam myDoubleParam() { return myDoubleParam_; } + + public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } + + public JavaTestParams setMyDoubleParam(double value) { + set(myDoubleParam_, value); + return this; + } + + private Param myStringParam_; + public Param myStringParam() { return myStringParam_; } + + public String getMyStringParam() { return getOrDefault(myStringParam_); } + + public JavaTestParams setMyStringParam(String value) { + set(myStringParam_, value); + return this; + } + + private DoubleArrayParam myDoubleArrayParam_; + public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } + + public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + + public JavaTestParams setMyDoubleArrayParam(double[] value) { + set(myDoubleArrayParam_, value); + return this; + } + + private void init() { + myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); + myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", + ParamValidators.inRange(0.0, 1.0)); + List validStrings = Lists.newArrayList("a", "b"); + myStringParam_ = new Param(this, "myStringParam", "this is a string param", + ParamValidators.inArray(validStrings)); + myDoubleArrayParam_ = + new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); + + setDefault(myIntParam(), 1); + setDefault(myDoubleParam(), 0.5); + setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + } + + @Override + public JavaTestParams copy(ParamMap extra) { + return defaultCopy(extra); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java new file mode 100644 index 000000000000..ebe800e749e0 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaDecisionTreeRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + dt.setImpurity(impurity); + } + DecisionTreeRegressionModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java new file mode 100644 index 000000000000..fc8c13db07e6 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + GBTRegressor rf = new GBTRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTRegressor.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java new file mode 100644 index 000000000000..d591a456864e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + + +public class JavaLinearRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + jsql = new SQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void linearRegressionDefaultParams() { + LinearRegression lr = new LinearRegression(); + assert(lr.getLabelCol().equals("label")); + LinearRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + predictions.collect(); + // Check defaults + assertEquals("features", model.getFeaturesCol()); + assertEquals("prediction", model.getPredictionCol()); + } + + @Test + public void linearRegressionWithSetters() { + // Set params, train, and check as many params as we can. + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0); + LinearRegressionModel model = lr.fit(dataset); + LinearRegression parent = (LinearRegression) model.parent(); + assertEquals(10, parent.getMaxIter()); + assertEquals(1.0, parent.getRegParam(), 0.0); + + // Call fit() with new params, and check as many params as we can. + LinearRegressionModel model2 = + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + LinearRegression parent2 = (LinearRegression) model2.parent(); + assertEquals(5, parent2.getMaxIter()); + assertEquals(0.1, parent2.getRegParam(), 0.0); + assertEquals("thePred", model2.getPredictionCol()); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java new file mode 100644 index 000000000000..a00ce5e249c3 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + // This tests setters. Training with various options is tested in Scala. + RandomForestRegressor rf = new RandomForestRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestRegressor.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + Vector importances = model.featureImportances(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 074b58c07df7..08eeca53f072 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -45,7 +45,7 @@ public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After @@ -68,8 +68,8 @@ public void crossValidationWithLogisticRegression() { .setEvaluator(eval) .setNumFolds(3); CrossValidatorModel cvModel = cv.fit(dataset); - ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); - Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); - Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + LogisticRegression parent = (LogisticRegression) cvModel.bestModel().parent(); + Assert.assertEquals(0.001, parent.getRegParam(), 0.0); + Assert.assertEquals(10, parent.getMaxIter()); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala new file mode 100644 index 000000000000..928301523fba --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite + +class IdentifiableSuite extends SparkFunSuite { + + import IdentifiableSuite.Test + + test("Identifiable") { + val test0 = new Test("test_0") + assert(test0.uid === "test_0") + + val test1 = new Test + assert(test1.uid.startsWith("test_")) + } +} + +object IdentifiableSuite { + + class Test(override val uid: String) extends Identifiable { + def this() = this(Identifiable.randomUID("test")) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714..3771c0ea7ad8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; public class JavaNaiveBayesSuite implements Serializable { private transient JavaSparkContext sc; @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception { // Should be able to get the first prediction. predictions.first(); } + + @Test + public void testModelTypeSetters() { + NaiveBayes nb = new NaiveBayes() + .setModelType("bernoulli") + .setModelType("multinomial"); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java new file mode 100644 index 000000000000..55787f8606d4 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLogisticRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java new file mode 100644 index 000000000000..467a7a69e8f3 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaGaussianMixtureSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGaussianMixture"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runGaussianMixture() { + List points = Lists.newArrayList( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ); + + JavaRDD data = sc.parallelize(points, 2); + GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) + .run(data); + assertEquals(model.gaussians().length, 2); + JavaRDD predictions = model.predict(data); + predictions.first(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1..3fea359a3b46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -19,21 +19,25 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; -import org.apache.spark.api.java.JavaRDD; import scala.Tuple2; +import scala.Tuple3; import org.junit.After; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertArrayEquals; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; - +import org.apache.spark.mllib.linalg.Vectors; public class JavaLDASuite implements Serializable { private transient JavaSparkContext sc; @@ -42,9 +46,9 @@ public class JavaLDASuite implements Serializable { public void setUp() { sc = new JavaSparkContext("local", "JavaLDA"); ArrayList> tinyCorpus = new ArrayList>(); - for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), - LDASuite$.MODULE$.tinyCorpus()[i]._2())); + for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); @@ -58,7 +62,10 @@ public void tearDown() { @Test public void localLDAModel() { - LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + Matrix topics = LDASuite.tinyTopics(); + double[] topicConcentration = new double[topics.numRows()]; + Arrays.fill(topicConcentration, 1.0D / topics.numRows()); + LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); // Check: basic parameters assertEquals(model.k(), tinyK); @@ -88,7 +95,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); @@ -105,15 +112,96 @@ public void distributedLDAModel() { assertEquals(roundedLocalTopicSummary.length, k); // Check: log probabilities - assert(model.logLikelihood() < 0.0); - assert(model.logPrior() < 0.0); + assertTrue(model.logLikelihood() < 0.0); + assertTrue(model.logPrior() < 0.0); + + // Check: topic distributions + JavaPairRDD topicDistributions = model.javaTopicDistributions(); + // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs + // over topics. Compare it against nonEmptyCorpus instead of corpus + JavaPairRDD nonEmptyCorpus = corpus.filter( + new Function, Boolean>() { + public Boolean call(Tuple2 tuple2) { + return Vectors.norm(tuple2._2(), 1.0) != 0.0; + } + }); + assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + Tuple3 topTopics = model.javaTopTopicsPerDocument(3).first(); + Long docId = topTopics._1(); // confirm doc ID type + int[] topicIndices = topTopics._2(); + double[] topicWeights = topTopics._3(); + assertEquals(3, topicIndices.length); + assertEquals(3, topicWeights.length); + + // Check: topTopicAssignments + Tuple3 topicAssignment = model.javaTopicAssignments().first(); + Long docId2 = topicAssignment._1(); + int[] termIndices2 = topicAssignment._2(); + int[] topicIndices2 = topicAssignment._3(); + assertEquals(termIndices2.length, topicIndices2.length); + } + + @Test + public void OnlineOptimizerCompatibility() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + OnlineLDAOptimizer op = new OnlineLDAOptimizer() + .setTau0(1024) + .setKappa(0.51) + .setGammaShape(1e40) + .setMiniBatchFraction(0.5); + + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op); + + LDAModel model = lda.run(corpus); + + // Check: basic parameters + assertEquals(model.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = model.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + } + + @Test + public void localLdaMethods() { + JavaRDD> docs = sc.parallelize(toyData, 2); + JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList> docsSingleWord = new ArrayList>(); + docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); } - private static int tinyK = LDASuite$.MODULE$.tinyK(); - private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); - private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static int tinyK = LDASuite.tinyK(); + private static int tinyVocabSize = LDASuite.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite$.MODULE$.tinyTopicDescription(); - JavaPairRDD corpus; + LDASuite.tinyTopicDescription(); + private JavaPairRDD corpus; + private LocalLDAModel toyModel = LDASuite.toyModel(); + private ArrayList> toyData = LDASuite.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java new file mode 100644 index 000000000000..3b0e879eec77 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.apache.spark.streaming.JavaTestUtils.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaStreamingKMeansSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + Vectors.dense(1.0), + Vectors.dense(0.0)); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingKMeans skmeans = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0}); + skmeans.trainOn(training); + JavaPairDStream prediction = skmeans.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java new file mode 100644 index 000000000000..fa4d334801ce --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; +import scala.Tuple2$; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaRankingMetricsSuite implements Serializable { + private transient JavaSparkContext sc; + private transient JavaRDD, List>> predictionAndLabels; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); + predictionAndLabels = sc.parallelize(Arrays.asList( + Tuple2$.MODULE$.apply( + Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), + Tuple2$.MODULE$.apply( + Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), + Tuple2$.MODULE$.apply( + Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void rankingMetrics() { + @SuppressWarnings("unchecked") + RankingMetrics metrics = RankingMetrics.of(predictionAndLabels); + Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5); + Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java new file mode 100644 index 000000000000..d7c2cb3ae206 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + + +public class JavaAssociationRulesSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runAssociationRules() { + + @SuppressWarnings("unchecked") + JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList( + new FreqItemset(new String[] {"a"}, 15L), + new FreqItemset(new String[] {"b"}, 35L), + new FreqItemset(new String[] {"a", "b"}, 12L) + )); + + JavaRDD> results = (new AssociationRules()).run(freqItemsets); + } +} + diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java new file mode 100644 index 000000000000..9ce2c52dca8b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; +import static org.junit.Assert.*; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaFPGrowthSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runFPGrowth() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + List> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + assertEquals(18, freqItemsets.size()); + + for (FPGrowth.FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java new file mode 100644 index 000000000000..34daf5fbde80 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; + +public class JavaPrefixSpanSuite { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaPrefixSpan"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runPrefixSpan() { + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + JavaRDD> freqSeqs = model.freqSequences().toJavaRDD(); + List> localFreqSeqs = freqSeqs.collect(); + Assert.assertEquals(5, localFreqSeqs.size()); + // Check that each frequent sequence could be materialized. + for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + List> seq = freqSeq.javaSequence(); + long freq = freqSeq.freq(); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 704d484d0b58..3349c5022423 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -71,8 +71,8 @@ public void diagonalMatrixConstruction() { Matrix sm = Matrices.diag(sv); DenseMatrix d = DenseMatrix.diag(v); DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.diag(v); - SparseMatrix ss = SparseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); assertArrayEquals(m.toArray(), sm.toArray(), 0.0); assertArrayEquals(d.toArray(), sm.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java new file mode 100644 index 000000000000..899c4ea60786 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLinearRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java new file mode 100644 index 000000000000..eb4e3698624b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; + +public class JavaStatisticsSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaStatistics"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testCorr() { + JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + + Double corr1 = Statistics.corr(x, y); + Double corr2 = Statistics.corr(x, y, "pearson"); + // Check default method + assertEquals(corr1, corr2); + } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD data = sc.parallelize(Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } +} diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2f175fb11794..1f2c9b75b617 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.ml +import scala.collection.JavaConverters._ + import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.sql.DataFrame -class PipelineSuite extends FunSuite { +class PipelineSuite extends SparkFunSuite { abstract class MyModel extends Model[MyModel] @@ -42,30 +46,34 @@ class PipelineSuite extends FunSuite { val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] - when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) - when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) + when(model0.copy(any[ParamMap])).thenReturn(model0) + when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) + when(estimator2.copy(any[ParamMap])).thenReturn(estimator2) + when(model2.copy(any[ParamMap])).thenReturn(model2) + when(transformer3.copy(any[ParamMap])).thenReturn(transformer3) + + when(estimator0.fit(meq(dataset0))).thenReturn(model0) + when(model0.transform(meq(dataset0))).thenReturn(dataset1) when(model0.parent).thenReturn(estimator0) - when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) - when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) - when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(transformer1.transform(meq(dataset1))).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2))).thenReturn(model2) + when(model2.transform(meq(dataset2))).thenReturn(dataset3) when(model2.parent).thenReturn(estimator2) - when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + when(transformer3.transform(meq(dataset3))).thenReturn(dataset4) val pipeline = new Pipeline() .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) - assert(pipelineModel.stages.size === 4) + MLTestingUtils.checkCopy(pipelineModel) + + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) assert(pipelineModel.stages(2).eq(model2)) assert(pipelineModel.stages(3).eq(transformer3)) - assert(pipelineModel.getModel(estimator0).eq(model0)) - assert(pipelineModel.getModel(estimator2).eq(model2)) - intercept[NoSuchElementException] { - pipelineModel.getModel(mock[Estimator[MyModel]]) - } val output = pipelineModel.transform(dataset0) assert(output.eq(dataset4)) } @@ -79,4 +87,28 @@ class PipelineSuite extends FunSuite { pipeline.fit(dataset) } } + + test("PipelineModel.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) + require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + + test("pipeline model constructors") { + val transform0 = mock[Transformer] + val model1 = mock[MyModel] + + val stages = Array(transform0, model1) + val pipelineModel0 = new PipelineModel("pipeline0", stages) + assert(pipelineModel0.uid === "pipeline0") + assert(pipelineModel0.stages === stages) + + val stagesAsList = stages.toList.asJava + val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList) + assert(pipelineModel1.uid === "pipeline1") + assert(pipelineModel1.stages === stages) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala new file mode 100644 index 000000000000..1292e57d7c01 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + + +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { + + // TODO: test for weights comparison with Weka MLP + test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array(0.0, 1.0, 1.0, 0.0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 1) + trainer.setWeights(initialWeights) + trainer.LBFGSOptimizer.setNumIterations(20) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input)(0), label(0)) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(math.round(p) === l) + } + } + + test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array( + Array(1.0, 0.0), + Array(0.0, 1.0), + Array(0.0, 1.0), + Array(1.0, 0.0) + ) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 2) + trainer.SGDOptimizer.setNumIterations(2000) + trainer.setWeights(initialWeights) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input), label) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(p ~== l absTol 0.5) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala new file mode 100644 index 000000000000..512cffb1acb6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.SparkFunSuite + +class AttributeGroupSuite extends SparkFunSuite { + + test("attribute group") { + val attrs = Array( + NumericAttribute.defaultAttr, + NominalAttribute.defaultAttr, + BinaryAttribute.defaultAttr.withIndex(0), + NumericAttribute.defaultAttr.withName("age").withSparsity(0.8), + NominalAttribute.defaultAttr.withName("size").withValues("small", "medium", "large"), + BinaryAttribute.defaultAttr.withName("clicked").withValues("no", "yes"), + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr) + val group = new AttributeGroup("user", attrs) + assert(group.size === 8) + assert(group.name === "user") + assert(group(0) === NumericAttribute.defaultAttr.withIndex(0)) + assert(group(2) === BinaryAttribute.defaultAttr.withIndex(2)) + assert(group.indexOf("age") === 3) + assert(group.indexOf("size") === 4) + assert(group.indexOf("clicked") === 5) + assert(!group.hasAttr("abc")) + intercept[NoSuchElementException] { + group("abc") + } + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField())) + } + + test("attribute group without attributes") { + val group0 = new AttributeGroup("user", 10) + assert(group0.name === "user") + assert(group0.numAttributes === Some(10)) + assert(group0.size === 10) + assert(group0.attributes.isEmpty) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + + val group1 = new AttributeGroup("item") + assert(group1.name === "item") + assert(group1.numAttributes.isEmpty) + assert(group1.attributes.isEmpty) + assert(group1.size === -1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala new file mode 100644 index 000000000000..6355e0f17949 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class AttributeSuite extends SparkFunSuite { + + test("default numeric attribute") { + val attr: NumericAttribute = NumericAttribute.defaultAttr + val metadata = Metadata.fromJson("{}") + val metadataWithType = Metadata.fromJson("""{"type":"numeric"}""") + assert(attr.attrType === AttributeType.Numeric) + assert(attr.isNumeric) + assert(!attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.min.isEmpty) + assert(attr.max.isEmpty) + assert(attr.std.isEmpty) + assert(attr.sparsity.isEmpty) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === Attribute.fromMetadata(metadataWithType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized numeric attribute") { + val name = "age" + val index = 0 + val metadata = Metadata.fromJson("""{"name":"age","idx":0}""") + val metadataWithType = Metadata.fromJson("""{"type":"numeric","name":"age","idx":0}""") + val attr: NumericAttribute = NumericAttribute.defaultAttr + .withName(name) + .withIndex(index) + assert(attr.attrType == AttributeType.Numeric) + assert(attr.isNumeric) + assert(!attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === Attribute.fromMetadata(metadataWithType)) + val field = attr.toStructField() + assert(field.dataType === DoubleType) + assert(!field.nullable) + assert(attr.withoutIndex === Attribute.fromStructField(field)) + val existingMetadata = new MetadataBuilder() + .putString("name", "test") + .build() + assert(attr.toStructField(existingMetadata).metadata.getString("name") === "test") + + val attr2 = + attr.withoutName.withoutIndex.withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.3) + assert(attr2.name.isEmpty) + assert(attr2.index.isEmpty) + assert(attr2.min === Some(0.0)) + assert(attr2.max === Some(1.0)) + assert(attr2.std === Some(0.5)) + assert(attr2.sparsity === Some(0.3)) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + } + + test("bad numeric attributes") { + val attr = NumericAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + intercept[IllegalArgumentException](attr.withStd(-0.1)) + intercept[IllegalArgumentException](attr.withSparsity(-0.5)) + intercept[IllegalArgumentException](attr.withSparsity(1.5)) + } + + test("default nominal attribute") { + val attr: NominalAttribute = NominalAttribute.defaultAttr + val metadata = Metadata.fromJson("""{"type":"nominal"}""") + val metadataWithoutType = Metadata.fromJson("{}") + assert(attr.attrType === AttributeType.Nominal) + assert(!attr.isNumeric) + assert(attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.values.isEmpty) + assert(attr.numValues.isEmpty) + assert(attr.isOrdinal.isEmpty) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized nominal attribute") { + val name = "size" + val index = 1 + val values = Array("small", "medium", "large") + val metadata = Metadata.fromJson( + """{"type":"nominal","name":"size","idx":1,"vals":["small","medium","large"]}""") + val metadataWithoutType = Metadata.fromJson( + """{"name":"size","idx":1,"vals":["small","medium","large"]}""") + val attr: NominalAttribute = NominalAttribute.defaultAttr + .withName(name) + .withIndex(index) + .withValues(values) + assert(attr.attrType === AttributeType.Nominal) + assert(!attr.isNumeric) + assert(attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.values === Some(values)) + assert(attr.indexOf("medium") === 1) + assert(attr.getValue(1) === "medium") + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) + assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) + + val attr2 = attr.withoutName.withoutIndex.withValues(attr.values.get :+ "x-large") + assert(attr2.name.isEmpty) + assert(attr2.index.isEmpty) + assert(attr2.values.get === Array("small", "medium", "large", "x-large")) + assert(attr2.indexOf("x-large") === 3) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false))) + } + + test("bad nominal attributes") { + val attr = NominalAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + intercept[IllegalArgumentException](attr.withNumValues(-1)) + } + + test("default binary attribute") { + val attr = BinaryAttribute.defaultAttr + val metadata = Metadata.fromJson("""{"type":"binary"}""") + val metadataWithoutType = Metadata.fromJson("{}") + assert(attr.attrType === AttributeType.Binary) + assert(attr.isNumeric) + assert(attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.values.isEmpty) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized binary attribute") { + val name = "clicked" + val index = 2 + val values = Array("no", "yes") + val metadata = Metadata.fromJson( + """{"type":"binary","name":"clicked","idx":2,"vals":["no","yes"]}""") + val metadataWithoutType = Metadata.fromJson( + """{"name":"clicked","idx":2,"vals":["no","yes"]}""") + val attr = BinaryAttribute.defaultAttr + .withName(name) + .withIndex(index) + .withValues(values(0), values(1)) + assert(attr.attrType === AttributeType.Binary) + assert(attr.isNumeric) + assert(attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.values.get === values) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) + assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) + } + + test("bad binary attributes") { + val attr = BinaryAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + } + + test("attribute from struct field") { + val metadata = NumericAttribute.defaultAttr.withName("label").toMetadata() + val fldWithoutMeta = new StructField("x", DoubleType, false, Metadata.empty) + assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute) + val fldWithMeta = new StructField("x", DoubleType, false, metadata) + assert(Attribute.fromStructField(fldWithMeta).isNumeric) + // Attribute.fromStructField should accept any NumericType, not just DoubleType + val longFldWithMeta = new StructField("x", LongType, false, metadata) + assert(Attribute.fromStructField(longFldWithMeta).isNumeric) + val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) + assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala new file mode 100644 index 000000000000..f680d8d3c4cc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row + +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + import DecisionTreeClassifierSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + orderedLabeledPointsWithLabel0RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()) + orderedLabeledPointsWithLabel1RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()) + categoricalDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()) + continuousDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()) + categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + } + + test("params") { + ParamsSuite.checkParams(new DecisionTreeClassifier) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) + ParamsSuite.checkParams(model) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification stump with ordered categorical features") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { + val dt = new DecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val numClasses = 2 + Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd => + DecisionTreeClassifier.supportedImpurities.foreach { impurity => + dt.setImpurity(impurity) + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + } + } + + test("Multiclass classification stump with 3-ary (unordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 3 + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Binary classification stump with 2 continuous features") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with continuous features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with continuous + unordered categorical features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with 10-ary (ordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min instances per node requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min info gain requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(newTree) + + val predictions = newTree.transform(newData) + .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + } + + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new DecisionTreeClassifier().setMaxDepth(3) + val model = dt.fit(df) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) + val newModel = DecisionTreeClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = DecisionTreeClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + // Use parent from newTree since this is not checked anyways. + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( + oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala new file mode 100644 index 000000000000..e3909bccaa5c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils + + +/** + * Test suite for [[GBTClassifier]]. + */ +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + import GBTClassifierSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("params") { + ParamsSuite.checkParams(new GBTClassifier) + val model = new GBTClassificationModel("gbtc", + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), + Array(1.0)) + ParamsSuite.checkParams(model) + } + + test("Binary classification with continuous features: Log Loss") { + val categoricalFeatures = Map.empty[Int, Int] + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType("logistic") + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTClassifier.supportedLossTypes.foreach { loss => + val gbt = new GBTClassifier() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) + val newModel = GBTClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTClassifier, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = + gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val newModel = gbt.fit(newData) + // Use parent from newTree since this is not checked anyways. + val oldModelAsNew = GBTClassificationModel.fromOld( + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 33e40dc7410c..cce39f382f73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,47 +17,768 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + @transient var binaryDataset: DataFrame = _ + private val eps: Double = 1e-5 override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) - dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + + dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) + + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + import org.apache.spark.mllib.classification.LogisticRegressionSuite + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + */ + binaryDataset = { + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + + sqlContext.createDataFrame( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) + } + } + + test("params") { + ParamsSuite.checkParams(new LogisticRegression) + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) } - test("logistic regression") { + test("logistic regression: default params") { val lr = new LogisticRegression + assert(lr.getLabelCol === "label") + assert(lr.getFeaturesCol === "features") + assert(lr.getPredictionCol === "prediction") + assert(lr.getRawPredictionCol === "rawPrediction") + assert(lr.getProbabilityCol === "probability") + assert(lr.getFitIntercept) + assert(lr.getStandardization) val model = lr.fit(dataset) model.transform(dataset) - .select("label", "prediction") + .select("label", "probability", "prediction", "rawPrediction") .collect() + assert(model.getThreshold === 0.5) + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + test("setThreshold, getThreshold") { + val lr = new LogisticRegression + // default + assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5") + withClue("LogisticRegression should not have thresholds set by default.") { + intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future + lr.getThresholds + } + } + // Set via threshold. + // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. + lr.setThreshold(1.0) + assert(lr.getThresholds === Array(0.0, 1.0)) + lr.setThreshold(0.0) + assert(lr.getThresholds === Array(1.0, 0.0)) + lr.setThreshold(0.5) + assert(lr.getThresholds === Array(0.5, 0.5)) + // Set via thresholds + val lr2 = new LogisticRegression + lr2.setThresholds(Array(0.3, 0.7)) + val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) + assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7) + // thresholds and threshold must be consistent + lr2.setThresholds(Array(0.1, 0.2, 0.3)) + withClue("getThreshold should throw error if thresholds has length != 2.") { + intercept[IllegalArgumentException] { + lr2.getThreshold + } + } + // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + val lr2model = lr2.fit(dataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + lr2model.getThreshold + } + } + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { + // Set params, train, and check as many params as we can. val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability") val model = lr.fit(dataset) - model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select("label", "score", "prediction") + val parent = model.parent.asInstanceOf[LogisticRegression] + assert(parent.getMaxIter === 10) + assert(parent.getRegParam === 1.0) + assert(parent.getThreshold === 0.6) + assert(model.getThreshold === 0.6) + + // Modify model params, and check that the params worked. + model.setThreshold(1.0) + val predAllZero = model.transform(dataset) + .select("prediction", "myProbability") .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") + // Call transform with params, and check that the params worked. + val predNotAllZero = + model.transform(dataset, model.threshold -> 0.0, + model.probabilityCol -> "myProb") + .select("prediction", "myProb") + .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predNotAllZero.exists(_ !== 0.0)) + + // Call fit() with new params, and check as many params as we can. + lr.setThresholds(Array(0.6, 0.4)) + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, + lr.probabilityCol -> "theProb") + val parent2 = model2.parent.asInstanceOf[LogisticRegression] + assert(parent2.getMaxIter === 5) + assert(parent2.getRegParam === 0.1) + assert(parent2.getThreshold === 0.4) + assert(model2.getThreshold === 0.4) + assert(model2.getProbabilityCol === "theProb") } - test("logistic regression fit and transform with varargs") { + test("logistic regression: Predictor, Classifier methods") { + val sqlContext = this.sqlContext val lr = new LogisticRegression - val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) - model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select("label", "probability", "prediction") - .collect() + + val model = lr.fit(dataset) + assert(model.numClasses === 2) + + val threshold = model.getThreshold + val results = model.transform(dataset) + + // Compare rawPrediction with probability + results.select("rawPrediction", "probability").collect().foreach { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) + assert(prob(1) ~== probFromRaw1 relTol eps) + assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) + } + + // Compare prediction with probability + results.select("prediction", "probability").collect().foreach { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } + } + + test("MultiClassSummarizer") { + val summarizer1 = (new MultiClassSummarizer) + .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0) + assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2)) + assert(summarizer1.countInvalid === 0) + assert(summarizer1.numClasses === 7) + + val summarizer2 = (new MultiClassSummarizer) + .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0) + assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer2.countInvalid === 0) + assert(summarizer2.numClasses === 6) + + val summarizer3 = (new MultiClassSummarizer) + .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0) + assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2)) + assert(summarizer3.countInvalid === 3) + assert(summarizer3.numClasses === 5) + + val summarizer4 = (new MultiClassSummarizer) + .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0) + assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer4.countInvalid === 2) + assert(summarizer4.numClasses === 4) + + // small map merges large one + val summarizerA = summarizer1.merge(summarizer2) + assert(summarizerA.hashCode() === summarizer2.hashCode()) + assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizerA.countInvalid === 0) + assert(summarizerA.numClasses === 7) + + // large map merges small one + val summarizerB = summarizer3.merge(summarizer4) + assert(summarizerB.hashCode() === summarizer3.hashCode()) + assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2)) + assert(summarizerB.countInvalid === 5) + assert(summarizerB.numClasses === 5) + } + + test("binary logistic regression with intercept without regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 + */ + val interceptR = 2.8366423 + val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) + } + + test("binary logistic regression without intercept without regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 + */ + val interceptR = 0.0 + val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-2) + + // Without regularization, with or without standardization should converge to the same solution. + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-2) + } + + test("binary logistic regression with intercept with L1 regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 + */ + val interceptR1 = -0.05627428 + val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.weights ~= weightsR1 absTol 2E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + standardize=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.3722152 + data.V2 . + data.V3 . + data.V4 -0.1665453 + data.V5 . + */ + val interceptR2 = 0.3722152 + val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.weights ~= weightsR2 absTol 1E-3) + } + + test("binary logistic regression without intercept with L1 regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(false) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(false) + .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 absTol 1E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE, standardize=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.08420782 + data.V5 . + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) + } + + test("binary logistic regression with intercept with L2 regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 + */ + val interceptR1 = 0.15021751 + val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + standardize=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.48657516 + data.V2 -0.05155371 + data.V3 0.02301057 + data.V4 -0.11482896 + data.V5 -0.06266838 + */ + val interceptR2 = 0.48657516 + val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + } + + test("binary logistic regression without intercept with L2 regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(false) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(false) + .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE, standardize=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.005679651 + data.V3 0.048967094 + data.V4 -0.093714016 + data.V5 -0.053314311 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-2) + } + + test("binary logistic regression with intercept with ElasticNet regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.57734851 + data.V2 -0.05310287 + data.V3 . + data.V4 -0.08849250 + data.V5 -0.15458796 + */ + val interceptR1 = 0.57734851 + val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) + + assert(model1.intercept ~== interceptR1 relTol 6E-3) + assert(model1.weights ~== weightsR1 absTol 5E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + standardize=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.51555993 + data.V2 . + data.V3 . + data.V4 -0.18807395 + data.V5 -0.05350074 + */ + val interceptR2 = 0.51555993 + val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) + + assert(model2.intercept ~== interceptR2 relTol 6E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) + } + + test("binary logistic regression without intercept with ElasticNet regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.001005743 + data.V3 0.072577857 + data.V4 -0.081203769 + data.V5 -0.142534158 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 absTol 1E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE, standardize=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 0.03345223 + data.V4 -0.11304532 + data.V5 . + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) + } + + test("binary logistic regression with intercept with strong L1 regularization") { + val trainer1 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true) + .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false) + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label } + .treeAggregate(new MultiClassSummarizer)( + seqOp = (c, v) => (c, v) match { + case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) + }, + combOp = (c1, c2) => (c1, c2) match { + case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => + classSummarizer1.merge(classSummarizer2) + }).histogram + + /* + For binary logistic regression with strong L1 regularization, all the weights will be zeros. + As a result, + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} + */ + val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) + val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) + + assert(model1.intercept ~== interceptTheory relTol 1E-5) + assert(model1.weights ~= weightsTheory absTol 1E-6) + + assert(model2.intercept ~== interceptTheory relTol 1E-5) + assert(model2.weights ~= weightsTheory absTol 1E-6) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.2480643 + data.V2 0.0000000 + data.V3 . + data.V4 . + data.V5 . + */ + val interceptR = -0.248065 + val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) + + assert(model1.intercept ~== interceptR relTol 1E-5) + assert(model1.weights ~== weightsR absTol 1E-6) + } + + test("evaluate on test set") { + // Evaluate on test set should be same as that of the transformed training data. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary] + assert(summary.areaUnderROC === sameSummary.areaUnderROC) + assert(summary.roc.collect() === sameSummary.roc.collect()) + assert(summary.pr.collect === sameSummary.pr.collect()) + assert( + summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect()) + assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect()) + assert( + summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) + } + + test("statistics on training data") { + // Test that loss is monotonically decreasing. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala new file mode 100644 index 000000000000..ddc948f65df4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning as binary classification problem with two outputs.") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100) + val model = trainer.fit(dataFrame) + val result = model.transform(dataFrame) + val predictionAndLabels = result.select("prediction", "label").collect() + predictionAndLabels.foreach { case Row(p: Double, l: Double) => + assert(p == l) + } + } + + // TODO: implement a more rigorous test + test("3 class classification with 2 hidden layers") { + val nPoints = 1000 + + // The following weights are taken from OneVsRestSuite.scala + // they represent 3-class iris dataset + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val numClasses = 3 + val numIterations = 100 + val layers = Array[Int](4, 5, 4, numClasses) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(numIterations) + val model = trainer.fit(dataFrame) + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } + // train multinomial logistic regression + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(true) + .setNumClasses(numClasses) + lr.optimizer.setRegParam(0.0) + .setNumIterations(numIterations) + val lrModel = lr.run(rdd) + val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // MLP's predictions should not differ a lot from LR's. + val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) + val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) + assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala new file mode 100644 index 000000000000..98bc9511163e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import breeze.linalg.{Vector => BV} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.classification.NaiveBayesSuite._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + + def validatePrediction(predictionAndLabels: DataFrame): Unit = { + val numOfErrorPredictions = predictionAndLabels.collect().count { + case Row(prediction: Double, label: Double) => + prediction != label + } + // At least 80% of the predictions should be on. + assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + } + + def validateModelFit( + piData: Vector, + thetaData: Matrix, + model: NaiveBayesModel): Unit = { + assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~== + Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch") + assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") + } + + def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) + val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) + val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def validateProbabilities( + featureAndProbabilities: DataFrame, + model: NaiveBayesModel, + modelType: String): Unit = { + featureAndProbabilities.collect().foreach { + case Row(features: Vector, probability: Vector) => { + assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) + val expected = modelType match { + case Multinomial => + expectedMultinomialProbabilities(model, features) + case Bernoulli => + expectedBernoulliProbabilities(model, features) + case _ => + throw new UnknownError(s"Invalid modelType: $modelType.") + } + assert(probability ~== expected relTol 1.0e-10) + } + } + } + + test("params") { + ParamsSuite.checkParams(new NaiveBayes) + val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), + theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4))) + ParamsSuite.checkParams(model) + } + + test("naive bayes: default params") { + val nb = new NaiveBayes + assert(nb.getLabelCol === "label") + assert(nb.getFeaturesCol === "features") + assert(nb.getPredictionCol === "prediction") + assert(nb.getSmoothing === 1.0) + assert(nb.getModelType === "multinomial") + } + + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 42, "multinomial")) + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) + assert(model.hasParent) + + val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 17, "multinomial")) + + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "multinomial") + } + + test("Naive Bayes Bernoulli") { + val nPoints = 10000 + val piArray = Array(0.5, 0.3, 0.2).map(math.log) + val thetaArray = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) + + val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 45, "bernoulli")) + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) + assert(model.hasParent) + + val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + piArray, thetaArray, nPoints, 20, "bernoulli")) + + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "bernoulli") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala new file mode 100644 index 000000000000..977f0e0b70c1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.Metadata + +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var dataset: DataFrame = _ + @transient var rdd: RDD[LabeledPoint] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val nPoints = 1000 + + // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + dataset = sqlContext.createDataFrame(rdd) + } + + test("params") { + ParamsSuite.checkParams(new OneVsRest) + val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0) + val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel)) + ParamsSuite.checkParams(model) + } + + test("one-vs-rest: default params") { + val numClasses = 3 + val ova = new OneVsRest() + .setClassifier(new LogisticRegression) + assert(ova.getLabelCol === "label") + assert(ova.getPredictionCol === "prediction") + val ovaModel = ova.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(ovaModel) + + assert(ovaModel.models.size === numClasses) + val transformedDataset = ovaModel.transform(dataset) + + // check for label metadata in prediction col + val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) + assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) + + val ovaResults = transformedDataset + .select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) + lr.optimizer.setRegParam(0.1).setNumIterations(100) + + val model = lr.run(rdd) + val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // determine the #confusion matrix in each class. + // bound how much error we allow compared to multinomial logistic regression. + val expectedMetrics = new MulticlassMetrics(results) + val ovaMetrics = new MulticlassMetrics(ovaResults) + assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400) + } + + test("one-vs-rest: pass label metadata correctly during train") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new MockLogisticRegression) + + val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata()) + val features = dataset("features").as("features") + val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) + ova.fit(datasetWithLabelMetadata) + } + + test("SPARK-8092: ensure label features and prediction cols are configurable") { + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexed") + + val indexedDataset = labelIndexer + .fit(dataset) + .transform(dataset) + .drop("label") + .withColumnRenamed("features", "f") + + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression()) + .setLabelCol(labelIndexer.getOutputCol) + .setFeaturesCol("f") + .setPredictionCol("p") + + val ovaModel = ova.fit(indexedDataset) + val transformedDataset = ovaModel.transform(indexedDataset) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields.contains("p")) + } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } + + test("OneVsRest.copy and OneVsRestModel.copy") { + val lr = new LogisticRegression() + .setMaxIter(1) + + val ovr = new OneVsRest() + withClue("copy with classifier unset should work") { + ovr.copy(ParamMap(lr.maxIter -> 10)) + } + ovr.setClassifier(lr) + val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10)) + require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects") + require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, + "copy should handle extra classifier params") + + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1))) + ovrModel.models.foreach { case m: LogisticRegressionModel => + require(m.getThreshold === 0.1, "copy should handle extra model params") + } + } +} + +private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { + + def this() = this("mockLogReg") + + setMaxIter(1) + + override protected def train(dataset: DataFrame): LogisticRegressionModel = { + val labelSchema = dataset.schema($(labelCol)) + // check for label attribute propagation. + assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) + super.train(dataset) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala new file mode 100644 index 000000000000..8f50cb924e64 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +final class TestProbabilisticClassificationModel( + override val uid: String, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] { + + override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra) + + override protected def predictRaw(input: Vector): Vector = { + input + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction + } + + def friendlyPredict(input: Vector): Double = { + predict(input) + } +} + + +class ProbabilisticClassifierSuite extends SparkFunSuite { + + test("test thresholding") { + val thresholds = Array(0.5, 0.2) + val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + } + + test("test thresholding not required") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala new file mode 100644 index 000000000000..b4403ec30049 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Test suite for [[RandomForestClassifier]]. + */ +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestClassifierSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + orderedLabeledPoints5_20 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) { + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + val newRF = rf + .setImpurity("Gini") + .setMaxDepth(2) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) + } + + test("params") { + ParamsSuite.checkParams(new RandomForestClassifier) + val model = new RandomForestClassificationModel("rfc", + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2) + ParamsSuite.checkParams(model) + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("Binary classification with continuous features and node Id cache:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + .setCacheNodeIds(true) + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + val rdd = sc.parallelize(arr) + val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) + val numClasses = 3 + + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(5) + .setNumTrees(2) + .setFeatureSubsetStrategy("sqrt") + .setSeed(12345) + compareAPIs(rdd, rf, categoricalFeatures, numClasses) + } + + test("subsampling rate in RandomForest"){ + val rdd = orderedLabeledPoints5_20 + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val rf1 = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setCacheNodeIds(true) + .setNumTrees(3) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(rdd, rf1, categoricalFeatures, numClasses) + + val rf2 = rf1.setSubsamplingRate(0.5) + compareAPIs(rdd, rf2, categoricalFeatures, numClasses) + } + + test("predictRaw and predictProbability") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + val predictions = model.transform(df) + .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = + Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) + val newModel = RandomForestClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newModel = rf.fit(newData) + // Use parent from newTree since this is not checked anyways. + val oldModelAsNew = RandomForestClassificationModel.fromOld( + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, + numClasses) + TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.hasParent) + assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) + assert(newModel.numClasses == numClasses) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala new file mode 100644 index 000000000000..688b0e31f91d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +private[clustering] case class TestRow(features: Vector) + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } +} + +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + final val k = 5 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val kmeans = new KMeans() + + assert(kmeans.getK === 2) + assert(kmeans.getFeaturesCol === "features") + assert(kmeans.getPredictionCol === "prediction") + assert(kmeans.getMaxIter === 20) + assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) + assert(kmeans.getInitSteps === 5) + assert(kmeans.getTol === 1e-4) + } + + test("set parameters") { + val kmeans = new KMeans() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setInitMode(MLlibKMeans.RANDOM) + .setInitSteps(3) + .setSeed(123) + .setTol(1e-3) + + assert(kmeans.getK === 9) + assert(kmeans.getFeaturesCol === "test_feature") + assert(kmeans.getPredictionCol === "test_prediction") + assert(kmeans.getMaxIter === 33) + assert(kmeans.getInitMode === MLlibKMeans.RANDOM) + assert(kmeans.getInitSteps === 3) + assert(kmeans.getSeed === 123) + assert(kmeans.getTol === 1e-3) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new KMeans().setK(1) + } + intercept[IllegalArgumentException] { + new KMeans().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new KMeans().setInitSteps(0) + } + } + + test("fit & transform") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala new file mode 100644 index 000000000000..def869fe6677 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class BinaryClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new BinaryClassificationEvaluator) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala new file mode 100644 index 000000000000..6d8412b0b370 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new MulticlassClassificationEvaluator) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala new file mode 100644 index 000000000000..aa722da32393 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new RegressionEvaluator) + } + + test("Regression Evaluator: default params") { + /** + * Here is the instruction describing how to export the test data into CSV format + * so we can validate the metrics compared with R's mmetric package. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, + * Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1)) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) + * .saveAsTextFile("path") + */ + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + /** + * Using the following R code to load the data, train the model and evaluate metrics. + * + * > library("glmnet") + * > library("rminer") + * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * > label <- as.numeric(data$V1) + * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0) + * > rmse <- mmetric(label, predict(model, features), metric='RMSE') + * > mae <- mmetric(label, predict(model, features), metric='MAE') + * > r2 <- mmetric(label, predict(model, features), metric='R2') + */ + val trainer = new LinearRegression + val model = trainer.fit(dataset) + val predictions = model.transform(dataset) + + // default = rmse + val evaluator = new RegressionEvaluator() + assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + + // r2 score + evaluator.setMetricName("r2") + assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + + // mae + evaluator.setMetricName("mae") + assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala new file mode 100644 index 000000000000..208604398366 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var data: Array[Double] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) + } + + test("params") { + ParamsSuite.checkParams(new Binarizer) + } + + test("Binarize continuous features with default parameter") { + val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(defaultBinarized)).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The feature value is not correct after binarization.") + } + } + + test("Binarize continuous features with setter") { + val threshold: Double = 0.2 + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(thresholdBinarized)).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(threshold) + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The feature value is not correct after binarization.") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala new file mode 100644 index 000000000000..0eba34fda622 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new Bucketizer) + } + + test("Bucket continuous features, without -inf,inf") { + // Check a set of valid feature values. + val splits = Array(-0.5, 0.0, 0.5) + val validData = Array(-0.5, -0.3, 0.0, 0.2) + val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) + val dataFrame: DataFrame = + sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + + // Check for exceptions when using a set of invalid feature values. + val invalidData1: Array[Double] = Array(-0.9) ++ validData + val invalidData2 = Array(0.51) ++ validData + val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") + withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer.transform(badDF1).collect() + } + } + val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") + withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer.transform(badDF2).collect() + } + } + } + + test("Bucket continuous features, with -inf,inf") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) + val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) + val dataFrame: DataFrame = + sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } + + test("Binary search correctness on hand-picked examples") { + import BucketizerSuite.checkBinarySearch + // length 3, with -inf + checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0)) + // length 4 + checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0)) + // length 5 + checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5)) + // length 3, with inf + checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity)) + // length 3, with -inf and inf + checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity)) + // length 4, with -inf and inf + checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)) + } + + test("Binary search correctness in contrast with linear search, on random data") { + val data = Array.fill(100)(Random.nextDouble()) + val splits: Array[Double] = Double.NegativeInfinity +: + Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity + val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) + assert(bsResult ~== lsResult absTol 1e-5) + } +} + +private object BucketizerSuite extends SparkFunSuite { + /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ + def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { + require(feature >= splits.head) + var i = 0 + val n = splits.length - 1 + while (i < n) { + if (feature < splits(i + 1)) return i + i += 1 + } + throw new RuntimeException( + s"linearSearchForBuckets failed to find bucket for feature value $feature") + } + + /** Check all values in splits, plus values between all splits. */ + def checkBinarySearch(splits: Array[Double]): Unit = { + def testFeature(feature: Double, expectedBucket: Double): Unit = { + assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + + s" ${splits.mkString(", ")}") + } + var i = 0 + val n = splits.length - 1 + while (i < n) { + // Split i should fall in bucket i. + testFeature(splits(i), i) + // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf. + testFeature((splits(i) + splits(i + 1)) / 2, i) + i += 1 + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala new file mode 100644 index 000000000000..e192fa4850af --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) + } + + private def split(s: String): Seq[String] = s.split("\\s+") + + test("CountVectorizerModel common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d"), + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, split("a b b c d a"), + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split(""), Vectors.sparse(4, Seq())), // empty string + (4, split("a notInDict d"), + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d e"), + Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), + (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), + (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), + (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .fit(df) + assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer vocabSize and minDF") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq((0, 1.0))))) + ).toDF("id", "words", "expected") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .fit(df) + assert(cvModel.vocabulary === Array("a", "b", "c")) + + // minDF: ignore terms with count less than 3 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3) + .fit(df) + assert(cvModel2.vocabulary === Array("a", "b")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // minDF: ignore terms with freq < 0.75 + val cvModel3 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3.0 / df.count()) + .fit(df) + assert(cvModel3.vocabulary === Array("a", "b")) + + cvModel3.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer throws exception when vocab is empty") { + intercept[IllegalArgumentException] { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a b b c c")), + (1, split("aa bb cc"))) + ).toDF("id", "words") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .setMinDF(3) + .fit(df) + } + } + + test("CountVectorizerModel with minTF count") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq())), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTF freq") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(0.3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala new file mode 100644 index 000000000000..37ed2367c33f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class DCTTestData(vec: Vector, wantedVec: Vector) + +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("forward transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = false + + testDCT(data, inverse) + } + + test("inverse transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = true + + testDCT(data, inverse) + } + + private def testDCT(data: Vector, inverse: Boolean): Unit = { + val expectedResultBuffer = data.toArray.clone() + if (inverse) { + (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true) + } else { + (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true) + } + val expectedResult = Vectors.dense(expectedResultBuffer) + + val dataset = sqlContext.createDataFrame(Seq( + DCTTestData(data, expectedResult) + )) + + val transformer = new DCT() + .setInputCol("vec") + .setOutputCol("resultVec") + .setInverse(inverse) + + transformer.transform(dataset) + .select("resultVec", "wantedVec") + .collect() + .foreach { case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala new file mode 100644 index 000000000000..4157b84b29d0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new HashingTF) + } + + test("hashingTF") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b b c d".split(" ").toSeq) + )).toDF("id", "words") + val n = 100 + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + val output = hashingTF.transform(df) + val attrGroup = AttributeGroup.fromStructField(output.schema("features")) + require(attrGroup.numAttributes === Some(n)) + val features = output.select("features").first().getAs[Vector](0) + // Assume perfect hash on "a", "b", "c", and "d". + def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n) + val expected = Vectors.sparse(n, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) + assert(features ~== expected absTol 1e-14) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala new file mode 100644 index 000000000000..08f80af03429 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { + + def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { + dataSet.map { + case data: DenseVector => + val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } + Vectors.dense(res) + case data: SparseVector => + val res = data.indices.zip(data.values).map { case (id, value) => + (id, value * model(id)) + } + Vectors.sparse(data.size, res) + } + } + + test("params") { + ParamsSuite.checkParams(new IDF) + val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0))) + ParamsSuite.checkParams(model) + } + + test("compute IDF with default parameter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((numOfData + 1.0) / (x + 1.0)) + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } + + test("compute IDF with setter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .setMinDocFreq(1) + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala new file mode 100644 index 000000000000..c04dda41eea3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("MinMaxScaler fit basic case") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.dense(1, 0, Long.MinValue), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)), + Vectors.sparse(3, Array(0), Array(1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(-5, 0, -5), + Vectors.dense(0, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(5, 5)), + Vectors.sparse(3, Array(0), Array(-2.5))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaled") + .setMin(-5) + .setMax(5) + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), "Transformed vector is different with expected.") + } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("MinMaxScaler arguments max must be larger than min") { + withClue("arguments max must be larger than min") { + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(10).setMax(0) + scaler.validateParams() + } + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(0).setMax(0) + scaler.validateParams() + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala new file mode 100644 index 000000000000..ab97e3dbc6ee --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) + +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.NGramSuite._ + + test("default behavior yields bigram features") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("Test", "for", "ngram", "."), + Array("Test for", "for ngram", "ngram .") + ))) + testNGram(nGram, dataset) + } + + test("NGramLength=4 yields length 4 n-grams") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array("a b c d", "b c d e") + ))) + testNGram(nGram, dataset) + } + + test("empty input yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array(), + Array() + ))) + testNGram(nGram, dataset) + } + + test("input array < n yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(6) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array() + ))) + testNGram(nGram, dataset) + } +} + +object NGramSuite extends SparkFunSuite { + + def testNGram(t: NGram, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("nGrams", "wantedNGrams") + .collect() + .foreach { case Row(actualNGrams, wantedNGrams) => + assert(actualNGrams === wantedNGrams) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala new file mode 100644 index 000000000000..9f03470b7f32 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var data: Array[Vector] = _ + @transient var dataFrame: DataFrame = _ + @transient var normalizer: Normalizer = _ + @transient var l1Normalized: Array[Vector] = _ + @transient var l2Normalized: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + l1Normalized = Array( + Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.12765957, -0.23404255, -0.63829787), + Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))), + Vectors.dense(0.625, 0.07894737, 0.29605263), + Vectors.sparse(3, Seq()) + ) + l2Normalized = Array( + Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.184549876, -0.3383414, -0.922749378), + Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))), + Vectors.dense(0.897906166, 0.113419726, 0.42532397), + Vectors.sparse(3, Seq()) + ) + + val sqlContext = new SQLContext(sc) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normalized_features") + } + + def collectResult(result: DataFrame): Array[Vector] = { + result.select("normalized_features").collect().map { + case Row(features: Vector) => features + } + } + + def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + } + + def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { (vector1, vector2) => + vector1 ~== vector2 absTol 1E-5 + }, "The vector value is not correct after normalization.") + } + + test("Normalization with default parameter") { + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l2Normalized) + } + + test("Normalization with setter") { + normalizer.setP(1) + + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l1Normalized) + } +} + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala new file mode 100644 index 000000000000..321eeb843941 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col + +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { + + def stringIndexed(): DataFrame = { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + indexer.transform(df) + } + + test("params") { + ParamsSuite.checkParams(new OneHotEncoder) + } + + test("OneHotEncoder dropLast = false") { + val transformed = stringIndexed() + val encoder = new OneHotEncoder() + .setInputCol("labelIndex") + .setOutputCol("labelVec") + .setDropLast(false) + val encoded = encoder.transform(transformed) + + val output = encoded.select("id", "labelVec").map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), + (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) + assert(output === expected) + } + + test("OneHotEncoder dropLast = true") { + val transformed = stringIndexed() + val encoder = new OneHotEncoder() + .setInputCol("labelIndex") + .setOutputCol("labelVec") + val encoded = encoder.transform(transformed) + + val output = encoded.select("id", "labelVec").map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) + assert(output === expected) + } + + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala new file mode 100644 index 000000000000..30c500f87a76 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} +import org.apache.spark.sql.Row + +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new PCA) + val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] + val model = new PCAModel("pca", new OldPCAModel(2, mat)) + ParamsSuite.checkParams(model) + } + + test("pca") { + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val mat = new RowMatrix(dataRDD) + val pc = mat.computePrincipalComponents(3) + val expected = mat.multiply(pc).rows + + val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setK(3) + .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(pca) + + pca.transform(df).select("pca_features", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala new file mode 100644 index 000000000000..29eebd8960eb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.ml.param.ParamsSuite +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new PolynomialExpansion) + } + + test("Polynomial expansion with default parameter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val twoDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(new Array[Double](9)), + Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(9, Array.empty, Array.empty)) + + val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } + + test("Polynomial expansion with setter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val threeDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(new Array[Double](19)), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), + Vectors.sparse(19, Array.empty, Array.empty)) + + val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala new file mode 100644 index 000000000000..436e66bab09b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class RFormulaParserSuite extends SparkFunSuite { + private def checkParse( + formula: String, + label: String, + terms: Seq[String], + schema: StructType = null) { + val resolved = RFormulaParser.parse(formula).resolve(schema) + assert(resolved.label == label) + assert(resolved.terms == terms) + } + + test("parse simple formulas") { + checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ x + x", "y", Seq("x")) + checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) + } + + test("parse dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ .", "a", Seq("b", "c"), schema) + } + + test("parse deletion") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ c - b", "a", Seq("c"), schema) + } + + test("parse additions and deletions in order") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ . - b + . - c", "a", Seq("b"), schema) + } + + test("dot ignores complex column types") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "tinyint", false) + .add("c", "map", true) + checkParse("a ~ .", "a", Seq("b"), schema) + } + + test("parse intercept") { + assert(RFormulaParser.parse("a ~ b").hasIntercept) + assert(RFormulaParser.parse("a ~ b + 1").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 0").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept) + assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala new file mode 100644 index 000000000000..6aed3243afce --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RFormula()) + } + + test("transform numeric data") { + val formula = new RFormula().setFormula("id ~ v1 + v2") + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), + (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) + ).toDF("id", "v1", "v2", "features", "label") + // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect() === expected.collect()) + } + + test("features column already exists") { + val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.fit(original) + } + intercept[IllegalArgumentException] { + formula.fit(original) + } + } + + test("label column already exists") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(resultSchema.toString == model.transform(original).schema.toString) + } + + test("label column already exists but is not double type") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val model = formula.fit(original) + intercept[IllegalArgumentException] { + model.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + model.transform(original) + } + } + + test("allow missing label column for test datasets") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(!resultSchema.exists(_.name == "label")) + assert(resultSchema.toString == model.transform(original).schema.toString) + } + + test("encodes string terms") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 000000000000..d19052881ae4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala new file mode 100644 index 000000000000..f01306f89cb5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +object StopWordsRemoverSuite extends SparkFunSuite { + def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("filtered", "expected") + .collect() + .foreach { case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} + +class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { + import StopWordsRemoverSuite._ + + test("StopWordsRemover default") { + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("test", "test"), Seq("test", "test")), + (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), + (Seq("a", "the", "an"), Seq()), + (Seq("A", "The", "AN"), Seq()), + (Seq(null), Seq(null)), + (Seq(), Seq()) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover case sensitive") { + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setCaseSensitive(true) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("A"), Seq("A")), + (Seq("The", "the"), Seq("The")) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with additional words") { + val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("python", "scala", "a"), Seq()), + (Seq("Python", "Scala", "swift"), Seq("swift")) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala new file mode 100644 index 000000000000..05e05bdc64bb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col + +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new StringIndexer) + val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) + ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) + } + + test("StringIndexer") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(indexer) + + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } + + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) + } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } + + test("StringIndexerModel should keep silent if the input column does not exist.") { + val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + .setInputCol("label") + .setOutputCol("labelIndex") + val df = sqlContext.range(0L, 10L) + assert(indexerModel.transform(df).eq(df)) + } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } + + test("StringIndexer, IndexToString are inverses") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val idx2str = new IndexToString() + .setInputCol("labelIndex") + .setOutputCol("sameLabel") + .setLabels(indexer.labels) + idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { + case Row(a: String, b: String) => + assert(a === b) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala new file mode 100644 index 000000000000..e5fd21c3f6fc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) + +class TokenizerSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new Tokenizer) + } +} + +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ + + test("params") { + ParamsSuite.checkParams(new RegexTokenizer) + } + + test("RegexTokenizer") { + val tokenizer0 = new RegexTokenizer() + .setGaps(false) + .setPattern("\\w+|\\p{Punct}") + .setInputCol("rawText") + .setOutputCol("tokens") + val dataset0 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + )) + testRegexTokenizer(tokenizer0, dataset0) + + val dataset1 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Array("punct")) + )) + tokenizer0.setMinTokenLength(3) + testRegexTokenizer(tokenizer0, dataset1) + + val tokenizer2 = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + val dataset2 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) + )) + testRegexTokenizer(tokenizer2, dataset2) + } +} + +object RegexTokenizerSuite extends SparkFunSuite { + + def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("tokens", "wantedTokens") + .collect() + .foreach { case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala new file mode 100644 index 000000000000..bb4d5b983e0d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col + +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new VectorAssembler) + } + + test("assemble") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + val dv = Vectors.dense(2.0, 0.0) + assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) + assert(assemble(0.0, dv, 1.0, sv) === + Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) + for (v <- Seq(1, "a", null)) { + intercept[SparkException](assemble(v)) + intercept[SparkException](assemble(1.0, v)) + } + } + + test("assemble should compress vectors") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + assert(v1.isInstanceOf[SparseVector]) + val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + assert(v2.isInstanceOf[DenseVector]) + } + + test("VectorAssembler") { + val df = sqlContext.createDataFrame(Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) + )).toDF("id", "x", "y", "name", "z", "n") + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + assembler.transform(df).select("features").collect().foreach { + case Row(v: Vector) => + assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) + } + } + + test("ML attributes") { + val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") + val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) + val user = new AttributeGroup("user", Array( + NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), + NumericAttribute.defaultAttr.withName("salary"))) + val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) + val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + .select( + col("browser").as("browser", browser.toMetadata()), + col("hour").as("hour", hour.toMetadata()), + col("count"), // "count" is an integer column without ML attribute + col("user").as("user", user.toMetadata()), + col("ad")) // "ad" is a vector column without ML attribute + val assembler = new VectorAssembler() + .setInputCols(Array("browser", "hour", "count", "user", "ad")) + .setOutputCol("features") + val output = assembler.transform(df) + val schema = output.schema + val features = AttributeGroup.fromStructField(schema("features")) + assert(features.size === 7) + val browserOut = features.getAttr(0) + assert(browserOut === browser.withIndex(0).withName("browser")) + val hourOut = features.getAttr(1) + assert(hourOut === hour.withIndex(1).withName("hour")) + val countOut = features.getAttr(2) + assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2)) + val userGenderOut = features.getAttr(3) + assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) + val userSalaryOut = features.getAttr(4) + assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 000000000000..8cb0a2cf14d3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.apache.spark.{Logging, SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + import VectorIndexerSuite.FeatureData + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("params") { + ParamsSuite.checkParams(new VectorIndexer) + val model = new VectorIndexerModel("indexer", 1, Map.empty) + ParamsSuite.checkParams(model) + } + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[SparkException] { + model.transform(densePoints2).collect() + logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + logInfo("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + // Chose correct categorical features + assert(categoryMaps.keys.toSet === categoricalFeatures) + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + logError(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata())) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala new file mode 100644 index 000000000000..a6c2fba8360d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + val slicer = new VectorSlicer + ParamsSuite.checkParams(slicer) + assert(slicer.getIndices.length === 0) + assert(slicer.getNames.length === 0) + withClue("VectorSlicer should not have any features selected by default") { + intercept[IllegalArgumentException] { + slicer.validateParams() + } + } + } + + test("feature validity checks") { + import VectorSlicer._ + assert(validIndices(Array(0, 1, 8, 2))) + assert(validIndices(Array.empty[Int])) + assert(!validIndices(Array(-1))) + assert(!validIndices(Array(1, 2, 1))) + + assert(validNames(Array("a", "b"))) + assert(validNames(Array.empty[String])) + assert(!validNames(Array("", "b"))) + assert(!validNames(Array("a", "b", "a"))) + } + + test("Test vector slicer") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), + Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3), + Vectors.sparse(5, Seq()) + ) + + // Expected after selecting indices 1, 4 + val expected = Array( + Vectors.sparse(2, Seq((0, 2.3))), + Vectors.dense(2.3, 1.0), + Vectors.dense(0.0, 0.0), + Vectors.dense(-1.1, 3.3), + Vectors.sparse(2, Seq()) + ) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]]) + + val resultAttrs = Array("f1", "f4").map(defaultAttr.withName) + val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) + + val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } + val df = sqlContext.createDataFrame(rdd, + StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) + + val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") + + def validateResults(df: DataFrame): Unit = { + df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 === vec2) + } + val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) + resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => + assert(a === b) + } + } + + vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) + validateResults(vectorSlicer.transform(df)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala new file mode 100644 index 000000000000..a2e46f202995 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} + +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new Word2Vec) + val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f)))) + ParamsSuite.checkParams(model) + } + + test("Word2Vec") { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val numOfWords = sentence.split(" ").size + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + + val expected = doc.map { sentence => + Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) => + word1.zip(word2).map { case (v1, v2) => v1 + v2 } + ).map(_ / numOfWords)) + } + + val docDF = doc.zip(expected).toDF("text", "expected") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(docDF).select("result", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + } + } + + test("getVectors") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } + + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val realVectors = model.getVectors.sort("word").select("vector").map { + case Row(v: Vector) => v + }.collect() + + realVectors.zip(expectedVectors).foreach { + case (real, expected) => + assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") + } + } + + test("findSynonyms") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val (synonyms, similarity) = model.findSynonyms("a", 2).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + assert(synonyms.toArray === Array("b", "c")) + expectedSimilarity.zip(similarity).map { + case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + } + + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala new file mode 100644 index 000000000000..460849c79f04 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame} + + +private[ml] object TreeTests extends SparkFunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val sqlContext = new SQLContext(data.sparkContext) + import sqlContext.implicits._ + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** Java-friendly version of [[setMetadata()]] */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + (a, b) match { + case (aye: InternalNode, bee: InternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: LeafNode) => // do nothing + case _ => + throw new AssertionError("Found mismatched nodes") + } + } + + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { + try { + a.trees.zip(b.trees).foreach { case (treeA, treeB) => + TreeTests.checkEqual(treeA, treeB) + } + assert(a.treeWeights === b.treeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + + /** + * Helper method for constructing a tree for testing. + * Given left, right children, construct a parent node. + * @param split Split for parent node + * @return Parent node with children attached + */ + def buildParentNode(left: Node, right: Node, split: Split): Node = { + val leftImp = left.impurityStats + val rightImp = right.impurityStats + val parentImp = leftImp.copy.add(rightImp) + val leftWeight = leftImp.count / parentImp.count.toDouble + val rightWeight = rightImp.count / parentImp.count.toDouble + val gain = parentImp.calculate() - + (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) + val pred = parentImp.predict + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce298761237..2c878f8372a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,23 +17,38 @@ package org.apache.spark.ml.param -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ParamsSuite extends FunSuite { - - val solver = new TestParams() - import solver.{inputCol, maxIter} +class ParamsSuite extends SparkFunSuite { test("param") { + val solver = new TestParams() + val uid = solver.uid + import solver.{maxIter, inputCol} + assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) - assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.doc === "maximum number of iterations (>= 0)") + assert(maxIter.parent === uid) + assert(maxIter.toString === s"${uid}__maxIter") + assert(!maxIter.isValid(-1)) + assert(maxIter.isValid(0)) + assert(maxIter.isValid(1)) + + solver.setMaxIter(5) + assert(solver.explainParam(maxIter) === + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") + + assert(inputCol.toString === s"${uid}__inputCol") + + intercept[IllegalArgumentException] { + solver.setMaxIter(-1) + } } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -41,16 +56,24 @@ class ParamsSuite extends FunSuite { assert(pair.param.eq(maxIter)) assert(pair.value === 5) } + intercept[IllegalArgumentException] { + val pair = maxIter -> -1 + } } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) + intercept[IllegalArgumentException] { + map0.put(maxIter, -1) + } assert(!map0.contains(inputCol)) intercept[NoSuchElementException] { @@ -78,31 +101,142 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) - assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + + assert(solver.explainParam(maxIter) === + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") + assert(solver.explainParams() === + Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + assert(solver.hasParam("inputCol")) + assert(!solver.hasParam("abc")) + intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { - solver.validate() + solver.validateParams() } - solver.validate(ParamMap(inputCol -> "input")) + solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") - solver.validate() + solver.validateParams() intercept[IllegalArgumentException] { - solver.validate(ParamMap(maxIter -> -10)) + ParamMap(maxIter -> -10) } - solver.setMaxIter(-10) intercept[IllegalArgumentException] { - solver.validate() + solver.setMaxIter(-10) } + + solver.clearMaxIter() + assert(!solver.isSet(maxIter)) + + val copied = solver.copy(ParamMap(solver.maxIter -> 50)) + assert(copied.uid === solver.uid) + assert(copied.getInputCol === solver.getInputCol) + assert(copied.getMaxIter === 50) + } + + test("ParamValidate") { + val alwaysTrue = ParamValidators.alwaysTrue[Int] + assert(alwaysTrue(1)) + + val gt1Int = ParamValidators.gt[Int](1) + assert(!gt1Int(1) && gt1Int(2)) + val gt1Double = ParamValidators.gt[Double](1) + assert(!gt1Double(1.0) && gt1Double(1.1)) + + val gtEq1Int = ParamValidators.gtEq[Int](1) + assert(!gtEq1Int(0) && gtEq1Int(1)) + val gtEq1Double = ParamValidators.gtEq[Double](1) + assert(!gtEq1Double(0.9) && gtEq1Double(1.0)) + + val lt1Int = ParamValidators.lt[Int](1) + assert(lt1Int(0) && !lt1Int(1)) + val lt1Double = ParamValidators.lt[Double](1) + assert(lt1Double(0.9) && !lt1Double(1.0)) + + val ltEq1Int = ParamValidators.ltEq[Int](1) + assert(ltEq1Int(1) && !ltEq1Int(2)) + val ltEq1Double = ParamValidators.ltEq[Double](1) + assert(ltEq1Double(1.0) && !ltEq1Double(1.1)) + + val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2) + assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) && + !inRange02IntInclusive(-1) && !inRange02IntInclusive(3)) + val inRange02IntExclusive = + ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false) + assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2)) + + val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2) + assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) && + inRange02DoubleInclusive(2) && + !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1)) + val inRange02DoubleExclusive = + ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false) + assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) && + !inRange02DoubleExclusive(2)) + + val inArray = ParamValidators.inArray[Int](Array(1, 2)) + assert(inArray(1) && inArray(2) && !inArray(0)) + + val arrayLengthGt = ParamValidators.arrayLengthGt[Int](2.0) + assert(arrayLengthGt(Array(0, 1, 2)) && !arrayLengthGt(Array(0, 1))) + } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) + } +} + +object ParamsSuite extends SparkFunSuite { + + /** + * Checks common requirements for [[Params.params]]: + * - params are ordered by names + * - param parent has the same UID as the object's UID + * - param name is the same as the param method name + * - obj.copy should return the same type as the obj + */ + def checkParams(obj: Params): Unit = { + val clazz = obj.getClass + + val params = obj.params + val paramNames = params.map(_.name) + require(paramNames === paramNames.sorted, "params must be ordered by names") + params.foreach { p => + assert(p.parent === obj.uid) + assert(obj.getParam(p.name) === p) + // TODO: Check that setters return self, which needs special handling for generic types. + } + + val copyMethod = clazz.getMethod("copy", classOf[ParamMap]) + val copyReturnType = copyMethod.getReturnType + require(copyReturnType === obj.getClass, + s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 1a65883d78a7..275924834453 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,20 +17,26 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.util.Identifiable + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { + + def this() = this(Identifiable.randomUID("testParams")) - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) - override def validate(paramMap: ParamMap) = { - val m = this.paramMap ++ paramMap - require(m(maxIter) >= 0) - require(m.contains(inputCol)) + setDefault(maxIter -> 10) + + def clearMaxIter(): this.type = clear(maxIter) + + override def validateParams(): Unit = { + super.validateParams() + require(isDefined(inputCol)) } + + override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala new file mode 100644 index 000000000000..b3af81a3c60b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.{ParamMap, Params} + +class SharedParamsSuite extends SparkFunSuite { + + test("outputCol") { + + class Obj(override val uid: String) extends Params with HasOutputCol { + override def copy(extra: ParamMap): Obj = defaultCopy(extra) + } + + val obj = new Obj("obj") + + assert(obj.hasDefault(obj.outputCol)) + assert(obj.getOrDefault(obj.outputCol) === "obj__output") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index acc447742bad..eadc80e0e62b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,29 +17,38 @@ package org.apache.spark.ml.recommendation +import java.io.File import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.Utils -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var sqlContext: SQLContext = _ + private var tempDir: File = _ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) + tempDir = Utils.createTempDir() + sc.setCheckpointDir(tempDir.getAbsolutePath) + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() } test("LocalIndexEncoder") { @@ -58,39 +67,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -98,41 +110,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -144,13 +131,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } @@ -350,16 +336,17 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { val sqlContext = this.sqlContext - import sqlContext.createDataFrame + import sqlContext.implicits._ val als = new ALS() .setRank(rank) .setRegParam(regParam) .setImplicitPrefs(implicitPrefs) .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) + .setSeed(0) val alpha = als.getAlpha - val model = als.fit(training) - val predictions = model.transform(test) + val model = als.fit(training.toDF()) + val predictions = model.transform(test.toDF()) .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) @@ -388,6 +375,9 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("exact rank-1 matrix") { @@ -437,17 +427,18 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating)) - val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4) + val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4, seed = 0) assert(longUserFactors.first()._1.getClass === classOf[Long]) val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating)) - val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) + val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4, seed = 0) assert(strUserFactors.first()._1.getClass === classOf[String]) } test("nonnegative constraint") { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) - val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true) + val (userFactors, itemFactors) = + ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true, seed = 0) def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = { factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _) } @@ -455,4 +446,41 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { assert(isNonnegative(itemFactors)) // TODO: Validate the solution. } + + test("als partitioner is a projection") { + for (p <- Seq(1, 10, 100, 1000)) { + val part = new ALSPartitioner(p) + var k = 0 + while (k < p) { + assert(k === part.getPartition(k)) + assert(k === part.getPartition(k.toLong)) + k += 1 + } + } + } + + test("partitioner in returned factors") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train( + ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4, seed = 0) + for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) { + assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.") + val part = userFactors.partitioner.get + userFactors.mapPartitionsWithIndex { (idx, items) => + items.foreach { case (id, _) => + if (part.getPartition(id) != idx) { + throw new SparkException(s"$tpe with ID $id should not be in partition $idx.") + } + } + Iterator.empty + }.count() + } + } + + test("als with large number of iterations") { + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, seed = 0) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, + implicitPrefs = true, seed = 0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala new file mode 100644 index 000000000000..b092bcd6a7e8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import DecisionTreeRegressorSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Regression stump with 3-ary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + test("Regression stump with binary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8).fit(df) + MLTestingUtils.checkCopy(model) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: test("model save/load") SPARK-6725 +} + +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newTree = dt.fit(newData) + // Use parent from newTree since this is not checked anyways. + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( + oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala new file mode 100644 index 000000000000..a68197b59193 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils + + +/** + * Test suite for [[GBTRegressor]]. + */ +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import GBTRegressorSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Regression with continuous features: SquaredError") { + val categoricalFeatures = Map.empty[Int, Int] + GBTRegressor.supportedLossTypes.foreach { loss => + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType(loss) + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + } + + test("GBTRegressor behaves reasonably on toy data") { + val df = sqlContext.createDataFrame(Seq( + LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), + LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), + LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), + LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), + LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), + LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) + )) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val preds = model.transform(df) + val predictions = preds.select("prediction").map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max() > 2) + assert(predictions.min() < -1) + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTRegressor.supportedLossTypes.foreach { loss => + val gbt = new GBTRegressor() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) + val newModel = GBTRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTRegressorSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = gbt.fit(newData) + // Use parent from newTree since this is not checked anyways. + val oldModelAsNew = GBTRegressionModel.fromOld( + oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala new file mode 100644 index 000000000000..c0ab00b68a2f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { + sqlContext.createDataFrame( + labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } + ).toDF("label", "features", "weight") + } + + private def generatePredictionInput(features: Seq[Double]): DataFrame = { + sqlContext.createDataFrame(features.map(Tuple1.apply)) + .toDF("features") + } + + test("isotonic regression predictions") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + val ir = new IsotonicRegression().setIsotonic(true) + + val model = ir.fit(dataset) + + val predictions = model + .transform(dataset) + .select("prediction").map { case Row(pred) => + pred + }.collect() + + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.getIsotonic) + } + + test("antitonic regression predictions") { + val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) + val ir = new IsotonicRegression().setIsotonic(false) + + val model = ir.fit(dataset) + val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } + + test("params validation") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression + ParamsSuite.checkParams(ir) + val model = ir.fit(dataset) + ParamsSuite.checkParams(model) + } + + test("default params") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + assert(ir.getLabelCol === "label") + assert(ir.getFeaturesCol === "features") + assert(ir.getPredictionCol === "prediction") + assert(!ir.isDefined(ir.weightCol)) + assert(ir.getIsotonic) + assert(ir.getFeatureIndex === 0) + + val model = ir.fit(dataset) + model.transform(dataset) + .select("label", "features", "prediction", "weight") + .collect() + + assert(model.getLabelCol === "label") + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(!model.isDefined(model.weightCol)) + assert(model.getIsotonic) + assert(model.getFeatureIndex === 0) + assert(model.hasParent) + } + + test("set parameters") { + val isotonicRegression = new IsotonicRegression() + .setIsotonic(false) + .setWeightCol("w") + .setFeaturesCol("f") + .setLabelCol("l") + .setPredictionCol("p") + + assert(!isotonicRegression.getIsotonic) + assert(isotonicRegression.getWeightCol === "w") + assert(isotonicRegression.getFeaturesCol === "f") + assert(isotonicRegression.getLabelCol === "l") + assert(isotonicRegression.getPredictionCol === "p") + } + + test("missing column") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + + intercept[IllegalArgumentException] { + new IsotonicRegression().setWeightCol("w").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setFeaturesCol("f").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setLabelCol("l").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) + } + } + + test("vector features column with feature index") { + val dataset = sqlContext.createDataFrame(Seq( + (4.0, Vectors.dense(0.0, 1.0)), + (3.0, Vectors.dense(0.0, 2.0)), + (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) + ).toDF("label", "features") + + val ir = new IsotonicRegression() + .setFeatureIndex(1) + + val model = ir.fit(dataset) + + val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala new file mode 100644 index 000000000000..2aaee71ecc73 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -0,0 +1,513 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var dataset: DataFrame = _ + @transient var datasetWithoutIntercept: DataFrame = _ + + /* + In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + is the same as the one trained by R's glmnet package. The following instruction + describes how to reproduce the data in R. + + import org.apache.spark.mllib.util.LinearDataGenerator + val data = + sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + .saveAsTextFile("path") + */ + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + /* + datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + training model without intercept + */ + datasetWithoutIntercept = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + + } + + test("params") { + ParamsSuite.checkParams(new LinearRegression) + val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("linear regression: default params") { + val lir = new LinearRegression + assert(lir.getLabelCol === "label") + assert(lir.getFeaturesCol === "features") + assert(lir.getPredictionCol === "prediction") + assert(lir.getRegParam === 0.0) + assert(lir.getElasticNetParam === 0.0) + assert(lir.getFitIntercept) + assert(lir.getStandardization) + val model = lir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(dataset) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + test("linear regression with intercept without regularization") { + val trainer1 = new LinearRegression + // The result should be the same regardless of standardization without regularization + val trainer2 = (new LinearRegression).setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.298698 + as.numeric.data.V2. 4.700706 + as.numeric.data.V3. 7.199082 + */ + val interceptR = 6.298698 + val weightsR = Vectors.dense(4.700706, 7.199082) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) + + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression without intercept without regularization") { + val trainer1 = (new LinearRegression).setFitIntercept(false) + // Without regularization the results should be the same + val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false) + val model1 = trainer1.fit(dataset) + val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept) + val model2 = trainer2.fit(dataset) + val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept) + + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 + */ + val weightsR = Vectors.dense(6.995908, 5.275131) + + assert(model1.intercept ~== 0 absTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) + assert(model2.intercept ~== 0 absTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) + + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 + */ + val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) + + assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3) + assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3) + } + + test("linear regression with intercept with L1 regularization") { + val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 + */ + val interceptR1 = 6.24300 + val weightsR1 = Vectors.dense(4.024821, 6.679841) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.416948 + as.numeric.data.V2. 3.893869 + as.numeric.data.V3. 6.724286 + */ + val interceptR2 = 6.416948 + val weightsR2 = Vectors.dense(3.893869, 6.724286) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression without intercept with L1 regularization") { + val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false) + val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false).setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(6.299752, 4.772913) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE, standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.232193 + as.numeric.data.V3. 4.764229 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(6.232193, 4.764229) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L2 regularization") { + val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.269376 + as.numeric.data.V2. 3.736216 + as.numeric.data.V3. 5.712356) + */ + val interceptR1 = 5.269376 + val weightsR1 = Vectors.dense(3.736216, 5.712356) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 5.791109 + as.numeric.data.V2. 3.435466 + as.numeric.data.V3. 5.910406 + */ + val interceptR2 = 5.791109 + val weightsR2 = Vectors.dense(3.435466, 5.910406) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression without intercept with L2 regularization") { + val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false) + val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false).setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(5.522875, 4.214502) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE, standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.263704 + as.numeric.data.V3. 4.187419 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(5.263704, 4.187419) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 + */ + val interceptR1 = 5.696056 + val weightsR1 = Vectors.dense(3.670489, 6.001122) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 + standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.114723 + as.numeric.data.V2. 3.409937 + as.numeric.data.V3. 6.146531 + */ + val interceptR2 = 6.114723 + val weightsR2 = Vectors.dense(3.409937, 6.146531) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression without intercept with ElasticNet regularization") { + val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false) + val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false).setStandardization(false) + val model1 = trainer1.fit(dataset) + val model2 = trainer2.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 + */ + val interceptR1 = 0.0 + val weightsR1 = Vectors.dense(5.673348, 4.322251) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE, standardize=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.477988 + as.numeric.data.V3. 4.297622 + */ + val interceptR2 = 0.0 + val weightsR2 = Vectors.dense(5.477988, 4.297622) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) + + model1.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression model training summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Training results for the model should be available + assert(model.hasSummary) + + // Residuals in [[LinearRegressionResults]] should equal those manually computed + val expectedResiduals = dataset.select("features", "label") + .map { case Row(features: DenseVector, label: Double) => + val prediction = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + label - prediction + } + .zip(model.summary.residuals.map(_.getDouble(0))) + .collect() + .foreach { case (manualResidual: Double, resultResidual: Double) => + assert(manualResidual ~== resultResidual relTol 1E-5) + } + + /* + Use the following R code to generate model training results. + + predictions <- predict(fit, newx=features) + residuals <- label - predictions + > mean(residuals^2) # MSE + [1] 0.009720325 + > mean(abs(residuals)) # MAD + [1] 0.07863206 + > cor(predictions, label)^2# r^2 + [,1] + s0 0.9998749 + */ + assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) + assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) + assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + + // Objective function should be monotonically decreasing for linear regression + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + + test("linear regression model testset evaluation summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Evaluating on training dataset should yield results summary equal to training summary + val testSummary = model.evaluate(dataset) + assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) + assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) + model.summary.residuals.select("residuals").collect() + .zip(testSummary.residuals.select("residuals").collect()) + .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala new file mode 100644 index 000000000000..7b1b3f11481d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * Test suite for [[RandomForestRegressor]]. + */ +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestRegressorSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val newRF = rf + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + regressionTestWithContinuousFeatures(rf) + } + + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + .setCacheNodeIds(true) + regressionTestWithContinuousFeatures(rf) + } + + test("Feature importance with toy data") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val importances = model.featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) + val newModel = RandomForestRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestRegressorSuite extends SparkFunSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = rf.fit(newData) + // Use parent from newTree since this is not checked anyways. + val oldModelAsNew = RandomForestRegressionModel.fromOld( + oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala new file mode 100644 index 000000000000..dc852795c7f6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestSuite.mapToVec + + test("computeFeatureImportance, featureImportances") { + /* Build tree for testing, with this structure: + grandParent + left2 parent + left right + */ + val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + + val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parentImp = parent.impurityStats + + val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandImp = grandParent.impurityStats + + // Test feature importance computed at different subtrees. + def testNode(node: Node, expected: Map[Int, Double]): Unit = { + val map = new OpenHashMap[Int, Double]() + RandomForest.computeFeatureImportance(node, map) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + + // Leaf node + testNode(left, Map.empty[Int, Double]) + + // Internal node with 2 leaf children + val feature0importance = parentImp.calculate() * parentImp.count - + (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count) + testNode(parent, Map(0 -> feature0importance)) + + // Full tree + val feature1importance = grandImp.calculate() * grandImp.count - + (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count) + testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance)) + + // Forest consisting of (full tree) + (internal node with 2 leafs) + val trees = Array(parent, grandParent).map { root => + new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel] + } + val importances: Vector = RandomForest.featureImportances(trees, 2) + val tree2norm = feature0importance + feature1importance + val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + (feature1importance / tree2norm) / 2.0) + assert(importances ~== expected relTol 0.01) + } + + test("normalizeMapValues") { + val map = new OpenHashMap[Int, Double]() + map(0) = 1.0 + map(2) = 2.0 + RandomForest.normalizeMapValues(map) + val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + +} + +private object RandomForestSuite { + + def mapToVec(map: Map[Int, Double]): Vector = { + val size = (map.keys.toSeq :+ 0).max + 1 + val (indices, values) = map.toSeq.sortBy(_._1).unzip + Vectors.sparse(size, indices.toArray, values.toArray) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 761ea821ef7c..fde02e0c84bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.ml.tuning -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @@ -49,8 +54,99 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) - val bestParamMap = cvModel.bestModel.fittingParamMap - assert(bestParamMap(lr.regParam) === 0.001) - assert(bestParamMap(lr.maxIter) === 10) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + } + + test("cross validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new CrossValidator() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.avgMetrics.length === lrParamMaps.length) + } + + test("validateParams should check estimatorParamMaps") { + import CrossValidatorSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new CrossValidator() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object CrossValidatorSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override def isLargerBetter: Boolean = true + + override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala index 20aa100112bf..810b70049ec1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} -class ParamGridBuilderSuite extends FunSuite { +class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala new file mode 100644 index 000000000000..ef24e6fb6b80 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { + test("train validation with logistic regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(cv.getTrainRatio === 0.5) + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.validationMetrics.length === lrParamMaps.length) + } + + test("train validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new TrainValidationSplit() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.validationMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.validationMetrics.length === lrParamMaps.length) + } + + test("validateParams should check estimatorParamMaps") { + import TrainValidationSplitSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new TrainValidationSplit() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object TrainValidationSplitSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override def isLargerBetter: Boolean = true + + override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala new file mode 100644 index 000000000000..d290cc9b06e7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap + +object MLTestingUtils { + def checkCopy(model: Model[_]): Unit = { + val copied = model.copy(ParamMap.empty) + .asInstanceOf[Model[_]] + assert(copied.parent.uid == model.parent.uid) + assert(copied.parent == model.parent) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala new file mode 100644 index 000000000000..9e6bc7193c13 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + + import StopwatchSuite._ + + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { + assert(sw.name === "sw") + assert(sw.elapsed() === 0L) + assert(!sw.isRunning) + intercept[AssertionError] { + sw.stop() + } + val duration = checkStopwatch(sw) + val elapsed = sw.elapsed() + assert(elapsed === duration) + val duration2 = checkStopwatch(sw) + val elapsed2 = sw.elapsed() + assert(elapsed2 === duration + duration2) + assert(sw.toString === s"sw: ${elapsed2}ms") + sw.start() + assert(sw.isRunning) + intercept[AssertionError] { + sw.start() + } + } + + test("LocalStopwatch") { + val sw = new LocalStopwatch("sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on driver") { + val sw = new DistributedStopwatch(sc, "sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on executors") { + val sw = new DistributedStopwatch(sc, "sw") + val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) + rdd.foreach { i => + acc += checkStopwatch(sw) + } + assert(!sw.isRunning) + val elapsed = sw.elapsed() + assert(elapsed === acc.value) + } + + test("MultiStopwatch") { + val sw = new MultiStopwatch(sc) + .addLocal("local") + .addDistributed("spark") + assert(sw("local").name === "local") + assert(sw("spark").name === "spark") + intercept[NoSuchElementException] { + sw("some") + } + assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") + val localDuration = checkStopwatch(sw("local")) + val sparkDuration = checkStopwatch(sw("spark")) + val localElapsed = sw("local").elapsed() + val sparkElapsed = sw("spark").elapsed() + assert(localElapsed === localDuration) + assert(sparkElapsed === sparkDuration) + assert(sw.toString === + s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") + val rdd = sc.parallelize(0 until 4, 4) + val acc = sc.accumulator(0L) + rdd.foreach { i => + sw("local").start() + val duration = checkStopwatch(sw("spark")) + sw("local").stop() + acc += duration + } + val localElapsed2 = sw("local").elapsed() + assert(localElapsed2 === localElapsed) + val sparkElapsed2 = sw("spark").elapsed() + assert(sparkElapsed2 === sparkElapsed + acc.value) + } +} + +private object StopwatchSuite extends SparkFunSuite { + + /** + * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and + * returns the duration reported by the stopwatch. + */ + def checkStopwatch(sw: Stopwatch): Long = { + val ubStart = now + sw.start() + val lbStart = now + Thread.sleep(new Random().nextInt(10)) + val lb = now - lbStart + val duration = sw.stop() + val ub = now - ubStart + assert(duration >= lb && duration <= ub) + duration + } + + /** The current time in milliseconds. */ + private def now: Long = System.currentTimeMillis() +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index db8ed62fa46c..59944416d96a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.api.python -import org.scalatest.FunSuite - -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors} +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating -class PythonMLLibAPISuite extends FunSuite { +class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() @@ -77,6 +76,16 @@ class PythonMLLibAPISuite extends FunSuite { val emptyMatrix = Matrices.dense(0, 0, empty) val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] assert(emptyMatrix == ne) + + val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4)) + val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix] + assert(sm.toArray === nsm.toArray) + + val smt = new SparseMatrix( + 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), + isTransposed = true) + val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] + assert(smt.toArray === nsmt.toArray) } test("pickle rating") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3fb45938f75d..8d14bb657215 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,17 +17,19 @@ package org.apache.spark.mllib.classification -import scala.util.control.Breaks._ +import scala.collection.JavaConverters._ import scala.util.Random -import scala.collection.JavaConversions._ +import scala.util.control.Breaks._ -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + object LogisticRegressionSuite { @@ -36,7 +38,7 @@ object LogisticRegressionSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) + generateLogisticInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale*X) @@ -89,21 +91,22 @@ object LogisticRegressionSuite { seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val xDim = xMean.size + val xDim = xMean.length val xWithInterceptsDim = if (addIntercept) xDim + 1 else xDim - val nClasses = weights.size / xWithInterceptsDim + 1 + val nClasses = weights.length / xWithInterceptsDim + 1 val x = Array.fill[Vector](nPoints)(Vectors.dense(Array.fill[Double](xDim)(rnd.nextGaussian()))) - x.map(vector => { + x.foreach { vector => // This doesn't work if `vector` is a sparse vector. val vectorArray = vector.toArray var i = 0 - while (i < vectorArray.size) { + val len = vectorArray.length + while (i < len) { vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i) i += 1 } - }) + } val y = (0 until nPoints).map { idx => val xArray = x(idx).toArray @@ -116,7 +119,7 @@ object LogisticRegressionSuite { } // Preventing the overflow when we compute the probability val maxMargin = margins.max - if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin // Computing the probabilities for each class from the margins. val norm = { @@ -127,7 +130,7 @@ object LogisticRegressionSuite { } temp } - for (i <-0 until nClasses) probs(i) /= norm + for (i <- 0 until nClasses) probs(i) /= norm // Compute the cumulative probability so we can generate a random number and assign a label. for (i <- 1 until nClasses) probs(i) += probs(i - 1) @@ -147,9 +150,26 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** Binary labels, 3 features */ + private val binaryModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2) + + /** 3 classes, 2 features */ + private val multiclassModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) + + private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = { + assert(a.weights == b.weights) + assert(a.intercept == b.intercept) + assert(a.numClasses == b.numClasses) + assert(a.numFeatures == b.numFeatures) + assert(a.getThreshold == b.getThreshold) + } } -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { + +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -176,6 +196,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M .setStepSize(10.0) .setRegParam(0.0) .setNumIterations(20) + .setConvergenceTol(0.0005) val model = lr.run(testRDD) @@ -353,8 +374,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M testRDD2.cache() testRDD3.cache() + val numIteration = 10 + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + lrA.optimizer.setNumIterations(numIteration) val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + lrB.optimizer.setNumIterations(numIteration) val modelA1 = lrA.run(testRDD1, initialWeights) val modelA2 = lrA.run(testRDD2, initialWeights) @@ -402,6 +427,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD) + val numFeatures = testRDD.map(_.features.size).first() + val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2)) + val model2 = lr.run(testRDD, initialWeights) + + LogisticRegressionSuite.checkModelsEqual(model, model2) + /** * The following is the instruction to reproduce the model using R's glmnet package. * @@ -462,9 +493,56 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } + test("model save/load: binary classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load: multiclass classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.multiclassModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index e68fe89d6cce..cffa1ab700f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,15 +19,20 @@ package org.apache.spark.mllib.classification import scala.util.Random -import org.scalatest.FunSuite +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.apache.spark.SparkException -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils object NaiveBayesSuite { + import NaiveBayes.{Multinomial, Bernoulli} + private def calcLabel(p: Double, pi: Array[Double]): Int = { var sum = 0.0 for (j <- 0 until pi.length) { @@ -39,28 +44,51 @@ object NaiveBayesSuite { // Generate input of the form Y = (theta * x).argmax() def generateNaiveBayesInput( - pi: Array[Double], // 1XC - theta: Array[Array[Double]], // CXD - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int, + modelType: String = Multinomial, + sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) - val _pi = pi.map(math.pow(math.E, _)) val _theta = theta.map(row => row.map(math.pow(math.E, _))) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) - val xi = Array.tabulate[Double](D) { j => - if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 + val xi = modelType match { + case Bernoulli => Array.tabulate[Double] (D) { j => + if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 + } + case Multinomial => + val mult = BrzMultinomial(BDV(_theta(y))) + val emptyMap = (0 until D).map(x => (x, 0.0)).toMap + val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { + case (index, reps) => (index, reps.size.toDouble) + } + counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) } } + + /** Bernoulli NaiveBayes with binary labels, 3 features */ + private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli) + + /** Multinomial NaiveBayes with binary labels, 3 features */ + private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + + import NaiveBayes.{Multinomial, Bernoulli} def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { @@ -71,23 +99,106 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } - test("Naive Bayes") { - val nPoints = 10000 + def validateModelFit( + piData: Array[Double], + thetaData: Array[Array[Double]], + model: NaiveBayesModel): Unit = { + def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { + (d1 - d2).abs <= precision + } + val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + for (i <- modelIndex) { + assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + } + for (i <- modelIndex) { + for (j <- 0 until thetaData(i._2).length) { + assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + } + } + } + + test("model types") { + assert(Multinomial === "multinomial") + assert(Bernoulli === "bernoulli") + } + + test("get, set params") { + val nb = new NaiveBayes() + nb.setLambda(2.0) + assert(nb.getLambda === 2.0) + nb.setLambda(3.0) + assert(nb.getLambda === 3.0) + } + + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial) + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(testRDD, 1.0, Multinomial) + validateModelFit(pi, theta, model) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 17, Multinomial) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + + // Test posteriors + validationData.map(_.features).foreach { features => + val predicted = model.predictProbabilities(features).toArray + assert(predicted.sum ~== 1.0 relTol 1.0e-10) + val expected = expectedMultinomialProbabilities(model, features) + expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) } + } + } + /** + * @param model Multinomial Naive Bayes model + * @param testData input to compute posterior probabilities for + * @return posterior class probabilities (in order of labels) for input + */ + private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = { + val piVector = new BDV(model.pi) + // model.theta is row-major; treat it as col-major representation of transpose, and transpose: + val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t + val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze) + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + classProbs.map(_ / classProbsSum) + } + + test("Naive Bayes Bernoulli") { + val nPoints = 10000 val pi = Array(0.5, 0.3, 0.2).map(math.log) val theta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 45, Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD) + val model = NaiveBayes.train(testRDD, 1.0, Bernoulli) + validateModelFit(pi, theta, model) - val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17) + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 20, Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -95,6 +206,33 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + + // Test posteriors + validationData.map(_.features).foreach { features => + val predicted = model.predictProbabilities(features).toArray + assert(predicted.sum ~== 1.0 relTol 1.0e-10) + val expected = expectedBernoulliProbabilities(model, features) + expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) } + } + } + + /** + * @param model Bernoulli Naive Bayes model + * @param testData input to compute posterior probabilities for + * @return posterior class probabilities (in order of labels) for input + */ + private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = { + val piVector = new BDV(model.pi) + val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t + val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length, + model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t + val testBreeze = testData.toBreeze + val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze + val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze) + val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze) + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + classProbs.map(_ / classProbsSum) } test("detect negative values") { @@ -123,9 +261,82 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { NaiveBayes.train(sc.makeRDD(nan, 2)) } } + + test("detect non zero or one values in Bernoulli") { + val badTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli) + } + + val okTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)) + ) + + val badPredict = Seq( + Vectors.dense(1.0), + Vectors.dense(2.0), + Vectors.dense(1.0), + Vectors.dense(0.0)) + + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli) + intercept[SparkException] { + model.predict(sc.makeRDD(badPredict, 2)).collect() + } + } + + test("model save/load: 2.0 to 2.0") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map { + model => + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === sameModel.modelType) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("model save/load: 1.0 to 2.0") { + val model = NaiveBayesSuite.binaryMultinomialModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model as version 1.0, load it back, and compare. + try { + val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === Multinomial) + } finally { + Utils.deleteRecursively(tempDir) + } + } } -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 10 @@ -136,8 +347,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble()))) } } - // If we serialize data directly in the task closure, the size of the serialized task would be - // greater than 1MB and hence Spark would throw an error. + // If we serialize data directly in the task closure, the size of the serialized task + // would be greater than 1MB and hence Spark would throw an error. val model = NaiveBayes.train(examples) val predictions = model.predict(examples.map(_.features)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index a2de7fbd4138..ee3c85d09a46 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils object SVMSuite { @@ -35,7 +35,7 @@ object SVMSuite { weights: Array[Double], nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) + generateSVMInput(intercept, weights, nPoints, seed).asJava } // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) @@ -45,7 +45,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => @@ -56,9 +56,12 @@ object SVMSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + /** Binary labels, 3 features */ + private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) + } -class SVMSuite extends FunSuite with MLlibTestSparkContext { +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -87,7 +90,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -113,7 +116,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -123,8 +126,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -141,7 +144,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val initialB = -1.0 val initialC = -1.0 @@ -155,8 +158,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD, initialWeights) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -173,7 +176,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) val testRDDInvalid = testRDD.map { lp => @@ -191,9 +194,41 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { // Turning off data validation should not throw an exception new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } + + test("model save/load") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = SVMSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = SVMModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = SVMModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + } finally { + Utils.deleteRecursively(tempDir) + } + } } -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 8b3e6e5ce924..d7b291d5a633 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -19,18 +19,26 @@ package org.apache.spark.mllib.classification import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 + + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -51,7 +59,7 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -85,7 +93,7 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) inputDStream.count() @@ -119,7 +127,7 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -132,4 +140,48 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { assert(errors.forall(x => x <= 0.4)) } + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + val numBatches = 10 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index c2cd56ea40ad..a72723eb00da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - -import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), Vectors.dense(5.0, 10.0), Vectors.dense(4.0, 11.0) )) - + // expectations val Ew = 1.0 val Emu = Vectors.dense(5.0, 10.0) @@ -44,17 +44,12 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) } + } - + test("two clusters") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) - + val data = sc.parallelize(GaussianTestData.data) + // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( Array(0.5, 0.5), @@ -63,16 +58,16 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) ) ) - + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) .run(data) - + assert(gmm.weights(0) ~== Ew(0) absTol 1E-3) assert(gmm.weights(1) ~== Ew(1) absTol 1E-3) assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3) @@ -80,4 +75,116 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + + test("two clusters with distributed decompositions") { + val data = sc.parallelize(GaussianTestData.data2, 2) + + val k = 5 + val d = data.first().size + assert(GaussianMixture.shouldDistributeGaussians(k, d)) + + val gmm = new GaussianMixture() + .setK(k) + .run(data) + + assert(gmm.k === k) + } + + test("single cluster with sparse data") { + val data = sc.parallelize(Array( + Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)), + Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)), + Vectors.sparse(3, Array(1), Array(6.0)) + )) + + val Ew = 1.0 + val Emu = Vectors.dense(2.0, 2.0, 2.0) + val Esigma = Matrices.dense(3, 3, + Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0) + ) + + val seeds = Array(42, 1994, 27, 11, 0) + seeds.foreach { seed => + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) + } + } + + test("two clusters with sparse data") { + val data = sc.parallelize(GaussianTestData.data) + val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) + ) + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val sparseGMM = new GaussianMixture() + .setK(2) + .setInitialModel(initialGmm) + .run(sparseData) + + assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3) + assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) + } + + test("model save / load") { + val data = sc.parallelize(GaussianTestData.data) + + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + gmm.save(sc, path) + + // TODO: GaussianMixtureModel should implement equals/hashcode directly. + val sameModel = GaussianMixtureModel.load(sc, path) + assert(sameModel.k === gmm.k) + (0 until sameModel.k).foreach { i => + assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu) + assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model prediction, parallel and local") { + val data = sc.parallelize(GaussianTestData.data) + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + + val batchPredictions = gmm.predict(data) + batchPredictions.zip(data).collect().foreach { case (batchPred, datum) => + assert(batchPred === gmm.predict(datum)) + } + } + + object GaussianTestData { + + val data = Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + ) + + val data2: Array[Vector] = Array.tabulate(25){ i: Int => + Vectors.dense(Array.tabulate(50)(i + _.toDouble)) + } + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index caee5917000a..3003c62d9876 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils -class KMeansSuite extends FunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} @@ -74,7 +74,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 2.0, 3.0) // Make sure code runs. - var model = KMeans.train(data, k=2, maxIterations=1) + var model = KMeans.train(data, k = 2, maxIterations = 1) assert(model.clusterCenters.size === 2) } @@ -86,7 +86,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { 2) // Make sure code runs. - var model = KMeans.train(data, k=3, maxIterations=1) + var model = KMeans.train(data, k = 3, maxIterations = 1) assert(model.clusterCenters.size === 3) } @@ -198,9 +198,13 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { test("k-means|| initialization") { case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { - @Override def compare(that: VectorWithCompare): Int = { - if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > - that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + override def compare(that: VectorWithCompare): Int = { + if (this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) { + -1 + } else { + 1 + } } } @@ -257,9 +261,72 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(predicts(0) != predicts(3)) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(true, false).foreach { case selector => + val model = KMeansSuite.createModel(10, 3, selector) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = KMeansModel.load(sc, path) + KMeansSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("Initialize using given cluster centers") { + val points = Seq( + Vectors.dense(0.0, 0.0), + Vectors.dense(1.0, 0.0), + Vectors.dense(0.0, 1.0), + Vectors.dense(1.0, 1.0) + ) + val rdd = sc.parallelize(points, 3) + // creating an initial model + val initialModel = new KMeansModel(Array(points(0), points(2))) + + val returnModel = new KMeans() + .setK(2) + .setMaxIterations(0) + .setInitialModel(initialModel) + .run(rdd) + // comparing the returned model and the initial model + assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0)) + assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1)) + } + +} + +object KMeansSuite extends SparkFunSuite { + def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { + val singlePoint = isSparse match { + case true => + Vectors.sparse(dim, Array.empty[Int], Array.empty[Double]) + case _ => + Vectors.dense(Array.fill[Double](dim)(0.0)) + } + new KMeansModel(Array.fill[Vector](k)(singlePoint)) + } + + def checkEqual(a: KMeansModel, b: KMeansModel): Unit = { + assert(a.k === b.k) + a.clusterCenters.zip(b.clusterCenters).foreach { + case (ca: SparseVector, cb: SparseVector) => + assert(ca === cb) + case (ca: DenseVector, cb: DenseVector) => + assert(ca === cb) + case _ => + throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") + } + } } -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 302d751eb8a9..37fb69d68f6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,18 +17,24 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite +import java.util.{ArrayList => JArrayList} -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.graphx.Edge +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils -class LDASuite extends FunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ test("LocalLDAModel") { - val model = new LocalLDAModel(tinyTopics) + val model = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) // Check: basic parameters assert(model.k === tinyK) @@ -37,7 +43,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Check: describeTopics() with all terms val fullTopicSummary = model.describeTopics() - assert(fullTopicSummary.size === tinyK) + assert(fullTopicSummary.length === tinyK) fullTopicSummary.zip(tinyTopicDescription).foreach { case ((algTerms, algTermWeights), (terms, termWeights)) => assert(algTerms === terms) @@ -54,7 +60,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { } } - test("running and DistributedLDAModel") { + test("running and DistributedLDAModel with default Optimizer (EM)") { val k = 3 val topicSmoothing = 1.2 val termSmoothing = 1.2 @@ -62,13 +68,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { // Train a model val lda = new LDA() lda.setK(k) + .setOptimizer(new EMLDAOptimizer) .setDocConcentration(topicSmoothing) .setTopicConcentration(termSmoothing) .setMaxIterations(5) .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal @@ -79,37 +86,84 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { assert(model.topicsMatrix === localModel.topicsMatrix) // Check: topic summaries - // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => - assert(t1 === t2) + val topicSummary = model.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) => + assert(topics ~== topicsLocal absTol 0.01) } // Check: per-doc topic distributions val topicDistributions = model.topicDistributions.collect() + // Ensure all documents are covered. - assert(topicDistributions.size === tinyCorpus.size) - assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) + // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs + // over topics. Compare it against nonEmptyTinyCorpus instead of tinyCorpus + val nonEmptyTinyCorpus = getNonEmptyDoc(tinyCorpus) + assert(topicDistributions.length === nonEmptyTinyCorpus.length) + assert(nonEmptyTinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) // Ensure we have proper distributions topicDistributions.foreach { case (docId, topicDistribution) => assert(topicDistribution.size === tinyK) assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) } + val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3))) + model.topicDistributions.join(top2TopicsPerDoc).collect().foreach { + case (docId, (topicDistribution, (indices, weights))) => + assert(indices.length == 2) + assert(weights.length == 2) + val bdvTopicDist = topicDistribution.toBreeze + val top2Indices = argtopk(bdvTopicDist, 2) + assert(top2Indices.toArray === indices) + assert(bdvTopicDist(top2Indices).toArray === weights) + } + // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // Check: topTopicAssignments + // Make sure it assigns a topic to each term appearing in each doc. + val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] = + model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap + assert(topTopicAssignments.keys.max < tinyCorpus.length) + tinyCorpus.foreach { case (docID: Long, doc: Vector) => + if (topTopicAssignments.contains(docID)) { + val (inds, vals) = topTopicAssignments(docID) + assert(inds.length === doc.numNonzeros) + // For "term" in actual doc, + // check that it has a topic assigned. + doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term))) + } else { + assert(doc.numNonzeros === 0) + } + } } test("vertex indexing") { @@ -123,6 +177,382 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) } + + test("setter alias") { + val lda = new LDA().setAlpha(2.0).setBeta(3.0) + assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0)) + assert(lda.getBeta === 3.0) + assert(lda.getTopicConcentration === 3.0) + } + + test("initializing with alpha length != k or 1 fails") { + intercept[IllegalArgumentException] { + val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4)) + val corpus = sc.parallelize(tinyCorpus, 2) + lda.run(corpus) + } + } + + test("initializing with elements in alpha < 0 fails") { + intercept[IllegalArgumentException] { + val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4)) + val corpus = sc.parallelize(tinyCorpus, 2) + lda.run(corpus) + } + } + + test("OnlineLDAOptimizer initialization") { + val lda = new LDA().setK(2) + val corpus = sc.parallelize(tinyCorpus, 2) + val op = new OnlineLDAOptimizer().initialize(corpus, lda) + op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567) + assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k + assert(op.getEta === 0.5) // default 1.0 / k + assert(op.getKappa === 0.9876) + assert(op.getMiniBatchFraction === 0.123) + assert(op.getTau0 === 567) + } + + test("OnlineLDAOptimizer one iteration") { + // run OnlineLDAOptimizer for 1 iteration to verify it's consistency with Blei-lab, + // [[https://github.com/Blei-Lab/onlineldavb]] + val k = 2 + val vocabSize = 6 + + def docs: Array[(Long, Vector)] = Array( + Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana + Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val corpus = sc.parallelize(docs, 2) + + // Set GammaShape large to avoid the stochastic impact. + val op = new OnlineLDAOptimizer().setTau0(1024).setKappa(0.51).setGammaShape(1e40) + .setMiniBatchFraction(1) + val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345) + + val state = op.initialize(corpus, lda) + // override lambda to simulate an intermediate state + // [[ 1.1 1.2 1.3 0.9 0.8 0.7] + // [ 0.9 0.8 0.7 1.1 1.2 1.3]] + op.setLambda(new BDM[Double](k, vocabSize, + Array(1.1, 0.9, 1.2, 0.8, 1.3, 0.7, 0.9, 1.1, 0.8, 1.2, 0.7, 1.3))) + + // run for one iteration + state.submitMiniBatch(corpus) + + // verify the result, Note this generate the identical result as + // [[https://github.com/Blei-Lab/onlineldavb]] + val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t) + val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t) + val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950) + val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050) + assert(topic1 ~== expectedTopic1 absTol 0.01) + assert(topic2 ~== expectedTopic2 absTol 0.01) + } + + test("OnlineLDAOptimizer with toy data") { + val docs = sc.parallelize(toyData) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(1e10) + val lda = new LDA().setK(2) + .setDocConcentration(0.01) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + + val ldaModel = lda.run(docs) + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights) + } + + // check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02) + topics.foreach { topic => + val smalls = topic.filter(t => t._2 < 0.1).map(_._2) + assert(smalls.length == 3 && smalls.sum < 0.2) + } + } + + test("LocalLDAModel logLikelihood") { + val ldaModel: LocalLDAModel = toyModel + + val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + docsSingleWord = [[(0, 1.0)]] + docsRepeatedWord = [[(0, 5.0)]] + print(lda.bound(docsSingleWord)) + > -25.9706969833 + print(lda.bound(docsRepeatedWord)) + > -31.4413908227 + */ + + assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D) + assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441 relTol 1E-3D) + } + + test("LocalLDAModel logPerplexity") { + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.log_perplexity(corpus)) + > -3.69051285096 + */ + + // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition + assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D) + } + + test("LocalLDAModel predict") { + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(list(lda.get_document_topics(corpus))) + > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)], + > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)], + > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]] + */ + + val expectedPredictions = List( + (0, 0.99504), (0, 0.99504), + (0, 0.99504), (1, 0.99504), + (1, 0.99504), (1, 0.99504)) + + val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + // convert results to expectedPredictions format, which only has highest probability topic + val topicsBz = topics.toBreeze.toDenseVector + (id, (argmax(topicsBz), max(topicsBz))) + }.sortByKey() + .values + .collect() + + expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => + expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + } + } + + test("OnlineLDAOptimizer with asymmetric prior") { + val docs = sc.parallelize(toyData) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(1e10) + val lda = new LDA().setK(2) + .setDocConcentration(Vectors.dense(0.00001, 0.1)) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + + val ldaModel = lda.run(docs) + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights) + } + + /* Verify results with Python: + + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(10) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100) + lda.print_topics() + + > ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5', + '0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5'] + */ + topics.foreach { topic => + assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 }) + } + } + + test("OnlineLDAOptimizer alpha hyperparameter optimization") { + val k = 2 + val docs = sc.parallelize(toyData) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false) + val lda = new LDA().setK(k) + .setDocConcentration(1D / k) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel] + + /* Verify the results with gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.alpha) + > [ 0.42582646 0.43511073] + */ + + assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05) + } + + test("model save/load") { + // Test for LocalLDAModel. + val localModel = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) + val tempDir1 = Utils.createTempDir() + val path1 = tempDir1.toURI.toString + + // Test for DistributedLDAModel. + val k = 3 + val docConcentration = 1.2 + val topicConcentration = 1.5 + val lda = new LDA() + lda.setK(k) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] + val tempDir2 = Utils.createTempDir() + val path2 = tempDir2.toURI.toString + + try { + localModel.save(sc, path1) + distributedModel.save(sc, path2) + val samelocalModel = LocalLDAModel.load(sc, path1) + assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) + assert(samelocalModel.k === localModel.k) + assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) + + val sameDistributedModel = DistributedLDAModel.load(sc, path2) + assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) + assert(distributedModel.k === sameDistributedModel.k) + assert(distributedModel.vocabSize === sameDistributedModel.vocabSize) + assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) + assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) + assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) + assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) + + val graph = distributedModel.graph + val sameGraph = sameDistributedModel.graph + assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect()) + val edge = graph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + val sameEdge = sameGraph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + assert(edge === sameEdge) + } finally { + Utils.deleteRecursively(tempDir1) + Utils.deleteRecursively(tempDir2) + } + } + + test("EMLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new EMLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + + test("OnlineLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new OnlineLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + } private[clustering] object LDASuite { @@ -141,13 +571,52 @@ private[clustering] object LDASuite { (terms.toArray, termWeights.toArray) } - def tinyCorpus = Array( + def tinyCorpus: Array[(Long, Vector)] = Array( + Vectors.dense(0, 0, 0, 0, 0), // empty doc Vectors.dense(1, 3, 0, 2, 8), Vectors.dense(0, 2, 1, 0, 4), Vectors.dense(2, 3, 12, 3, 1), + Vectors.dense(0, 0, 0, 0, 0), // empty doc Vectors.dense(0, 3, 1, 9, 8), Vectors.dense(1, 1, 4, 2, 6) ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data + def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter { + case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0 + } + + def toyData: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + /** Used in the Java Test Suite */ + def javaToyData: JArrayList[(java.lang.Long, Vector)] = { + val javaData = new JArrayList[(java.lang.Long, Vector)] + var i = 0 + while (i < toyData.length) { + javaData.add((toyData(i)._1, toyData(i)._2)) + i += 1 + } + javaData + } + + def toyModel: LocalLDAModel = { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + ldaModel + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 03ecd9ca730b..189000512155 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -18,14 +18,15 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable +import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.PowerIterationClustering._ @@ -51,18 +52,66 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext .setK(2) .run(sc.parallelize(similarities, 2)) val predictions = Array.fill(2)(mutable.Set.empty[Long]) - model.assignments.collect().foreach { case (i, c) => - predictions(c) += i + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + + val model2 = new PowerIterationClustering() + .setK(2) + .setInitializationMode("degree") + .run(sc.parallelize(similarities, 2)) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id + } + assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + + test("power iteration clustering on graph") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + + val edges = similarities.flatMap { case (i, j, s) => + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0) + + val model = new PowerIterationClustering() + .setK(2) + .run(graph) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) - + val model2 = new PowerIterationClustering() .setK(2) .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - model2.assignments.collect().foreach { case (i, c) => - predictions2(c) += i + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id } assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) } @@ -91,11 +140,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext */ val similarities = Seq[(Long, Long, Double)]( (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + // scalastyle:off val expected = Array( Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + // scalastyle:on val w = normalize(sc.parallelize(similarities, 2)) w.edges.collect().foreach { case Edge(i, j, x) => assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) @@ -110,4 +161,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext assert(x ~== u1(i.toInt) absTol 1e-14) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val model = PowerIterationClusteringSuite.createModel(sc, 3, 10) + try { + model.save(sc, path) + val sameModel = PowerIterationClusteringModel.load(sc, path) + PowerIterationClusteringSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object PowerIterationClusteringSuite extends SparkFunSuite { + def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { + val assignments = sc.parallelize( + (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) + new PowerIterationClusteringModel(k, assignments) + } + + def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = { + assert(a.k === b.k) + + val aAssignments = a.assignments.map(x => (x.id, x.cluster)) + val bAssignments = b.assignments.map(x => (x.id, x.cluster)) + val unequalElements = aAssignments.join(bAssignments).filter { + case (id, (c1, c2)) => c1 != c2 }.count() + assert(unequalElements === 0L) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 850c9fce507c..3645d29dccdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,17 +17,25 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends FunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis: Int = 30000 - override def maxWaitTimeMillis = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } test("accuracy for single center and equivalence to grand average") { // set parameters @@ -47,7 +55,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -59,7 +67,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 val grandMean = - input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) } @@ -83,7 +91,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -115,7 +123,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -133,6 +141,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(math.abs(c1) ~== 0.8 absTol 0.6) } + test("SPARK-7946 setDecayFactor") { + val kMeans = new StreamingKMeans() + assert(kMeans.decayFactor === 1.0) + kMeans.setDecayFactor(2.0) + assert(kMeans.decayFactor === 2.0) + } + def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 79847633ff0d..87ccc7eda44e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index e0224f960cc4..99d52fabc530 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 7dc4f3cfbc4e..d55bc8c3ec09 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 2537dd62c92f..f3b19aeb42f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 609eed983ff4..c0924a213a84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 670b4c34e609..4b7f1be58f99 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -17,31 +17,91 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("regression metrics for unbiased (includes intercept term) predictor") { + /* Verify results in R: + preds = c(2.25, -0.25, 1.75, 7.75) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.796875 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.3125 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.559017 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9571734 + */ + val predictionAndObservations = sc.parallelize( + Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics for biased (no intercept term) predictor") { + /* Verify results in R: + preds = c(2.5, 0.0, 2.0, 8.0) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) - test("regression metrics") { + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.859375 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.375 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.6123724 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9486081 + */ val predictionAndObservations = sc.parallelize( - Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") } test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( - Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 747f5914598e..889727fb5582 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { /* * Contingency tables diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala new file mode 100644 index 000000000000..ccbf8a91cdd3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("elementwise (hadamard) product should properly apply vector to dense data set") { + val denseData = Array( + Vectors.dense(1.0, 4.0, 1.9, -9.0) + ) + val scalingVec = Vectors.dense(2.0, 0.5, 0.0, 0.25) + val transformer = new ElementwiseProduct(scalingVec) + val transformedData = transformer.transform(sc.makeRDD(denseData)) + val transformedVecs = transformedData.collect() + val transformedVec = transformedVecs(0) + val expectedVec = Vectors.dense(2.0, 2.0, 0.0, -2.25) + assert(transformedVec ~== expectedVec absTol 1E-5, + s"Expected transformed vector $expectedVec but found $transformedVec") + } + + test("elementwise (hadamard) product should properly apply vector to sparse data set") { + val sparseData = Array( + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))) + ) + val dataRDD = sc.parallelize(sparseData, 3) + val scalingVec = Vectors.dense(1.0, 0.0, 0.5) + val transformer = new ElementwiseProduct(scalingVec) + val data2 = sparseData.map(transformer.transform) + val data2RDD = transformer.transform(dataRDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after hadamard product") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert(data2(0) ~== Vectors.sparse(3, Seq((1, 0.0), (2, -1.5))) absTol 1E-5) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 0c4dfb7b97c7..cf279c02334e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 0a5cad7caf8e..21163633051e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 85fdd271b5ed..34122d6ed2e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - import breeze.linalg.{norm => brzNorm} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), @@ -106,10 +105,10 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { assert((dataInf, dataInfRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(dataInf(0).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) - assert(dataInf(2).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) - assert(dataInf(3).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) - assert(dataInf(4).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(0).toArray.map(math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(2).toArray.map(math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(3).toArray.map(math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(4).toArray.map(math.abs).max ~== 1.0 absTol 1E-5) assert(dataInf(0) ~== Vectors.sparse(3, Seq((0, -0.86956522), (1, 1.0))) absTol 1E-5) assert(dataInf(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala new file mode 100644 index 000000000000..e57f49191378 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { + + private val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + + private lazy val dataRDD = sc.parallelize(data, 2) + + test("Correct computing use a PCA wrapper") { + val k = dataRDD.count().toInt + val pca = new PCA(k).fit(dataRDD) + + val mat = new RowMatrix(dataRDD) + val pc = mat.computePrincipalComponents(k) + + val pca_transform = pca.transform(dataRDD).collect() + val mat_multiply = mat.multiply(pc).rows.collect() + + assert(pca_transform.toSet === mat_multiply.toSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 7f94564b2a3a..6ab2fa677012 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { // When the input data is all constant, the variance is zero. The standardization against // zero variance is not well-defined, but we decide to just set it into zero here. @@ -360,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } withClue("model needs std and mean vectors to be equal size when both are provided") { intercept[IllegalArgumentException] { - val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 52278690dbd8..a864eec460f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: add more tests @@ -35,6 +37,22 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") + + // Test that model built using Word2Vec, i.e wordVectors and wordIndec + // and a Word2VecMap give the same values. + val word2VecMap = model.getVectors + val newModel = new Word2VecModel(word2VecMap) + assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq)) + } + + test("Word2Vec throws exception when vocabulary is empty") { + intercept[IllegalArgumentException] { + val sentence = "a b c" + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + new Word2Vec().setMinCount(10).fit(doc) + } } test("Word2VecModel") { @@ -51,4 +69,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms(0)._1 == "taiwan") assert(syms(1)._1 == "japan") } + + test("model load / save") { + + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala new file mode 100644 index 000000000000..77a2773c36f5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("association rules using String type") { + val freqItemsets = sc.parallelize(Seq( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L) + ).map { + case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq) + }) + + val ar = new AssociationRules() + + val results1 = ar + .setMinConfidence(0.9) + .run(freqItemsets) + .collect() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + assert(results1.size === 23) + assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + + val results2 = ar + .setMinConfidence(0) + .run(freqItemsets) + .collect() + + /* Verify results using the `R` code: + ars = apriori(transactions, + parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + nrow(arsDF) + sum(arsDF$confidence == 1) + > nrow(arsDF) + [1] 30 + > sum(arsDF$confidence == 1) + [1] 23 + */ + assert(results2.size === 30) + assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 71ef60da6dd3..4a9bfdb348d9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,13 +16,13 @@ */ package org.apache.spark.mllib.fpm -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { + - test("FP-Growth") { + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -39,15 +39,58 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .setMinSupport(0.9) .setNumPartitions(1) .run(rdd) + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + > eclat(transactions, parameter = list(support = 0.9)) + ... + eclat - zero frequent items + set of 0 itemsets + */ assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) .run(rdd) - val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => - (items.toSet, count) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) } + + /* Verify results using the `R` code: + fp = eclat(transactions, parameter = list(support = 0.5)) + fpDF = as(sort(fp), "data.frame") + fpDF$support = fpDF$support * length(transactions) + names(fpDF)[names(fpDF) == "support"] = "freq" + > fpDF + items freq + 13 {z} 5 + 14 {x} 4 + 1 {s,x} 3 + 2 {t,x,y,z} 3 + 3 {t,y,z} 3 + 4 {t,x,y} 3 + 5 {x,y,z} 3 + 6 {y,z} 3 + 7 {x,y} 3 + 8 {t,y} 3 + 9 {t,x,z} 3 + 10 {t,z} 3 + 11 {t,x} 3 + 12 {x,z} 3 + 15 {t} 3 + 16 {y} 3 + 17 {s} 3 + 18 {r} 3 + */ val expected = Set( (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), (Set("r"), 3L), @@ -62,12 +105,173 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .setMinSupport(0.3) .setNumPartitions(4) .run(rdd) + + /* Verify results using the `R` code: + fp = eclat(transactions, parameter = list(support = 0.3)) + fpDF = as(fp, "data.frame") + fpDF$support = fpDF$support * length(transactions) + names(fpDF)[names(fpDF) == "support"] = "freq" + > nrow(fpDF) + [1] 54 + */ assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) .run(rdd) + + /* Verify results using the `R` code: + fp = eclat(transactions, parameter = list(support = 0.1)) + fpDF = as(fp, "data.frame") + fpDF$support = fpDF$support * length(transactions) + names(fpDF)[names(fpDF) == "support"] = "freq" + > nrow(fpDF) + [1] 625 + */ assert(model1.freqItemsets.count() === 625) } + + test("FP-Growth String type association rule generation") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + val rules = (new FPGrowth()) + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + .generateAssociationRules(0.9) + .collect() + + assert(rules.size === 23) + assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + + /* Verify results using the `R` code: + transactions = as(sapply( + list("1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + > eclat(transactions, parameter = list(support = 0.9)) + ... + eclat - zero frequent items + set of 0 itemsets + */ + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + /* Verify results using the `R` code: + fp = eclat(transactions, parameter = list(support = 0.5)) + fpDF = as(sort(fp), "data.frame") + fpDF$support = fpDF$support * length(transactions) + names(fpDF)[names(fpDF) == "support"] = "freq" + > fpDF + items freq + 6 {1} 6 + 3 {1,3} 5 + 7 {2} 5 + 8 {3} 5 + 1 {2,4} 4 + 2 {1,2,3} 4 + 4 {2,3} 4 + 5 {1,2} 4 + 9 {4} 4 + */ + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + + /* Verify results using the `R` code: + fp = eclat(transactions, parameter = list(support = 0.3)) + fpDF = as(fp, "data.frame") + fpDF$support = fpDF$support * length(transactions) + names(fpDF)[names(fpDF) == "support"] = "freq" + > nrow(fpDF) + [1] 15 + */ + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + + /* Verify results using the `R` code: + fp = eclat(transactions, parameter = list(support = 0.1)) + fpDF = as(fp, "data.frame") + fpDF$support = fpDF$support * length(transactions) + names(fpDF)[names(fpDF) == "support"] = "freq" + > nrow(fpDF) + [1] 65 + */ + assert(model1.freqItemsets.count() === 65) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala index 04017f67c311..a56d7b357921 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm import scala.language.existentials -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPTreeSuite extends FunSuite with MLlibTestSparkContext { +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala new file mode 100644 index 000000000000..a83e543859b8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("PrefixSpan internal (integer seq, 0 delim) run, singleton itemsets") { + + /* + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade( + prefixSpanSeqs, + parameter = list(support = + 2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + */ + + val sequences = Array( + Array(0, 1, 0, 3, 0, 4, 0, 5, 0), + Array(0, 2, 0, 3, 0, 1, 0), + Array(0, 2, 0, 4, 0, 1, 0), + Array(0, 3, 0, 1, 0, 3, 0, 4, 0, 5, 0), + Array(0, 3, 0, 4, 0, 4, 0, 3, 0), + Array(0, 6, 0, 5, 0, 3, 0)) + + val rdd = sc.parallelize(sequences, 2).cache() + + val result1 = PrefixSpan.genFreqPatterns( + rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L) + val expectedValue1 = Array( + (Array(0, 1, 0), 4L), + (Array(0, 1, 0, 3, 0), 2L), + (Array(0, 1, 0, 3, 0, 4, 0), 2L), + (Array(0, 1, 0, 3, 0, 4, 0, 5, 0), 2L), + (Array(0, 1, 0, 3, 0, 5, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 1, 0, 4, 0, 5, 0), 2L), + (Array(0, 1, 0, 5, 0), 2L), + (Array(0, 2, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 3, 0, 3, 0), 2L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 3, 0, 4, 0, 5, 0), 2L), + (Array(0, 3, 0, 5, 0), 2L), + (Array(0, 4, 0), 4L), + (Array(0, 4, 0, 5, 0), 2L), + (Array(0, 5, 0), 3L) + ) + compareInternalResults(expectedValue1, result1.collect()) + + val result2 = PrefixSpan.genFreqPatterns( + rdd, minCount = 3, maxPatternLength = 50, maxLocalProjDBSize = 32L) + val expectedValue2 = Array( + (Array(0, 1, 0), 4L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 4, 0), 4L), + (Array(0, 5, 0), 3L) + ) + compareInternalResults(expectedValue2, result2.collect()) + + val result3 = PrefixSpan.genFreqPatterns( + rdd, minCount = 2, maxPatternLength = 2, maxLocalProjDBSize = 32L) + val expectedValue3 = Array( + (Array(0, 1, 0), 4L), + (Array(0, 1, 0, 3, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 1, 0, 5, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 2, 0), 2L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 3, 0, 3, 0), 2L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 3, 0, 5, 0), 2L), + (Array(0, 4, 0), 4L), + (Array(0, 4, 0, 5, 0), 2L), + (Array(0, 5, 0), 3L) + ) + compareInternalResults(expectedValue3, result3.collect()) + } + + test("PrefixSpan internal (integer seq, -1 delim) run, variable-size itemsets") { + val sequences = Array( + Array(0, 1, 0, 1, 2, 3, 0, 1, 3, 0, 4, 0, 3, 6, 0), + Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0), + Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0), + Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0)) + val rdd = sc.parallelize(sequences, 2).cache() + val result = PrefixSpan.genFreqPatterns( + rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L) + + /* + To verify results, create file "prefixSpanSeqs" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 1 1 + 1 2 3 1 2 3 + 1 3 2 1 3 + 1 4 1 4 + 1 5 2 3 6 + 2 1 2 1 4 + 2 2 1 3 + 2 3 2 2 3 + 2 4 2 1 5 + 3 1 2 5 6 + 3 2 2 1 2 + 3 3 2 4 6 + 3 4 1 3 + 3 5 1 2 + 4 1 1 5 + 4 2 1 7 + 4 3 2 1 6 + 4 4 1 3 + 4 5 1 2 + 4 6 1 3 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = list(support = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 1.00 + 2 <{2}> 1.00 + 3 <{3}> 1.00 + 4 <{4}> 0.75 + 5 <{5}> 0.75 + 6 <{6}> 0.75 + 7 <{1},{6}> 0.50 + 8 <{2},{6}> 0.50 + 9 <{5},{6}> 0.50 + 10 <{1,2},{6}> 0.50 + 11 <{1},{4}> 0.50 + 12 <{2},{4}> 0.50 + 13 <{1,2},{4}> 0.50 + 14 <{1},{3}> 1.00 + 15 <{2},{3}> 0.75 + 16 <{2,3}> 0.50 + 17 <{3},{3}> 0.75 + 18 <{4},{3}> 0.75 + 19 <{5},{3}> 0.50 + 20 <{6},{3}> 0.50 + 21 <{5},{6},{3}> 0.50 + 22 <{6},{2},{3}> 0.50 + 23 <{5},{2},{3}> 0.50 + 24 <{5},{1},{3}> 0.50 + 25 <{2},{4},{3}> 0.50 + 26 <{1},{4},{3}> 0.50 + 27 <{1,2},{4},{3}> 0.50 + 28 <{1},{3},{3}> 0.75 + 29 <{1,2},{3}> 0.50 + 30 <{1},{2},{3}> 0.50 + 31 <{1},{2,3}> 0.50 + 32 <{1},{2}> 1.00 + 33 <{1,2}> 0.50 + 34 <{3},{2}> 0.75 + 35 <{4},{2}> 0.50 + 36 <{5},{2}> 0.50 + 37 <{6},{2}> 0.50 + 38 <{5},{6},{2}> 0.50 + 39 <{6},{3},{2}> 0.50 + 40 <{5},{3},{2}> 0.50 + 41 <{5},{1},{2}> 0.50 + 42 <{4},{3},{2}> 0.50 + 43 <{1},{3},{2}> 0.75 + 44 <{5},{6},{3},{2}> 0.50 + 45 <{5},{1},{3},{2}> 0.50 + 46 <{1},{1}> 0.50 + 47 <{2},{1}> 0.50 + 48 <{3},{1}> 0.50 + 49 <{5},{1}> 0.50 + 50 <{2,3},{1}> 0.50 + 51 <{1},{3},{1}> 0.50 + 52 <{1},{2,3},{1}> 0.50 + 53 <{1},{2},{1}> 0.50 + */ + val expectedValue = Array( + (Array(0, 1, 0), 4L), + (Array(0, 2, 0), 4L), + (Array(0, 3, 0), 4L), + (Array(0, 4, 0), 3L), + (Array(0, 5, 0), 3L), + (Array(0, 6, 0), 3L), + (Array(0, 1, 0, 6, 0), 2L), + (Array(0, 2, 0, 6, 0), 2L), + (Array(0, 5, 0, 6, 0), 2L), + (Array(0, 1, 2, 0, 6, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 2, 0, 4, 0), 2L), + (Array(0, 1, 2, 0, 4, 0), 2L), + (Array(0, 1, 0, 3, 0), 4L), + (Array(0, 2, 0, 3, 0), 3L), + (Array(0, 2, 3, 0), 2L), + (Array(0, 3, 0, 3, 0), 3L), + (Array(0, 4, 0, 3, 0), 3L), + (Array(0, 5, 0, 3, 0), 2L), + (Array(0, 6, 0, 3, 0), 2L), + (Array(0, 5, 0, 6, 0, 3, 0), 2L), + (Array(0, 6, 0, 2, 0, 3, 0), 2L), + (Array(0, 5, 0, 2, 0, 3, 0), 2L), + (Array(0, 5, 0, 1, 0, 3, 0), 2L), + (Array(0, 2, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 2, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 0, 3, 0, 3, 0), 3L), + (Array(0, 1, 2, 0, 3, 0), 2L), + (Array(0, 1, 0, 2, 0, 3, 0), 2L), + (Array(0, 1, 0, 2, 3, 0), 2L), + (Array(0, 1, 0, 2, 0), 4L), + (Array(0, 1, 2, 0), 2L), + (Array(0, 3, 0, 2, 0), 3L), + (Array(0, 4, 0, 2, 0), 2L), + (Array(0, 5, 0, 2, 0), 2L), + (Array(0, 6, 0, 2, 0), 2L), + (Array(0, 5, 0, 6, 0, 2, 0), 2L), + (Array(0, 6, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 1, 0, 2, 0), 2L), + (Array(0, 4, 0, 3, 0, 2, 0), 2L), + (Array(0, 1, 0, 3, 0, 2, 0), 3L), + (Array(0, 5, 0, 6, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 1, 0, 3, 0, 2, 0), 2L), + (Array(0, 1, 0, 1, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 5, 0, 1, 0), 2L), + (Array(0, 2, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 2, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 2, 0, 1, 0), 2L)) + + compareInternalResults(expectedValue, result.collect()) + } + + test("PrefixSpan projections with multiple partial starts") { + val sequences = Seq( + Array(Array(1, 2), Array(1, 2, 3))) + val rdd = sc.parallelize(sequences, 2) + val prefixSpan = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + val model = prefixSpan.run(rdd) + val expected = Array( + (Array(Array(1)), 1L), + (Array(Array(1, 2)), 1L), + (Array(Array(1), Array(1)), 1L), + (Array(Array(1), Array(2)), 1L), + (Array(Array(1), Array(3)), 1L), + (Array(Array(1, 3)), 1L), + (Array(Array(2)), 1L), + (Array(Array(2, 3)), 1L), + (Array(Array(2), Array(1)), 1L), + (Array(Array(2), Array(2)), 1L), + (Array(Array(2), Array(3)), 1L), + (Array(Array(3)), 1L)) + compareResults(expected, model.freqSequences.collect()) + } + + test("PrefixSpan Integer type, variable-size itemsets") { + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + + /* + To verify results, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + + val model = prefixSpan.run(rdd) + val expected = Array( + (Array(Array(1)), 3L), + (Array(Array(2)), 3L), + (Array(Array(3)), 2L), + (Array(Array(1), Array(3)), 2L), + (Array(Array(1, 2)), 3L) + ) + compareResults(expected, model.freqSequences.collect()) + } + + test("PrefixSpan String type, variable-size itemsets") { + // This is the same test as "PrefixSpan Int type, variable-size itemsets" except + // mapped to Strings + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))).map(seq => seq.map(itemSet => itemSet.map(intToString))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + + val model = prefixSpan.run(rdd) + val expected = Array( + (Array(Array(1)), 3L), + (Array(Array(2)), 3L), + (Array(Array(3)), 2L), + (Array(Array(1), Array(3)), 2L), + (Array(Array(1, 2)), 3L) + ).map { case (pattern, count) => + (pattern.map(itemSet => itemSet.map(intToString)), count) + } + compareResults(expected, model.freqSequences.collect()) + } + + private def compareResults[Item]( + expectedValue: Array[(Array[Array[Item]], Long)], + actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { + val expectedSet = expectedValue.map { case (pattern: Array[Array[Item]], count: Long) => + (pattern.map(itemSet => itemSet.toSet).toSeq, count) + }.toSet + val actualSet = actualValue.map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + assert(expectedSet === actualSet) + } + + private def compareInternalResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Unit = { + val expectedSet = expectedValue.map(x => (x._1.toSeq, x._2)).toSet + val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet + assert(expectedSet === actualSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index dac28a369b5b..e331c7598918 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,35 +17,33 @@ package org.apache.spark.mllib.impl -import org.scalatest.FunSuite - import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { import PeriodicGraphCheckpointerSuite._ - // TODO: Do I need to call count() on the graphs' RDDs? - test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) checkPersistence(graphsToCheck, iteration) iteration += 1 @@ -57,9 +55,11 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString val checkpointInterval = 2 var graphsToCheck = Seq.empty[GraphToCheck] - + sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) @@ -68,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graph.vertices.count() graph.edges.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) @@ -170,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite { } else { // Graph should never be checkpointed assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") } } catch { case e: AssertionError => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 000000000000..b2a459a68b5f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + rdd.getCheckpointFile.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index b0b78acd6df1..d119e0b50a39 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ -class BLASSuite extends FunSuite { +class BLASSuite extends SparkFunSuite { test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) @@ -140,7 +139,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dA) assert(dA ~== expected absTol 1e-15) - + val dB = new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) @@ -149,7 +148,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dB) } } - + val dC = new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) @@ -158,7 +157,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dC) } } - + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) withClue("Size of vector must match the rank of matrix") { @@ -166,6 +165,14 @@ class BLASSuite extends FunSuite { syr(alpha, y, dA) } } + + val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0)) + val dD = new DenseMatrix(4, 4, + Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) + syr(0.1, xSparse, dD) + val expectedSparse = new DenseMatrix(4, 4, + Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4)) + assert(dD ~== expectedSparse absTol 1e-15) } test("gemm") { @@ -193,8 +200,14 @@ class BLASSuite extends FunSuite { val C10 = C1.copy val C11 = C1.copy val C12 = C1.copy + val C13 = C1.copy + val C14 = C1.copy + val C15 = C1.copy + val C16 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) + val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) + val expected5 = C1.copy gemm(1.0, dA, B, 2.0, C1) gemm(1.0, sA, B, 2.0, C2) @@ -241,6 +254,16 @@ class BLASSuite extends FunSuite { assert(C10 ~== expected2 absTol 1e-15) assert(C11 ~== expected3 absTol 1e-15) assert(C12 ~== expected3 absTol 1e-15) + + gemm(0, dA, B, 5, C13) + gemm(0, sA, B, 5, C14) + gemm(0, dA, B, 1, C15) + gemm(0, sA, B, 1, C16) + assert(C13 ~== expected4 absTol 1e-15) + assert(C14 ~== expected4 absTol 1e-15) + assert(C15 ~== expected5 absTol 1e-15) + assert(C16 ~== expected5 absTol 1e-15) + } test("gemv") { @@ -249,32 +272,96 @@ class BLASSuite extends FunSuite { new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) - val x = new DenseVector(Array(1.0, 2.0, 3.0)) + val dA2 = + new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) + val sA2 = + new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), + true) + + val dx = new DenseVector(Array(1.0, 2.0, 3.0)) + val sx = dx.toSparse val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA.multiply(x) ~== expected absTol 1e-15) - assert(sA.multiply(x) ~== expected absTol 1e-15) + assert(dA.multiply(dx) ~== expected absTol 1e-15) + assert(sA.multiply(dx) ~== expected absTol 1e-15) + assert(dA.multiply(sx) ~== expected absTol 1e-15) + assert(sA.multiply(sx) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + val y8 = y1.copy + val y9 = y1.copy + val y10 = y1.copy + val y11 = y1.copy + val y12 = y1.copy + val y13 = y1.copy + val y14 = y1.copy + val y15 = y1.copy + val y16 = y1.copy + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) - gemv(1.0, dA, x, 2.0, y1) - gemv(1.0, sA, x, 2.0, y2) - gemv(2.0, dA, x, 2.0, y3) - gemv(2.0, sA, x, 2.0, y4) + gemv(1.0, dA, dx, 2.0, y1) + gemv(1.0, sA, dx, 2.0, y2) + gemv(1.0, dA, sx, 2.0, y3) + gemv(1.0, sA, sx, 2.0, y4) + + gemv(1.0, dA2, dx, 2.0, y5) + gemv(1.0, sA2, dx, 2.0, y6) + gemv(1.0, dA2, sx, 2.0, y7) + gemv(1.0, sA2, sx, 2.0, y8) + + gemv(2.0, dA, dx, 2.0, y9) + gemv(2.0, sA, dx, 2.0, y10) + gemv(2.0, dA, sx, 2.0, y11) + gemv(2.0, sA, sx, 2.0, y12) + + gemv(2.0, dA2, dx, 2.0, y13) + gemv(2.0, sA2, dx, 2.0, y14) + gemv(2.0, dA2, sx, 2.0, y15) + gemv(2.0, sA2, sx, 2.0, y16) + assert(y1 ~== expected2 absTol 1e-15) assert(y2 ~== expected2 absTol 1e-15) - assert(y3 ~== expected3 absTol 1e-15) - assert(y4 ~== expected3 absTol 1e-15) + assert(y3 ~== expected2 absTol 1e-15) + assert(y4 ~== expected2 absTol 1e-15) + + assert(y5 ~== expected2 absTol 1e-15) + assert(y6 ~== expected2 absTol 1e-15) + assert(y7 ~== expected2 absTol 1e-15) + assert(y8 ~== expected2 absTol 1e-15) + + assert(y9 ~== expected3 absTol 1e-15) + assert(y10 ~== expected3 absTol 1e-15) + assert(y11 ~== expected3 absTol 1e-15) + assert(y12 ~== expected3 absTol 1e-15) + + assert(y13 ~== expected3 absTol 1e-15) + assert(y14 ~== expected3 absTol 1e-15) + assert(y15 ~== expected3 absTol 1e-15) + assert(y16 ~== expected3 absTol 1e-15) + withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(1.0, dA.transpose, x, 2.0, y1) + gemv(1.0, dA.transpose, dx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, sA.transpose, dx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, dA.transpose, sx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, sA.transpose, sx, 2.0, y1) } } + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = @@ -283,7 +370,9 @@ class BLASSuite extends FunSuite { val dATT = dAT.transpose val sATT = sAT.transpose - assert(dATT.multiply(x) ~== expected absTol 1e-15) - assert(sATT.multiply(x) ~== expected absTol 1e-15) + assert(dATT.multiply(dx) ~== expected absTol 1e-15) + assert(sATT.multiply(dx) ~== expected absTol 1e-15) + assert(dATT.multiply(sx) ~== expected absTol 1e-15) + assert(sATT.multiply(sx) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 203103237397..dc04258e41d2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} -class BreezeMatrixConversionSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 8abdac72902c..3772c9235ad3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.SparkFunSuite + /** * Test Breeze vector conversions. */ -class BreezeVectorConversionSuite extends FunSuite { +class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index b1ebfde0e5e5..bfd6d5495f5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg import java.util.Random import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class MatricesSuite extends FunSuite { +class MatricesSuite extends SparkFunSuite { test("dense matrix construction") { val m = 3 val n = 2 @@ -74,6 +74,24 @@ class MatricesSuite extends FunSuite { } } + test("equals") { + val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)) + assert(dm1 === dm1) + assert(dm1 !== dm1.transpose) + + val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) + assert(dm1 === dm2.transpose) + + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm1) + assert(sm1 === dm1) + assert(sm1 !== sm1.transpose) + + val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm2.transpose) + assert(sm1 === dm2.transpose) + } + test("matrix copies are deep copies") { val m = 3 val n = 2 @@ -137,8 +155,8 @@ class MatricesSuite extends FunSuite { val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) - val spMat2 = deMat1.toSparse() - val deMat2 = spMat1.toDense() + val spMat2 = deMat1.toSparse + val deMat2 = spMat1.toDense assert(spMat1.toBreeze === spMat2.toBreeze) assert(deMat1.toBreeze === deMat2.toBreeze) @@ -185,8 +203,8 @@ class MatricesSuite extends FunSuite { assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") - assert(dAT.toSparse().toBreeze === sATexpected.toBreeze) - assert(sAT.toDense().toBreeze === dATexpected.toBreeze) + assert(dAT.toSparse.toBreeze === sATexpected.toBreeze) + assert(sAT.toDense.toBreeze === dATexpected.toBreeze) } test("foreachActive") { @@ -424,4 +442,45 @@ class MatricesSuite extends FunSuite { assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) } + + test("MatrixUDT") { + val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8)) + val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0)) + val dm3 = new DenseMatrix(0, 0, Array()) + val sm1 = dm1.toSparse + val sm2 = dm2.toSparse + val sm3 = dm3.toSparse + val mUDT = new MatrixUDT() + Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach { + mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray) + } + assert(mUDT.typeName == "matrix") + assert(mUDT.simpleString == "matrix") + } + + test("toString") { + val empty = Matrices.ones(0, 0) + empty.toString(0, 0) + + val mat = Matrices.rand(5, 10, new Random()) + mat.toString(-1, -5) + mat.toString(0, 0) + mat.toString(Int.MinValue, Int.MinValue) + mat.toString(Int.MaxValue, Int.MaxValue) + var lines = mat.toString(6, 50).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 50)) + + lines = mat.toString(5, 100).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 100)) + } + + test("numNonzeros and numActives") { + val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1)) + assert(dm1.numNonzeros === 3) + assert(dm1.numActives === 6) + + val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)) + assert(sm1.numNonzeros === 1) + assert(sm1.numActives === 3) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 5def899cea11..6508ddeba420 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends FunSuite { +class VectorsSuite extends SparkFunSuite with Logging { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -58,16 +57,70 @@ class VectorsSuite extends FunSuite { assert(vec.values === values) } + test("sparse vector construction with mismatched indices/values array") { + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0)) + } + } + + test("sparse vector construction with too many indices vs size") { + intercept[IllegalArgumentException] { + Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0)) + } + } + test("dense to array") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.toArray.eq(arr)) } + test("dense argmax") { + val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector] + assert(vec.argmax === -1) + + val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector] + assert(vec2.argmax === 3) + + val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector] + assert(vec3.argmax === 3) + } + test("sparse to array") { val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] assert(vec.toArray === arr) } + test("sparse argmax") { + val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec.argmax === -1) + + val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] + assert(vec2.argmax === 3) + + val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7)) + assert(vec3.argmax === 2) + + // check for case that sparse vector is created with + // only negative values {0.0, 0.0,-1.0, -0.7, 0.0} + val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7)) + assert(vec4.argmax === 0) + + val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0)) + assert(vec5.argmax === 1) + + val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0)) + assert(vec6.argmax === 2) + + val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7)) + assert(vec7.argmax === 1) + + val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) + assert(vec8.argmax === 0) + } + test("vector equals") { val dv1 = Vectors.dense(arr.clone()) val dv2 = Vectors.dense(arr.clone()) @@ -143,7 +196,7 @@ class VectorsSuite extends FunSuite { malformatted.foreach { s => intercept[SparkException] { Vectors.parse(s) - println(s"Didn't detect malformatted string $s.") + logInfo(s"Didn't detect malformatted string $s.") } } } @@ -187,6 +240,8 @@ class VectorsSuite extends FunSuite { for (v <- Seq(dv0, dv1, sv0, sv1)) { assert(v === udt.deserialize(udt.serialize(v))) } + assert(udt.typeName == "vector") + assert(udt.simpleString == "vector") } test("fromBreeze") { @@ -213,13 +268,13 @@ class VectorsSuite extends FunSuite { val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) - // SparseVector vs. SparseVector - assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. SparseVector assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. DenseVector assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) - } + } } test("foreachActive") { @@ -268,4 +323,55 @@ class VectorsSuite extends FunSuite { assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) } + + test("Vector numActive and numNonzeros") { + val dv = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv.numActives === 4) + assert(dv.numNonzeros === 2) + + val sv = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv.numActives === 3) + assert(sv.numNonzeros === 2) + } + + test("Vector toSparse and toDense") { + val dv0 = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv0.toDense === dv0) + val dv0s = dv0.toSparse + assert(dv0s.numActives === 2) + assert(dv0s === dv0) + + val sv0 = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv0.toDense === sv0) + val sv0s = sv0.toSparse + assert(sv0s.numActives === 2) + assert(sv0s === sv0) + } + + test("Vector.compressed") { + val dv0 = Vectors.dense(1.0, 2.0, 3.0, 0.0) + val dv0c = dv0.compressed.asInstanceOf[DenseVector] + assert(dv0c === dv0) + + val dv1 = Vectors.dense(0.0, 2.0, 0.0, 0.0) + val dv1c = dv1.compressed.asInstanceOf[SparseVector] + assert(dv1 === dv1c) + assert(dv1c.numActives === 1) + + val sv0 = Vectors.sparse(4, Array(1, 2), Array(2.0, 0.0)) + val sv0c = sv0.compressed.asInstanceOf[SparseVector] + assert(sv0 === sv0c) + assert(sv0c.numActives === 1) + + val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) + val sv1c = sv1.compressed.asInstanceOf[DenseVector] + assert(sv1 === sv1c) + } + + test("SparseVector.slice") { + val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4)) + assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2))) + assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) + assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 949d1c993957..93fe04c139b9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 @@ -57,11 +56,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { val random = new ju.Random() // This should generate a 4x4 grid of 1x2 blocks. val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + // scalastyle:off val expected0 = Array( Array(0, 0, 4, 4, 8, 8, 12), Array(1, 1, 5, 5, 9, 9, 13), Array(2, 2, 6, 6, 10, 10, 14), Array(3, 3, 7, 7, 11, 11, 15)) + // scalastyle:on for (i <- 0 until 4; j <- 0 until 7) { assert(part0.getPartition((i, j)) === expected0(i)(j)) assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 04b36a9ef999..f3728cd036a3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 2ab53cc13db7..0ecb7a221a50 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -136,6 +135,17 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate matrix sizes of svd") { + val k = 2 + val A = new IndexedRowMatrix(indexedRows) + val svd = A.computeSVD(k, computeU = true) + assert(svd.U.numRows() === m) + assert(svd.U.numCols() === k) + assert(svd.s.size === k) + assert(svd.V.numRows === n) + assert(svd.V.numCols === k) + } + test("validate k in svd") { val A = new IndexedRowMatrix(indexedRows) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 3309713e91f8..283ffec1d49d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random +import breeze.numerics.abs import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -96,7 +97,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } test("similar columns") { - val colMags = Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)) + val colMags = Vectors.dense(math.sqrt(126), math.sqrt(66), math.sqrt(94)) val expected = BDM( (0.0, 54.0, 72.0), (0.0, 0.0, 78.0), @@ -232,15 +233,31 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") - assert(summary.normL2 === Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)), + assert(summary.normL2 === Vectors.dense(math.sqrt(126), math.sqrt(66), math.sqrt(94)), "magnitude mismatch.") assert(summary.normL1 === Vectors.dense(18.0, 12.0, 16.0), "L1 norm mismatch") } } } + + test("QR Decomposition") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.tallSkinnyQR(true) + val expected = breeze.linalg.qr.reduced(mat.toBreeze()) + val calcQ = result.Q + val calcR = result.R + assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) + assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze())) + // Decomposition without computing Q + val rOnly = mat.tallSkinnyQR(computeQ = false) + assert(rOnly.Q == null) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) + } + } } -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { var mat: RowMatrix = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 86481c6e6620..36ac7d267243 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.mllib.optimization -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -34,7 +35,7 @@ object GradientDescentSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateGDInput(offset, scale, nPoints, seed)) + generateGDInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale * X) @@ -42,7 +43,7 @@ object GradientDescentSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -61,7 +62,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 @@ -81,11 +82,11 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0 +: features.toArray) + label -> MLUtils.appendBias(features) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) + val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -138,9 +139,48 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } + + test("iteration should end with convergence tolerance") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val initialB = -1.0 + val initialWeights = Array(initialB) + + val gradient = new LogisticGradient() + val updater = new SimpleUpdater() + val stepSize = 1.0 + val numIterations = 10 + val regParam = 0 + val miniBatchFrac = 1.0 + val convergenceTolerance = 5.0e-1 + + // Add a extra variable consisting of all 1.0's for the intercept. + val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) + val data = testData.map { case LabeledPoint(label, features) => + label -> MLUtils.appendBias(features) + } + + val dataRDD = sc.parallelize(data, 2).cache() + val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) + + val (_, loss) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept, + convergenceTolerance) + + assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early") + } } -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 70c64775e4c0..75ae0eb32fb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 @@ -89,7 +90,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { // it requires 90 iterations in GD. No matter how hard we increase // the number of iterations in GD here, the lossGD will be always // larger than lossLBFGS. This is based on observation, no theoretically guaranteed - assert(Math.abs((lossGD.last - loss.last) / loss.last) < 0.02, + assert(math.abs((lossGD.last - loss.last) / loss.last) < 0.02, "LBFGS should match GD result within 2% difference.") } @@ -121,7 +122,8 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { numGDIterations, regParam, miniBatchFrac, - initialWeightsWithIntercept) + initialWeightsWithIntercept, + convergenceTol) assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") @@ -220,7 +222,8 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { numGDIterations, regParam, miniBatchFrac, - initialWeightsWithIntercept) + initialWeightsWithIntercept, + convergenceTol) // for class LBFGS and the optimize method, we only look at the weights assert( @@ -229,7 +232,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { } } -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 82c327bd49fc..d8f9b8c33963 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.FunSuite - import org.jblas.{DoubleMatrix, SimpleBlas} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class NNLSSuite extends FunSuite { +class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) @@ -55,7 +54,7 @@ class NNLSSuite extends FunSuite { for (k <- 0 until 100) { val (ata, atb) = genOnesData(n, rand) - val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val x = new DoubleMatrix(NNLS.solve(ata.data, atb.data, ws)) assert(x.length === n) val answer = DoubleMatrix.ones(n, 1) SimpleBlas.axpy(-1.0, answer, x) @@ -68,18 +67,20 @@ class NNLSSuite extends FunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 + // scalastyle:off val ata = new DoubleMatrix(Array( Array( 4.377, -3.531, -1.306, -0.139, 3.418), Array(-3.531, 4.344, 0.934, 0.305, -2.140), Array(-1.306, 0.934, 2.644, -0.203, -0.170), Array(-0.139, 0.305, -0.203, 5.883, 1.428), Array( 3.418, -2.140, -0.170, 1.428, 4.684))) + // scalastyle:on val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) val ws = NNLS.createWorkspace(n) - val x = NNLS.solve(ata, atb, ws) + val x = NNLS.solve(ata.data, atb.data, ws) for (i <- 0 until n) { assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) @@ -104,7 +105,7 @@ class NNLSSuite extends FunSuite { val ws = NNLS.createWorkspace(n) - val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val x = new DoubleMatrix(NNLS.solve(ata.data, atb.data, ws)) val obj = computeObjectiveValue(ata, atb, x) assert(obj < refObj + 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala new file mode 100644 index 000000000000..4c6e76e47419 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionNormalizationMethodType + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.util.LinearDataGenerator + +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { + + test("logistic regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + + // assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === logisticRegressionModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) + } + + test("linear SVM PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + + // assert that the PMML format is as expected + assert(svmModelExport.isInstanceOf[PMMLModelExport]) + val pmml = svmModelExport.getPmml + assert(pmml.getHeader.getDescription + === "linear SVM") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === svmModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure linear SVM has normalization method set to NONE + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala new file mode 100644 index 000000000000..1d3230948178 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} +import org.apache.spark.mllib.util.LinearDataGenerator + +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite { + + test("linear regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + // assert that the PMML format is as expected + assert(linearModelExport.isInstanceOf[PMMLModelExport]) + val pmml = linearModelExport.getPmml + assert(pmml.getHeader.getDescription === "linear regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === linearRegressionModel.weights.size + 1) + // This verifies that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === linearRegressionModel.weights.size) + } + + test("ridge regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + // assert that the PMML format is as expected + assert(ridgeModelExport.isInstanceOf[PMMLModelExport]) + val pmml = ridgeModelExport.getPmml + assert(pmml.getHeader.getDescription === "ridge regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === ridgeRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === ridgeRegressionModel.weights.size) + } + + test("lasso PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + // assert that the PMML format is as expected + assert(lassoModelExport.isInstanceOf[PMMLModelExport]) + val pmml = lassoModelExport.getPmml + assert(pmml.getHeader.getDescription === "lasso regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === lassoModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === lassoModel.weights.size) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala new file mode 100644 index 000000000000..b3f9750afa73 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.ClusteringModel + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors + +class KMeansPMMLModelExportSuite extends SparkFunSuite { + + test("KMeansPMMLModelExport generate PMML format") { + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) + + // assert that the PMML format is as expected + assert(modelExport.isInstanceOf[PMMLModelExport]) + val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala new file mode 100644 index 000000000000..af4945096175 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} +import org.apache.spark.mllib.util.LinearDataGenerator + +class PMMLModelExportFactorySuite extends SparkFunSuite { + + test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) + + assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) + } + + test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + + "LinearRegressionModel, RidgeRegressionModel or LassoModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + } + + test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + + "when passing a LogisticRegressionModel or SVMModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + val logisticRegressionModelExport = + PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + } + + test("PMMLModelExportFactory throw IllegalArgumentException " + + "when passing a Multinomial Logistic Regression") { + /** 3 classes, 2 features */ + val multiclassLogisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + numFeatures = 2, numClasses = 3) + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) + } + } + + test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { + val invalidModel = new Object + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(invalidModel) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b792d819fdab..a5ca1518f82f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random import scala.math -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class RandomDataGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite { def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 6395188a0842..413db2000d6d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, @@ -181,7 +180,8 @@ class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializa val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) - val exponential = RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) + val exponential = + RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) testGeneratedVectorRDD(exponential, rows, cols, parts, exponentialMean, exponentialMean, 0.1) val gamma = RandomRDDs.gammaVectorRDD(sc, gammaShape, gammaScale, rows, cols, parts, seed) @@ -197,7 +197,7 @@ private[random] class MockDistro extends RandomDataGenerator[Double] { // This allows us to check that each partition has a different seed override def nextValue(): Double = seed.toDouble - override def setSeed(seed: Long) = this.seed = seed + override def setSeed(seed: Long): Unit = this.seed = seed override def copy(): MockDistro = new MockDistro } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala new file mode 100644 index 000000000000..10f5a2be48f7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ + +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { + test("topByKey") { + val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, + 1), (3, 5)), 2) + .topByKey(5) + .collectAsMap() + + assert(topMap.size === 3) + assert(topMap(1) === Array(7, 6, 3, 2, 1)) + assert(topMap(3) === Array(7, 5, 2)) + assert(topMap(5) === Array(1)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 6d6c0aa5be81..bc6417261483 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 8775c0ca9df8..045135f7f8d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.recommendation -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.math.abs import scala.util.Random -import org.scalatest.FunSuite import org.jblas.DoubleMatrix +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel @@ -38,7 +38,7 @@ object ALSSuite { negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { val (sampledRatings, trueRatings, truePrefs) = generateRatings(users, products, features, samplingRate, implicitPrefs) - (seqAsJavaList(sampledRatings), trueRatings, truePrefs) + (sampledRatings.asJava, trueRatings, truePrefs) } def generateRatings( @@ -84,7 +84,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with MLlibTestSparkContext { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) @@ -203,6 +203,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ + // scalastyle:off def testALS( users: Int, products: Int, @@ -216,6 +217,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { numUserBlocks: Int = -1, numProductBlocks: Int = -1, negativeFactors: Boolean = true) { + // scalastyle:on + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc904a2..2c8ed057a516 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.recommendation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext { val rank = 2 var userFeatures: RDD[(Int, Array[Double])] = _ @@ -53,4 +53,42 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("batch predict API recommendProductsForUsers") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val topK = 10 + val recommendations = model.recommendProductsForUsers(topK).collectAsMap() + + assert(recommendations(0)(0).rating ~== 17.0 relTol 1e-14) + assert(recommendations(1)(0).rating ~== 39.0 relTol 1e-14) + } + + test("batch predict API recommendUsersForProducts") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val topK = 10 + val recommendations = model.recommendUsersForProducts(topK).collectAsMap() + + assert(recommendations(2)(0).user == 1) + assert(recommendations(2)(0).rating ~== 39.0 relTol 1e-14) + assert(recommendations(2)(1).user == 0) + assert(recommendations(2)(1).rating ~== 17.0 relTol 1e-14) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 7ef45248281e..ea4f2865757c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.mllib.regression -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { - Math.round(d * 100).toDouble / 100 + math.round(d * 100).toDouble / 100 } private def generateIsotonicInput(labels: Seq[Double]): Seq[(Double, Double, Double)] = { @@ -73,6 +75,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(model.isotonic) } + test("model save/load") { + val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) + val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0) + val model = new IsotonicRegressionModel(boundaries, predictions, true) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = IsotonicRegressionModel.load(sc, path) + assert(model.boundaries === sameModel.boundaries) + assert(model.predictions === sameModel.predictions) + assert(model.isotonic === model.isotonic) + } finally { + Utils.deleteRecursively(tempDir) + } + } + test("isotonic regression with size 0") { val model = runIsotonicRegression(Seq(), true) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 110c44a7193f..f8d0af8820e6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -class LabeledPointSuite extends FunSuite { +class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( @@ -32,6 +31,11 @@ class LabeledPointSuite extends FunSuite { } } + test("parse labeled points with whitespaces") { + val point = LabeledPoint.parse("(0.0, [1.0, 2.0])") + assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0))) + } + test("parse labeled points with v0.9 format") { val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2668dcc14a84..39537e7bb4c7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,13 +19,19 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LassoSuite { -class LassoSuite extends FunSuite with MLlibTestSparkContext { + /** 3 features */ + val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} + +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -60,11 +66,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -93,7 +100,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005) val model = ls.run(testRDD, initialWeights) val weight0 = model.weights(0) @@ -103,11 +110,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData,2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -115,9 +123,26 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("model save/load") { + val model = LassoSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LassoModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 864622a9296a..f88a1c33c9f7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,13 +19,19 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LinearRegressionSuite { -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { + /** 3 features */ + val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} + +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -124,9 +130,26 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { validatePrediction( sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } + + test("model save/load") { + val model = LinearRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LinearRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 18d3bf5ea4ec..7a781fee634c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -20,15 +20,22 @@ package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { +private object RidgeRegressionSuite { - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + /** 3 features */ + val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} + +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) }.reduceLeft(_ + _) / predictions.size @@ -75,9 +82,26 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("model save/load") { + val model = RidgeRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RidgeRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf..34c07ed17081 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -19,17 +19,25 @@ package org.apache.spark.mllib.regression import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 20000 + override def maxWaitTimeMillis: Int = 20000 + + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -54,6 +62,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.2) .setNumIterations(25) + .setConvergenceTol(0.0001) // generate sequence of simulated data val numBatches = 10 @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) inputDStream.count() @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) // collect the output as (true, estimated) tuples @@ -139,4 +148,50 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + val numBatches = 10 + val nPoints = 100 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index d20a09b4b492..c3eeda012571 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { // test input data val xData = Array(1.0, 0.0, -2.0) @@ -96,11 +95,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { val X = sc.parallelize(data) val defaultMat = Statistics.corr(X) val pearsonMat = Statistics.corr(X, "pearson") + // scalastyle:off val expected = BDM( (1.00000000, 0.05564149, Double.NaN, 0.4004714), (0.05564149, 1.00000000, Double.NaN, 0.9135959), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), - (0.40047142, 0.91359586, Double.NaN,1.0000000)) + (0.40047142, 0.91359586, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(defaultMat.toBreeze, expected)) assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) } @@ -108,11 +109,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { test("corr(X) spearman") { val X = sc.parallelize(data) val spearmanMat = Statistics.corr(X, "spearman") + // scalastyle:off val expected = BDM( (1.0000000, 0.1054093, Double.NaN, 0.4000000), (0.1054093, 1.0000000, Double.NaN, 0.9486833), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) } @@ -143,7 +146,7 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = { for (i <- 0 until A.rows; j <- 0 until A.cols) { if (!approxEqual(A(i, j), B(i, j), threshold)) { - println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) + logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) return false } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 15418e603596..142b90e764a7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -19,16 +19,18 @@ package org.apache.spark.mllib.stat import java.util.Random -import org.scalatest.FunSuite +import org.apache.commons.math3.distribution.{ExponentialDistribution, + NormalDistribution, UniformRealDistribution} +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { @@ -155,4 +157,101 @@ class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) } } + + test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") { + // Create theoretical distributions + val stdNormalDist = new NormalDistribution(0, 1) + val expDist = new ExponentialDistribution(0.6) + val unifDist = new UniformRealDistribution() + + // set seeds + val seed = 10L + stdNormalDist.reseedRandomGenerator(seed) + expDist.reseedRandomGenerator(seed) + unifDist.reseedRandomGenerator(seed) + + // Sample data from the distributions and parallelize it + val n = 100000 + val sampledNorm = sc.parallelize(stdNormalDist.sample(n), 10) + val sampledExp = sc.parallelize(expDist.sample(n), 10) + val sampledUnif = sc.parallelize(unifDist.sample(n), 10) + + // Use a apache math commons local KS test to verify calculations + val ksTest = new KolmogorovSmirnovTest() + val pThreshold = 0.05 + + // Comparing a standard normal sample to a standard normal distribution + val result1 = Statistics.kolmogorovSmirnovTest(sampledNorm, "norm", 0, 1) + val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledNorm.collect()) + val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n) + // Verify vs apache math commons ks test + assert(result1.statistic ~== referenceStat1 relTol 1e-4) + assert(result1.pValue ~== referencePVal1 relTol 1e-4) + // Cannot reject null hypothesis + assert(result1.pValue > pThreshold) + + // Comparing an exponential sample to a standard normal distribution + val result2 = Statistics.kolmogorovSmirnovTest(sampledExp, "norm", 0, 1) + val referenceStat2 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledExp.collect()) + val referencePVal2 = 1 - ksTest.cdf(referenceStat2, n) + // verify vs apache math commons ks test + assert(result2.statistic ~== referenceStat2 relTol 1e-4) + assert(result2.pValue ~== referencePVal2 relTol 1e-4) + // reject null hypothesis + assert(result2.pValue < pThreshold) + + // Testing the use of a user provided CDF function + // Distribution is not serializable, so will have to create in the lambda + val expCDF = (x: Double) => new ExponentialDistribution(0.2).cumulativeProbability(x) + + // Comparing an exponential sample with mean X to an exponential distribution with mean Y + // Where X != Y + val result3 = Statistics.kolmogorovSmirnovTest(sampledExp, expCDF) + val referenceStat3 = ksTest.kolmogorovSmirnovStatistic(new ExponentialDistribution(0.2), + sampledExp.collect()) + val referencePVal3 = 1 - ksTest.cdf(referenceStat3, sampledNorm.count().toInt) + // verify vs apache math commons ks test + assert(result3.statistic ~== referenceStat3 relTol 1e-4) + assert(result3.pValue ~== referencePVal3 relTol 1e-4) + // reject null hypothesis + assert(result3.pValue < pThreshold) + } + + test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") { + /* + Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample + > sessionInfo() + R version 3.2.0 (2015-04-16) + Platform: x86_64-apple-darwin13.4.0 (64-bit) + > set.seed(20) + > v <- rnorm(20) + > v + [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612 + [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222 + [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264 + [19] -0.08578219 0.38921440 + > ks.test(v, pnorm, alternative = "two.sided") + + One-sample Kolmogorov-Smirnov test + + data: v + D = 0.18874, p-value = 0.4223 + alternative hypothesis: two-sided + */ + + val rKSStat = 0.18874 + val rKSPVal = 0.4223 + val rData = sc.parallelize( + Array( + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ) + ) + val rCompResult = Statistics.kolmogorovSmirnovTest(rData, "norm", 0, 1) + assert(rCompResult.statistic ~== rKSStat relTol 1e-4) + assert(rCompResult.pValue ~== rKSPVal relTol 1e-4) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala new file mode 100644 index 000000000000..5feccdf33681 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.commons.math3.distribution.NormalDistribution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { + test("kernel density single sample") { + val rdd = sc.parallelize(Array(5.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) + val normal = new NormalDistribution(5.0, 3.0) + val acceptableErr = 1e-6 + assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) + assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) + } + + test("kernel density multiple samples") { + val rdd = sc.parallelize(Array(5.0, 10.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) + val normal1 = new NormalDistribution(5.0, 3.0) + val normal2 = new NormalDistribution(10.0, 3.0) + val acceptableErr = 1e-6 + assert(math.abs( + densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) + assert(math.abs( + densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 23b0eec865de..07efde4f5e6d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateOnlineSummarizerSuite extends FunSuite { +class MultivariateOnlineSummarizerSuite extends SparkFunSuite { test("basic error handing") { val summarizer = new MultivariateOnlineSummarizer diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index fac2498e4dcb..aa60deb665ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,49 +17,48 @@ package org.apache.spark.mllib.stat.distribution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) - + val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } - + test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } - + test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9347eaf9221a..356d957f1590 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -29,10 +28,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + + +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { + ///////////////////////////////////////////////////////////////////////////// + // Tests examining individual elements of training + ///////////////////////////////////////////////////////////////////////////// test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() @@ -188,7 +193,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 3) - assert(bins(0).length === 6) + assert(bins(0).length === 0) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -226,41 +231,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(1.0)) - // Check bins. - - assert(bins(0)(0).category === Double.MinValue) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(0).category === Double.MinValue) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(0)(1).category === Double.MinValue) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(0.0)) - assert(bins(0)(1).highSplit.categories.length === 1) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(1).category === Double.MinValue) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 1) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(0)(2).category === Double.MinValue) - assert(bins(0)(2).lowSplit.categories.length === 1) - assert(bins(0)(2).lowSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.length === 2) - assert(bins(0)(2).highSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.contains(0.0)) - assert(bins(1)(2).category === Double.MinValue) - assert(bins(1)(2).lowSplit.categories.length === 1) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length === 2) - assert(bins(1)(2).highSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.contains(0.0)) - } test("Multiclass classification with ordered categorical features: split and bin calculations") { @@ -287,6 +257,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(bins(0).length === 0) } + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Second level node building with vs. without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClasses = 2, maxBins = 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) + + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + val children1 = new Array[Node](2) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict.predict === children2(i).predict.predict) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -471,76 +600,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(rootNode.predict.predict === 1) } - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -561,11 +620,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) - arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) @@ -577,11 +636,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 2 continuous features") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -701,11 +760,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min instances per node requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInstancesPerNode = 2) @@ -728,11 +786,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("do not choose split that does not satisfy min instance per node requirements") { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -748,10 +806,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min info gain requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, @@ -772,94 +830,35 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("Avoid aggregation on the last level") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) + test("Node.subtreeIterator") { + val model = DecisionTreeSuite.createModel(Classification) + val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted + assert(nodeIds === DecisionTreeSuite.createdModelNodeIds) } - test("Avoid aggregation if impurity is 0.0") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val model = DecisionTreeSuite.createModel(algo) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = DecisionTreeModel.load(sc, path) + DecisionTreeSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } } } -object DecisionTreeSuite { +object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, @@ -979,4 +978,96 @@ object DecisionTreeSuite { arr } + /** Create a leaf node with the given node ID */ + private def createLeafNode(id: Int): Node = { + Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true) + } + + /** + * Create an internal node with the given node ID and feature type. + * Note: This does NOT set the child nodes. + */ + private def createInternalNode(id: Int, featureType: FeatureType): Node = { + val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false) + featureType match { + case Continuous => + node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous, + categories = List.empty[Double])) + case Categorical => + node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, + categories = List(0.0, 1.0))) + } + // TODO: The information gain stats should be consistent with info in children: SPARK-7131 + node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, + leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) + node + } + + /** + * Create a tree model. This is deterministic and contains a variety of node and feature types. + * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131 + */ + private[spark] def createModel(algo: Algo): DecisionTreeModel = { + val topNode = createInternalNode(id = 1, Continuous) + val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) + val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) + topNode.leftNode = Some(node2) + topNode.rightNode = Some(node3) + node3.leftNode = Some(node6) + node3.rightNode = Some(node7) + new DecisionTreeModel(topNode, algo) + } + + /** Sorted Node IDs matching the model returned by [[createModel()]] */ + private val createdModelNodeIds = Array(1, 2, 3, 6, 7) + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + assert(a.algo === b.algo) + checkEqual(a.topNode, b.topNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendents are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.id === b.id) + assert(a.predict === b.predict) + assert(a.impurity === b.impurity) + assert(a.isLeaf === b.isLeaf) + assert(a.split === b.split) + (a.stats, b.stats) match { + // TODO: Check other fields besides the infomation gain. + case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) + case (None, None) => + case _ => throw new AssertionError( + s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + } + (a.leftNode, b.leftNode) match { + case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has leftNode defined. " + + s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + } + (a.rightNode, b.rightNode) match { + case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has rightNode defined. " + + s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 8972c229b7ec..334bf3790fc7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -70,7 +70,7 @@ object EnsembleTestHelper { metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => - prediction - label + label - prediction } val metric = metricName match { case "mse" => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index e8341a5d0d10..6fc9e8df621d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,50 +17,49 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} - +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) - - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) - val boostingStrategy = - new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) - - val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) - - assert(gbt.trees.size === numIterations) - try { - EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) - } catch { - case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + - s" subsamplingRate=$subsamplingRate") - throw e - } - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val dt = DecisionTree.train(remappedInput, treeStrategy) - - // Make sure trees are the same. - assert(gbt.trees.head.toString == dt.toString) + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) } } @@ -81,7 +80,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } @@ -112,7 +111,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } @@ -133,14 +132,100 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { BoostingStrategy.defaultParams(algo) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + + Array(Classification, Regression).foreach { algo => + val model = new GradientBoostedTreesModel(algo, trees, treeWeights) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = GradientBoostedTreesModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + assert(model.treeWeights === sameModel.treeWeights) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("runWithValidation stops early and performs better on a validation dataset") { + // Set numIterations large enough so that it stops early. + val numIterations = 20 + val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) + val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) + + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } -object GradientBoostedTreesSuite { +private object GradientBoostedTreesSuite { // Combinations for estimators, learning rates and subsamplingRate val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) - val randomSeeds = Array(681283, 4398) - val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 92b498580af0..49aff21fe791 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends FunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 55e963977b54..e6df5d974bf3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -19,21 +19,22 @@ package org.apache.spark.mllib.tree import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.Node +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) @@ -194,7 +195,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){ @@ -212,6 +212,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { assert(rf1.toDebugString != rf2.toDebugString) } -} - + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray + val model = new RandomForestModel(algo, trees) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RandomForestModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index b184e936672c..9d756da41032 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impl -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 5e9101cdd380..525ab68c7921 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -26,7 +26,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => override def beforeAll() { val conf = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("test-cluster") .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data sc = new SparkContext(conf) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 668fc1d43c5d..70219e9ad9d3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -21,19 +21,19 @@ import java.io.File import scala.io.Source -import org.scalatest.FunSuite - import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { +class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") - if (m > 10) { + if (m > 10) { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) @@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), @@ -168,7 +202,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { "Each training+validation set combined should contain all of the data.") } // K fold cross validation should only have each element in the validation set exactly once - assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted === data.collect().sorted) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index b658889476d3..5d1796ef6572 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.util -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SQLContext trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ + @transient var sqlContext: SQLContext = _ override def beforeAll() { super.beforeAll() @@ -31,12 +32,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => .setMaster("local[2]") .setAppName("MLlibUnitTest") sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) } override def afterAll() { + sqlContext = null if (sc != null) { sc.stop() } + sc = null super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index f68fb95eac4e..16d7c3ab39b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.util -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.SparkException - -class NumericParserSuite extends FunSuite { +class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" @@ -35,8 +33,15 @@ class NumericParserSuite extends FunSuite { malformatted.foreach { s => intercept[SparkException] { NumericParser.parse(s) - println(s"Didn't detect malformatted string $s.") + throw new RuntimeException(s"Didn't detect malformatted string $s.") } } } + + test("parser with whitespaces") { + val s = "(0.0, [1.0, 2.0])" + val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] + assert(parsed(0).asInstanceOf[Double] === 0.0) + assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index e957fa5d25f4..352193a67860 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -95,16 +95,16 @@ object TestingUtils { /** * Comparison using absolute tolerance. */ - def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, - x, eps, ABS_TOL_MSG) + def absTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG) /** * Comparison using relative tolerance. */ - def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, - x, eps, REL_TOL_MSG) + def relTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareVectorRightSide( @@ -166,7 +166,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareMatrixRightSide( @@ -229,7 +229,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index b0ecb33c2848..8f475f30249d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.util +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.scalatest.exceptions.TestFailedException -class TestingUtilsSuite extends FunSuite { +class TestingUtilsSuite extends SparkFunSuite { test("Comparing doubles using relative error.") { @@ -88,16 +88,20 @@ class TestingUtilsSuite extends FunSuite { assert(!(17.8 ~= 17.59 absTol 0.2)) // Comparisons of numbers very close to zero, and both side of zeros - assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - - assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert( + -Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) } test("Comparing vectors using relative error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) @@ -130,7 +134,7 @@ class TestingUtilsSuite extends FunSuite { test("Comparing vectors using absolute error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) diff --git a/network/common/pom.xml b/network/common/pom.xml index 8f7c924d6b3a..7dc3068ab8cb 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -77,7 +77,12 @@ org.mockito - mockito-all + mockito-core + test + + + org.slf4j + slf4j-log4j12 test @@ -90,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 5bc6e5a2418a..b8d073fa16b4 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,7 +36,7 @@ import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -82,13 +83,21 @@ public TransportClientFactory createClientFactory() { } /** Create a server which will attempt to bind to a specific port. */ - public TransportServer createServer(int port) { - return new TransportServer(this, port); + public TransportServer createServer(int port, List bootstraps) { + return new TransportServer(this, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ + public TransportServer createServer(List bootstraps) { + return createServer(0, bootstraps); + } + public TransportServer createServer() { - return new TransportServer(this, 0); + return createServer(0, Lists.newArrayList()); + } + + public TransportChannelHandler initializePipeline(SocketChannel channel) { + return initializePipeline(channel, rpcHandler); } /** @@ -96,17 +105,23 @@ public TransportServer createServer() { * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or * response messages. * + * @param channel The channel to initialize. + * @param channelRpcHandler The RPC handler to use for the channel. + * * @return Returns the created TransportChannelHandler, which includes a TransportClient that can * be used to communicate on this channel. The TransportClient is directly associated with a * ChannelHandler to ensure all users of the same channel get the same TransportClient object. */ - public TransportChannelHandler initializePipeline(SocketChannel channel) { + public TransportChannelHandler initializePipeline( + SocketChannel channel, + RpcHandler channelRpcHandler) { try { - TransportChannelHandler channelHandler = createChannelHandler(channel); + TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) + .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); @@ -122,12 +137,13 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private TransportChannelHandler createChannelHandler(Channel channel) { + private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); - return new TransportChannelHandler(client, responseHandler, requestHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler, + conf.connectionTimeoutMs()); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 37f2e34ceb24..e8e7f06247d3 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,6 +19,7 @@ import java.io.Closeable; import java.io.IOException; +import java.net.SocketAddress; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -79,6 +80,10 @@ public boolean isActive() { return channel.isOpen() || channel.isActive(); } + public SocketAddress getSocketAddress() { + return channel.remoteAddress(); + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java index 65e8020e3412..eaae2ee043c5 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import io.netty.channel.Channel; + /** * A bootstrap which is executed on a TransportClient before it is returned to the user. * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- @@ -28,5 +30,5 @@ */ public interface TransportClientBootstrap { /** Performs the bootstrapping operation, throwing an exception on failure. */ - public void doBootstrap(TransportClient client) throws RuntimeException; + void doBootstrap(TransportClient client, Channel channel) throws RuntimeException; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index d26b9b4d6055..4952ffb44bb8 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -172,12 +172,14 @@ private TransportClient createClient(InetSocketAddress address) throws IOExcepti .option(ChannelOption.ALLOCATOR, pooledAllocator); final AtomicReference clientRef = new AtomicReference(); + final AtomicReference channelRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); clientRef.set(clientHandler.getClient()); + channelRef.set(ch); } }); @@ -192,6 +194,7 @@ public void initChannel(SocketChannel ch) { } TransportClient client = clientRef.get(); + Channel channel = channelRef.get(); assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. @@ -199,7 +202,7 @@ public void initChannel(SocketChannel ch) { logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { - clientBootstrap.doBootstrap(client); + clientBootstrap.doBootstrap(client, channel); } } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2044afb0d85d..94fc21af5e60 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; -import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ + private final AtomicLong timeOfLastRequestNs; + public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingFetches.put(streamChunkId, callback); } @@ -65,6 +70,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingRpcs.put(requestId, callback); } @@ -161,8 +167,12 @@ public void handle(ResponseMessage message) { } /** Returns total number of outstanding requests (fetch requests + rpcs) */ - @VisibleForTesting public int numOutstandingRequests() { return outstandingFetches.size() + outstandingRpcs.size(); } + + /** Returns the time in nanoseconds of when the last request was sent out. */ + public long getTimeOfLastRequestNs() { + return timeOfLastRequestNs.get(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 986957c1509f..f0363830b61a 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -17,7 +17,6 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -53,6 +52,11 @@ public static ChunkFetchFailure decode(ByteBuf buf) { return new ChunkFetchFailure(streamChunkId, errorString); } + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, errorString); + } + @Override public boolean equals(Object other) { if (other instanceof ChunkFetchFailure) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 980947cf13f6..5a173af54f61 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -48,6 +48,11 @@ public static ChunkFetchRequest decode(ByteBuf buf) { return new ChunkFetchRequest(StreamChunkId.decode(buf)); } + @Override + public int hashCode() { + return streamChunkId.hashCode(); + } + @Override public boolean equals(Object other) { if (other instanceof ChunkFetchRequest) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index ff4936470c69..c962fb7ecf76 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -61,6 +61,11 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { return new ChunkFetchSuccess(streamChunkId, managedBuf); } + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, buffer); + } + @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 873c69425094..9162d0b977f8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -20,7 +20,6 @@ import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; /** Provides a canonical set of Encoders for simple types. */ public class Encoders { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a7..0f999f5dfe8d 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 000000000000..d686a951467c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index ebd764eb5eb5..2dfc7876ba32 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -17,7 +17,6 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -51,6 +50,11 @@ public static RpcFailure decode(ByteBuf buf) { return new RpcFailure(requestId, errorString); } + @Override + public int hashCode() { + return Objects.hashCode(requestId, errorString); + } + @Override public boolean equals(Object other) { if (other instanceof RpcFailure) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index cdee0b0e0316..745039db742f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -59,6 +59,11 @@ public static RpcRequest decode(ByteBuf buf) { return new RpcRequest(requestId, message); } + @Override + public int hashCode() { + return Objects.hashCode(requestId, Arrays.hashCode(message)); + } + @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 0a62e09a8115..1671cd444f03 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -50,6 +50,11 @@ public static RpcResponse decode(ByteBuf buf) { return new RpcResponse(requestId, response); } + @Override + public int hashCode() { + return Objects.hashCode(requestId, Arrays.hashCode(response)); + } + @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 000000000000..185ba2ef3bb1 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final boolean encrypt; + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this(conf, appId, secretKeyHolder, false); + } + + public SaslClientBootstrap( + TransportConf conf, + String appId, + SecretKeyHolder secretKeyHolder, + boolean encrypt) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + this.encrypt = encrypt; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client, Channel channel) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); + payload = saslClient.response(response); + } + + if (encrypt) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { + throw new RuntimeException( + new SaslException("Encryption requests by negotiated non-encrypted connection.")); + } + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + saslClient = null; + logger.debug("Channel {} configured for SASL encryption.", client); + } + } finally { + if (saslClient != null) { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java new file mode 100644 index 000000000000..127335e4d35f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.FileRegion; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.NettyUtils; + +/** + * Provides SASL-based encription for transport channels. The single method exposed by this + * class installs the needed channel handlers on a connected channel. + */ +class SaslEncryption { + + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "saslEncryption"; + + /** + * Adds channel handlers that perform encryption / decryption of data using SASL. + * + * @param channel The channel. + * @param backend The SASL backend. + * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control + * memory usage. + */ + static void addToChannel( + Channel channel, + SaslEncryptionBackend backend, + int maxOutboundBlockSize) { + channel.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) + .addFirst("saslDecryption", new DecryptionHandler(backend)) + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); + } + + private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + + private final int maxOutboundBlockSize; + private final SaslEncryptionBackend backend; + + EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) { + this.backend = backend; + this.maxOutboundBlockSize = maxOutboundBlockSize; + } + + /** + * Wrap the incoming message in an implementation that will perform encryption lazily. This is + * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in + * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it + * does not guarantee any ordering. + */ + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + + ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + backend.dispose(); + } finally { + super.handlerRemoved(ctx); + } + } + + } + + private static class DecryptionHandler extends MessageToMessageDecoder { + + private final SaslEncryptionBackend backend; + + DecryptionHandler(SaslEncryptionBackend backend) { + this.backend = backend; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) + throws Exception { + + byte[] data; + int offset; + int length = msg.readableBytes(); + if (msg.hasArray()) { + data = msg.array(); + offset = msg.arrayOffset(); + msg.skipBytes(length); + } else { + data = new byte[length]; + msg.readBytes(data); + offset = 0; + } + + out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length))); + } + + } + + @VisibleForTesting + static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + + private final SaslEncryptionBackend backend; + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + + /** + * A channel used to buffer input data for encryption. The channel has an upper size bound + * so that if the input is larger than the allowed buffer, it will be broken into multiple + * chunks. + */ + private final ByteArrayWritableChannel byteChannel; + + private ByteBuf currentHeader; + private ByteBuffer currentChunk; + private long currentChunkSize; + private long currentReportedBytes; + private long unencryptedChunkSize; + private long transferred; + + EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.backend = backend; + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + } + + /** + * Returns the size of the original (unencrypted) message. + * + * This makes assumptions about how netty treats FileRegion instances, because there's no way + * to know beforehand what will be the size of the encrypted message. Namely, it assumes + * that netty will try to transfer data from this message while + * transfered() < count(). So these two methods return, technically, wrong data, + * but netty doesn't know better. + */ + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + /** + * Returns an approximation of the amount of data transferred. See {@link #count()}. + */ + @Override + public long transfered() { + return transferred; + } + + /** + * Transfers data from the original message to the channel, encrypting it in the process. + * + * This method also breaks down the original message into smaller chunks when needed. This + * is done to keep memory usage under control. This avoids having to copy the whole message + * data into memory at once, and can avoid ballooning memory usage when transferring large + * messages such as shuffle blocks. + * + * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward + * until a whole chunk has been written. This is done because the code can't use the actual + * number of bytes written to the channel as the transferred count (see {@link #count()}). + * Instead, once an encrypted chunk is written to the output (including its header), the + * size of the original block will be added to the {@link #transfered()} amount. + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) + throws IOException { + + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + long reportedWritten = 0L; + long actuallyWritten = 0L; + do { + if (currentChunk == null) { + nextChunk(); + } + + if (currentHeader.readableBytes() > 0) { + int bytesWritten = target.write(currentHeader.nioBuffer()); + currentHeader.skipBytes(bytesWritten); + actuallyWritten += bytesWritten; + if (currentHeader.readableBytes() > 0) { + // Break out of loop if there are still header bytes left to write. + break; + } + } + + actuallyWritten += target.write(currentChunk); + if (!currentChunk.hasRemaining()) { + // Only update the count of written bytes once a full chunk has been written. + // See method javadoc. + long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes; + reportedWritten += chunkBytesRemaining; + transferred += chunkBytesRemaining; + currentHeader.release(); + currentHeader = null; + currentChunk = null; + currentChunkSize = 0; + currentReportedBytes = 0; + } + } while (currentChunk == null && transfered() + reportedWritten < count()); + + // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, + // we return 1 until we can (i.e. until the reported count would actually match the size + // of the current chunk), at which point we resort to returning 0 so that the counts still + // match, at the cost of some performance. That situation should be rare, though. + if (reportedWritten != 0L) { + return reportedWritten; + } + + if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) { + transferred += 1L; + currentReportedBytes += 1L; + return 1L; + } + + return 0L; + } + + private void nextChunk() throws IOException { + byteChannel.reset(); + if (isByteBuf) { + int copied = byteChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteChannel, region.transfered()); + } + + byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); + this.currentChunk = ByteBuffer.wrap(encrypted); + this.currentChunkSize = encrypted.length; + this.currentHeader = Unpooled.copyLong(8 + currentChunkSize); + this.unencryptedChunkSize = byteChannel.length(); + } + + @Override + protected void deallocate() { + if (currentHeader != null) { + currentHeader.release(); + } + if (buf != null) { + buf.release(); + } + if (region != null) { + region.release(); + } + } + + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java new file mode 100644 index 000000000000..89b78bc7e1df --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.sasl.SaslException; + +interface SaslEncryptionBackend { + + /** Disposes of resources used by the backend. */ + void dispose(); + + /** Encrypt data. */ + byte[] wrap(byte[] data, int offset, int len) throws SaslException; + + /** Decrypt data. */ + byte[] unwrap(byte[] data, int offset, int len) throws SaslException; + +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 000000000000..be6165caf3c7 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.sasl.Sasl; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.TransportConf; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +class SaslRpcHandler extends RpcHandler { + private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** Transport configuration. */ + private final TransportConf conf; + + /** The client channel. */ + private final Channel channel; + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + private SparkSaslServer saslServer; + private boolean isComplete; + + SaslRpcHandler( + TransportConf conf, + Channel channel, + RpcHandler delegate, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.channel = channel; + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.saslServer = null; + this.isComplete = false; + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + if (isComplete) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); + } + + byte[] response = saslServer.response(saslMessage.payload); + callback.onSuccess(response); + + // Setup encryption after the SASL response is sent, otherwise the client can't parse the + // response. It's ok to change the channel pipeline here since we are processing an incoming + // message, so the pipeline is busy and no new incoming messages will be fed to it before this + // method returns. This assumes that the code ensures, through other means, that no outbound + // messages are being written to the channel while negotiation is still going on. + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + isComplete = true; + if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("Enabling encryption for channel {}", client); + SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); + saslServer = null; + } else { + saslServer.dispose(); + saslServer = null; + } + } + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + if (saslServer != null) { + saslServer.dispose(); + } + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java new file mode 100644 index 000000000000..f2f983856f44 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.channel.Channel; + +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server. This allows customizing the client channel to allow for things such as SASL + * authentication. + */ +public class SaslServerBootstrap implements TransportServerBootstrap { + + private final TransportConf conf; + private final SecretKeyHolder secretKeyHolder; + + public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL + * negotiation. + */ + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder); + } + +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java similarity index 84% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 9abad1f30a25..94685e91b862 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.util.Map; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; @@ -27,9 +29,9 @@ import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import java.io.IOException; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,19 +42,25 @@ * initial state to the "authenticated" state. This client initializes the protocol via a * firstToken, which is then followed by a set of challenges and responses. */ -public class SparkSaslClient { +public class SparkSaslClient implements SaslEncryptionBackend { private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; + private final String expectedQop; private SaslClient saslClient; - public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH; + + Map saslProps = ImmutableMap.builder() + .put(Sasl.QOP, expectedQop) + .build(); try { this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, - SASL_PROPS, new ClientCallbackHandler()); + saslProps, new ClientCallbackHandler()); } catch (SaslException e) { throw Throwables.propagate(e); } @@ -76,6 +84,11 @@ public synchronized boolean isComplete() { return saslClient != null && saslClient.isComplete(); } + /** Returns the value of a negotiated property. */ + public Object getNegotiatedProperty(String name) { + return saslClient.getNegotiatedProperty(name); + } + /** * Respond to server's SASL token. * @param token contains server's SASL token @@ -93,6 +106,7 @@ public synchronized byte[] response(byte[] token) { * Disposes of any system resources or security-sensitive information the * SaslClient might be using. */ + @Override public synchronized void dispose() { if (saslClient != null) { try { @@ -134,4 +148,15 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback } } } + + @Override + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + return saslClient.wrap(data, offset, len); + } + + @Override + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + return saslClient.unwrap(data, offset, len); + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java similarity index 81% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e87b17ead1e1..431cb67a2ae0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -44,7 +44,7 @@ * initial state to the "authenticated" state. (It is not a server in the sense of accepting * connections on some socket.) */ -public class SparkSaslServer { +public class SparkSaslServer implements SaslEncryptionBackend { private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); /** @@ -60,26 +60,37 @@ public class SparkSaslServer { static final String DIGEST = "DIGEST-MD5"; /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. + * Quality of protection value that includes encryption. */ - static final Map SASL_PROPS = ImmutableMap.builder() - .put(Sasl.QOP, "auth") - .put(Sasl.SERVER_AUTH, "true") - .build(); + static final String QOP_AUTH_CONF = "auth-conf"; + + /** + * Quality of protection value that does not include encryption. + */ + static final String QOP_AUTH = "auth"; /** Identifier for a certain secret key within the secretKeyHolder. */ private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; private SaslServer saslServer; - public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + public SparkSaslServer( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + + // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption + // is listed first since it's preferred over the non-encrypted one (if the client also + // lists both in the request). + String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH); + Map saslProps = ImmutableMap.builder() + .put(Sasl.SERVER_AUTH, "true") + .put(Sasl.QOP, qop) + .build(); try { - this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, new DigestCallbackHandler()); } catch (SaslException e) { throw Throwables.propagate(e); @@ -93,6 +104,11 @@ public synchronized boolean isComplete() { return saslServer != null && saslServer.isComplete(); } + /** Returns the value of a negotiated property. */ + public Object getNegotiatedProperty(String name) { + return saslServer.getNegotiatedProperty(name); + } + /** * Used to respond to server SASL tokens. * @param token Server's SASL token @@ -110,6 +126,7 @@ public synchronized byte[] response(byte[] token) { * Disposes of any system resources or security-sensitive information the * SaslServer might be using. */ + @Override public synchronized void dispose() { if (saslServer != null) { try { @@ -122,6 +139,16 @@ public synchronized void dispose() { } } + @Override + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + return saslServer.wrap(data, offset, len); + } + + @Override + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + return saslServer.unwrap(data, offset, len); + } + /** * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index a6d390e13f39..c95e64e8e2cd 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,14 +20,18 @@ import java.util.Iterator; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import com.google.common.base.Preconditions; + /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually * fetched as chunks by the client. Each registered buffer is one chunk. @@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; - private final Map streams; + private final ConcurrentHashMap streams; /** State of a single stream. */ private static class StreamState { final Iterator buffers; + // The channel associated to the stream + Channel associatedChannel = null; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator buffers) { - this.buffers = buffers; + this.buffers = Preconditions.checkNotNull(buffers); } } @@ -58,6 +65,13 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap(); } + @Override + public void registerChannel(Channel channel, long streamId) { + if (streams.containsKey(streamId)) { + streams.get(streamId).associatedChannel = channel; + } + } + @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -80,12 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } @Override - public void connectionTerminated(long streamId) { - // Release all remaining buffers. - StreamState state = streams.remove(streamId); - if (state != null && state.buffers != null) { - while (state.buffers.hasNext()) { - state.buffers.next().release(); + public void connectionTerminated(Channel channel) { + // Close all streams which have been associated with the channel. + for (Map.Entry entry: streams.entrySet()) { + StreamState state = entry.getValue(); + if (state.associatedChannel == channel) { + streams.remove(entry.getKey()); + + // Release all remaining buffers. + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 5a9a14a180c1..929f789bf9d2 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import io.netty.channel.Channel; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -44,9 +46,18 @@ public abstract class StreamManager { public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); /** - * Indicates that the TCP connection that was tied to the given stream has been terminated. After - * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned - * up. + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. The getChunk() method will be called serially on this connection and once the + * connection is closed, the stream will never be used again, enabling cleanup. + * + * This must be called before the first getChunk() on the stream, but it may be invoked multiple + * times with the same channel and stream id. + */ + public void registerChannel(Channel channel, long streamId) { } + + /** + * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not + * to read from the associated streams again, so any state can be cleaned up. */ - public void connectionTerminated(long streamId) { } + public void connectionTerminated(Channel channel) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index e491367fa452..8e0ee709e38e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -19,6 +19,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +42,11 @@ * Client. * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. + * + * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. + * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic + * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not + * timeout if the client is continuously sending but getting no responses, for simplicity. */ public class TransportChannelHandler extends SimpleChannelInboundHandler { private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); @@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 1580180cc17e..e5159ab56d0d 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,10 +17,7 @@ package org.apache.spark.network.server; -import java.util.Set; - import com.google.common.base.Throwables; -import com.google.common.collect.Sets; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; - /** List of all stream ids that have been read on this handler, used for cleanup. */ - private final Set streamIds; - public TransportRequestHandler( Channel channel, TransportClient reverseClient, @@ -73,7 +67,6 @@ public TransportRequestHandler( this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); - this.streamIds = Sets.newHashSet(); } @Override @@ -82,10 +75,7 @@ public void exceptionCaught(Throwable cause) { @Override public void channelUnregistered() { - // Inform the StreamManager that these streams will no longer be read from. - for (long streamId : streamIds) { - streamManager.connectionTerminated(streamId); - } + streamManager.connectionTerminated(channel); rpcHandler.connectionTerminated(reverseClient); } @@ -102,12 +92,12 @@ public void handle(RequestMessage request) { private void processFetchRequest(final ChunkFetchRequest req) { final String client = NettyUtils.getRemoteAddress(channel); - streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); ManagedBuffer buf; try { + streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format( diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 625c3257d764..f4fadb1ee3b8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -19,8 +19,11 @@ import java.io.Closeable; import java.net.InetSocketAddress; +import java.util.List; import java.util.concurrent.TimeUnit; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelFuture; @@ -28,7 +31,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; -import io.netty.util.internal.PlatformDependent; +import org.apache.spark.network.util.JavaUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,17 +48,30 @@ public class TransportServer implements Closeable { private final TransportContext context; private final TransportConf conf; + private final RpcHandler appRpcHandler; + private final List bootstraps; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port = -1; /** Creates a TransportServer that binds to the given port, or to any available if 0. */ - public TransportServer(TransportContext context, int portToBind) { + public TransportServer( + TransportContext context, + int portToBind, + RpcHandler appRpcHandler, + List bootstraps) { this.context = context; this.conf = context.getConf(); - - init(portToBind); + this.appRpcHandler = appRpcHandler; + this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); + + try { + init(portToBind); + } catch (RuntimeException e) { + JavaUtils.closeQuietly(this); + throw e; + } } public int getPort() { @@ -96,7 +112,11 @@ private void init(int portToBind) { bootstrap.childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel ch) throws Exception { - context.initializePipeline(ch); + RpcHandler rpcHandler = appRpcHandler; + for (TransportServerBootstrap bootstrap : bootstraps) { + rpcHandler = bootstrap.doBootstrap(ch, rpcHandler); + } + context.initializePipeline(ch, rpcHandler); } }); @@ -122,5 +142,4 @@ public void close() { } bootstrap = null; } - } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java new file mode 100644 index 000000000000..05803ab1bb05 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.Channel; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server. This allows customizing the client channel to allow for things such as SASL + * authentication. + */ +public interface TransportServerBootstrap { + /** + * Customizes the channel to include new features, if needed. + * + * @param channel The connected channel opened by the client. + * @param rpcHandler The RPC handler for the server. + * @return The RPC handler to use for the channel. + */ + RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler); +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java new file mode 100644 index 000000000000..b1415720045e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +/** + * A writable channel that stores the written data in a byte array in memory. + */ +public class ByteArrayWritableChannel implements WritableByteChannel { + + private final byte[] data; + private int offset; + + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + } + + public byte[] getData() { + return data; + } + + public int length() { + return offset; + } + + /** Resets the channel so that writing to it will overwrite the existing buffer. */ + public void reset() { + offset = 0; + } + + /** + * Reads from the given buffer into the internal byte array. + */ + @Override + public int write(ByteBuffer src) { + int toTransfer = Math.min(src.remaining(), data.length - offset); + src.get(data, offset, toTransfer); + offset += toTransfer; + return toTransfer; + } + + @Override + public void close() { + + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 000000000000..36d655017fb0 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index bf8a1fc42fc6..7d27439cfde7 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,19 +17,17 @@ package org.apache.spark.network.util; -import java.nio.ByteBuffer; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.File; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; -import com.google.common.base.Preconditions; -import com.google.common.io.Closeables; import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +39,12 @@ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + /** + * Define a default value for driver memory here since this value is referenced across the code + * base and nearly all files already use Utils.scala + */ + public static final long DEFAULT_DRIVER_MEM_MB = 1024; + /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { @@ -127,4 +131,157 @@ private static boolean isSymlink(File file) throws IOException { } return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } + + private static final ImmutableMap timeSuffixes = + ImmutableMap.builder() + .put("us", TimeUnit.MICROSECONDS) + .put("ms", TimeUnit.MILLISECONDS) + .put("s", TimeUnit.SECONDS) + .put("m", TimeUnit.MINUTES) + .put("min", TimeUnit.MINUTES) + .put("h", TimeUnit.HOURS) + .put("d", TimeUnit.DAYS) + .build(); + + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + + /** + * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for + * internal use. If no suffix is provided a direct conversion is attempted. + */ + private static long parseTimeString(String str, TimeUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); + if (!m.matches()) { + throw new NumberFormatException("Failed to parse time string: " + str); + } + + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !timeSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); + } catch (NumberFormatException e) { + String timeError = "Time must be specified as seconds (s), " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + + "E.g. 50s, 100ms, or 250us."; + + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + public static long timeStringAsMs(String str) { + return parseTimeString(str, TimeUnit.MILLISECONDS); + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + public static long timeStringAsSec(String str) { + return parseTimeString(str, TimeUnit.SECONDS); + } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java new file mode 100644 index 000000000000..668d2356b955 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.collect.Maps; + +import java.util.Map; +import java.util.NoSuchElementException; + +/** ConfigProvider based on a Map (copied in the constructor). */ +public class MapConfigProvider extends ConfigProvider { + private final Map config; + + public MapConfigProvider(Map config) { + this.config = Maps.newHashMap(config); + } + + @Override + public String get(String name) { + String value = config.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 2a4b88b64cdc..26c6399ce7db 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -25,7 +25,6 @@ import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; -import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollServerSocketChannel; import io.netty.channel.epoll.EpollSocketChannel; @@ -99,7 +98,7 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } - /** Returns the remote address on the channel or "<remote address>" if none exists. */ + /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { return channel.remoteAddress().toString(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6c9178688693..3b2eff377955 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import com.google.common.primitives.Ints; + /** * A central location that tracks all the settings we expose to users. */ @@ -37,8 +39,11 @@ public boolean preferDirectBufs() { /** Connect timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { - int defaultTimeout = conf.getInt("spark.network.timeout", 120); - return conf.getInt("spark.shuffle.io.connectionTimeout", defaultTimeout) * 1000; + long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( + conf.get("spark.network.timeout", "120s")); + long defaultTimeoutMs = JavaUtils.timeStringAsSec( + conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ @@ -68,7 +73,9 @@ public int numConnectionsPerPeer() { public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { return conf.getInt("spark.shuffle.sasl.timeout", 30) * 1000; } + public int saslRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. @@ -80,7 +87,9 @@ public int numConnectionsPerPeer() { * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ - public int ioRetryWaitTimeMs() { return conf.getInt("spark.shuffle.io.retryWait", 5) * 1000; } + public int ioRetryWaitTimeMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + } /** * Minimum size of a block that we should start using memory map rather than reading in through @@ -98,4 +107,27 @@ public int memoryMapBytes() { public boolean lazyFileDescriptor() { return conf.getBoolean("spark.shuffle.io.lazyFD", true); } + + /** + * Maximum number of retries when binding to a port before giving up. + */ + public int portMaxRetries() { + return conf.getInt("spark.port.maxRetries", 16); + } + + /** + * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + */ + public int maxSaslEncryptedBlockSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + } + + /** + * Whether the server should enforce encryption on SASL-authenticated connections. + */ + public boolean saslServerAlwaysEncrypt() { + return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c719..d500bc3c98a7 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,35 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +60,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +93,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java new file mode 100644 index 000000000000..84ebb337e6d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.*; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Suite which ensures that requests that go without a response for the network timeout period are + * failed, and the connection closed. + * + * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, + * to ensure stability in different test environments. + */ +public class RequestTimeoutIntegrationSuite { + + private TransportServer server; + private TransportClientFactory clientFactory; + + private StreamManager defaultManager; + private TransportConf conf; + + // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. + private final int FOREVER = 60 * 1000; + + @Before + public void setUp() throws Exception { + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.connectionTimeout", "2s"); + conf = new TransportConf(new MapConfigProvider(configMap)); + + defaultManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + }; + } + + @After + public void tearDown() { + if (server != null) { + server.close(); + } + if (clientFactory != null) { + clientFactory.close(); + } + } + + // Basic suite: First request completes quickly, and second waits for longer than network timeout. + @Test + public void timeoutInactiveRequests() throws Exception { + final Semaphore semaphore = new Semaphore(1); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // First completes quickly (semaphore starts at 1). + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.success.length == response.length); + } + + // Second times out after 2 seconds, with slack. Must be IOException. + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client.sendRpc(new byte[0], callback1); + callback1.wait(4 * 1000); + assert (callback1.failure != null); + assert (callback1.failure instanceof IOException); + } + semaphore.release(); + } + + // A timeout will cause the connection to be closed, invalidating the current TransportClient. + // It should be the case that requesting a client from the factory produces a new, valid one. + @Test + public void timeoutCleanlyClosesClient() throws Exception { + final Semaphore semaphore = new Semaphore(0); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + + // First request should eventually fail. + TransportClient client0 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client0.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.failure instanceof IOException); + assert (!client0.isActive()); + } + + // Increment the semaphore and the second request should succeed quickly. + semaphore.release(2); + TransportClient client1 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client1.sendRpc(new byte[0], callback1); + callback1.wait(FOREVER); + assert (callback1.success.length == response.length); + assert (callback1.failure == null); + } + } + + // The timeout is relative to the LAST request sent, which is kinda weird, but still. + // This test also makes sure the timeout works for Fetch requests as well as RPCs. + @Test + public void furtherRequestsDelay() throws Exception { + final byte[] response = new byte[16]; + final StreamManager manager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); + return new NioManagedBuffer(ByteBuffer.wrap(response)); + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return manager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // Send one request, which will eventually fail. + TestCallback callback0 = new TestCallback(); + client.fetchChunk(0, 0, callback0); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // Send a second request before the first has failed. + TestCallback callback1 = new TestCallback(); + client.fetchChunk(0, 1, callback1); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + synchronized (callback0) { + // not complete yet, but should complete soon + assert (callback0.success == null && callback0.failure == null); + callback0.wait(2 * 1000); + assert (callback0.failure instanceof IOException); + } + + synchronized (callback1) { + // failed at same time as previous + assert (callback0.failure instanceof IOException); + } + } + + /** + * Callback which sets 'success' or 'failure' on completion. + * Additionally notifies all waiters on this callback when invoked. + */ + class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + + byte[] success; + Throwable failure; + + @Override + public void onSuccess(byte[] response) { + synchronized(this) { + success = response; + this.notifyAll(); + } + } + + @Override + public void onFailure(Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + synchronized(this) { + try { + success = buffer.nioByteBuffer().array(); + this.notifyAll(); + } catch (IOException e) { + // weird + } + } + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 38113a918f79..83c90f9eff2b 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -80,6 +80,11 @@ public Object convertToNetty() throws IOException { return underlying.convertToNetty(); } + @Override + public int hashCode() { + return underlying.hashCode(); + } + @Override public boolean equals(Object other) { if (other instanceof ManagedBuffer) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 416dc1b969fa..35de5e57ccb9 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.NoSuchElementException; +import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.Maps; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -36,9 +37,9 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -70,16 +71,10 @@ public void tearDown() { */ private void testClientReuse(final int maxConnections, boolean concurrent) throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { - @Override - public String get(String name) { - if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) { - return Integer.toString(maxConnections); - } else { - throw new NoSuchElementException(); - } - } - }); + + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); + TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 000000000000..6c98e733b462 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.util.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 000000000000..8104004847a2 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.security.sasl.SaslException; + +import com.google.common.collect.Lists; +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } + + @Test + public void testSaslAuthentication() throws Exception { + testBasicSasl(false); + } + + @Test + public void testSaslEncryption() throws Exception { + testBasicSasl(true); + } + + private void testBasicSasl(boolean encrypt) throws Exception { + RpcHandler rpcHandler = mock(RpcHandler.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + byte[] message = (byte[]) invocation.getArguments()[1]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + assertEquals("Ping", new String(message, StandardCharsets.UTF_8)); + cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8)); + return null; + } + }) + .when(rpcHandler) + .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + try { + byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); + } finally { + ctx.close(); + } + } + + @Test + public void testEncryptedMessage() throws Exception { + SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); + byte[] data = new byte[1024]; + new Random().nextBytes(data); + when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); + + ByteBuf msg = Unpooled.buffer(); + try { + msg.writeBytes(data); + + // Create a channel with a really small buffer compared to the data. This means that on each + // call, the outbound data will not be fully written, so the write() method should return a + // dummy count to keep the channel alive when possible. + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32); + + SaslEncryption.EncryptedMessage emsg = + new SaslEncryption.EncryptedMessage(backend, msg, 1024); + long count = emsg.transferTo(channel, emsg.transfered()); + assertTrue(count < data.length); + assertTrue(count > 0); + + // Here, the output buffer is full so nothing should be transferred. + assertEquals(0, emsg.transferTo(channel, emsg.transfered())); + + // Now there's room in the buffer, but not enough to transfer all the remaining data, + // so the dummy count should be returned. + channel.reset(); + assertEquals(1, emsg.transferTo(channel, emsg.transfered())); + + // Eventually, the whole message should be transferred. + for (int i = 0; i < data.length / 32 - 2; i++) { + channel.reset(); + assertEquals(1, emsg.transferTo(channel, emsg.transfered())); + } + + channel.reset(); + count = emsg.transferTo(channel, emsg.transfered()); + assertTrue("Unexpected count: " + count, count > 1 && count < data.length); + assertEquals(data.length, emsg.transfered()); + } finally { + msg.release(); + } + } + + @Test + public void testEncryptedMessageChunking() throws Exception { + File file = File.createTempFile("sasltest", ".txt"); + try { + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + byte[] data = new byte[8 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); + // It doesn't really matter what we return here, as long as it's not null. + when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); + + FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length()); + SaslEncryption.EncryptedMessage emsg = + new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + while (emsg.transfered() < emsg.count()) { + channel.reset(); + emsg.transferTo(channel, emsg.transfered()); + } + + verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt()); + } finally { + file.delete(); + } + } + + @Test + public void testFileRegionEncryption() throws Exception { + final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; + System.setProperty(blockSizeConf, "1k"); + + final AtomicReference response = new AtomicReference<>(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(conf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[8 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + System.clearProperty(blockSizeConf); + } + } + + @Test + public void testServerAlwaysEncrypt() throws Exception { + final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt"; + System.setProperty(alwaysEncryptConfName, "true"); + + SaslTestCtx ctx = null; + try { + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + fail("Should have failed to connect without encryption."); + } catch (Exception e) { + assertTrue(e.getCause() instanceof SaslException); + } finally { + if (ctx != null) { + ctx.close(); + } + System.clearProperty(alwaysEncryptConfName); + } + } + + @Test + public void testDataEncryptionIsActuallyEnabled() throws Exception { + // This test sets up an encrypted connection but then, using a client bootstrap, removes + // the encryption handler from the client side. This should cause the server to not be + // able to understand RPCs sent to it and thus close the connection. + SaslTestCtx ctx = null; + try { + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), + TimeUnit.SECONDS.toMillis(10)); + fail("Should have failed to send RPC to server."); + } catch (Exception e) { + assertFalse(e.getCause() instanceof TimeoutException); + } finally { + if (ctx != null) { + ctx.close(); + } + } + } + + private static class SaslTestCtx { + + final TransportClient client; + final TransportServer server; + + private final boolean encrypt; + private final boolean disableClientEncryption; + private final EncryptionCheckerBootstrap checker; + + SaslTestCtx( + RpcHandler rpcHandler, + boolean encrypt, + boolean disableClientEncryption) + throws Exception { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); + when(keyHolder.getSaslUser(anyString())).thenReturn("user"); + when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); + + TransportContext ctx = new TransportContext(conf, rpcHandler); + + this.checker = new EncryptionCheckerBootstrap(); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), + checker)); + + try { + List clientBootstraps = Lists.newArrayList(); + clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); + if (disableClientEncryption) { + clientBootstraps.add(new EncryptionDisablerBootstrap()); + } + + this.client = ctx.createClientFactory(clientBootstraps) + .createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + close(); + throw e; + } + + this.encrypt = encrypt; + this.disableClientEncryption = disableClientEncryption; + } + + void close() { + if (!disableClientEncryption) { + assertEquals(encrypt, checker.foundEncryptionHandler); + } + if (client != null) { + client.close(); + } + if (server != null) { + server.close(); + } + } + + } + + private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter + implements TransportServerBootstrap { + + boolean foundEncryptionHandler; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (!foundEncryptionHandler) { + foundEncryptionHandler = + ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + } + ctx.write(msg, promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + } + + @Override + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + channel.pipeline().addFirst("encryptionChecker", this); + return rpcHandler; + } + + } + + private static class EncryptionDisablerBootstrap implements TransportClientBootstrap { + + @Override + public void doBootstrap(TransportClient client, Channel channel) { + channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME); + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 000000000000..e8da774f7ca9 --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index c2d0300ecd90..3d2edf9d9451 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -43,6 +43,22 @@ ${project.version} + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + org.slf4j @@ -79,7 +95,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java deleted file mode 100644 index 33aa1344345f..000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.util.TransportConf; - -/** - * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The - * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. - */ -public class SaslClientBootstrap implements TransportClientBootstrap { - private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); - - private final TransportConf conf; - private final String appId; - private final SecretKeyHolder secretKeyHolder; - - public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { - this.conf = conf; - this.appId = appId; - this.secretKeyHolder = secretKeyHolder; - } - - /** - * Performs SASL authentication by sending a token, and then proceeding with the SASL - * challenge-response tokens until we either successfully authenticate or throw an exception - * due to mismatch. - */ - @Override - public void doBootstrap(TransportClient client) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); - try { - byte[] payload = saslClient.firstToken(); - - while (!saslClient.isComplete()) { - SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); - msg.encode(buf); - - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); - } - } finally { - try { - // Once authentication is complete, the server will trust all remaining communication. - saslClient.dispose(); - } catch (RuntimeException e) { - logger.error("Error while disposing SASL client", e); - } - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java deleted file mode 100644 index 3777a18e33f7..000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.util.concurrent.ConcurrentMap; - -import com.google.common.base.Charsets; -import com.google.common.collect.Maps; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; - -/** - * RPC Handler which performs SASL authentication before delegating to a child RPC handler. - * The delegate will only receive messages if the given connection has been successfully - * authenticated. A connection may be authenticated at most once. - * - * Note that the authentication process consists of multiple challenge-response pairs, each of - * which are individual RPCs. - */ -public class SaslRpcHandler extends RpcHandler { - private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); - - /** RpcHandler we will delegate to for authenticated connections. */ - private final RpcHandler delegate; - - /** Class which provides secret keys which are shared by server and client on a per-app basis. */ - private final SecretKeyHolder secretKeyHolder; - - /** Maps each channel to its SASL authentication state. */ - private final ConcurrentMap channelAuthenticationMap; - - public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { - this.delegate = delegate; - this.secretKeyHolder = secretKeyHolder; - this.channelAuthenticationMap = Maps.newConcurrentMap(); - } - - @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - SparkSaslServer saslServer = channelAuthenticationMap.get(client); - if (saslServer != null && saslServer.isComplete()) { - // Authentication complete, delegate to base handler. - delegate.receive(client, message, callback); - return; - } - - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); - channelAuthenticationMap.put(client, saslServer); - } - - byte[] response = saslServer.response(saslMessage.payload); - if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - } - callback.onSuccess(response); - } - - @Override - public StreamManager getStreamManager() { - return delegate.getStreamManager(); - } - - @Override - public void connectionTerminated(TransportClient client) { - SparkSaslServer saslServer = channelAuthenticationMap.remove(client); - if (saslServer != null) { - saslServer.dispose(); - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 46ca9708621b..0df1dd621f6e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -17,11 +17,12 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.IOException; import java.util.List; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; -import org.apache.spark.network.util.TransportConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,10 +32,10 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; +import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.TransportConf; + /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -46,18 +47,20 @@ public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); - private final ExternalShuffleBlockManager blockManager; + @VisibleForTesting + final ExternalShuffleBlockResolver blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler(TransportConf conf) { - this(new OneForOneStreamManager(), new ExternalShuffleBlockManager(conf)); + public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { + this(new OneForOneStreamManager(), + new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); } /** Enables mocking out the StreamManager and BlockManager. */ @VisibleForTesting ExternalShuffleBlockHandler( OneForOneStreamManager streamManager, - ExternalShuffleBlockManager blockManager) { + ExternalShuffleBlockResolver blockManager) { this.streamManager = streamManager; this.blockManager = blockManager; } @@ -65,7 +68,13 @@ public ExternalShuffleBlockHandler(TransportConf conf) { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + handleMessage(msgObj, client, callback); + } + protected void handleMessage( + BlockTransferMessage msgObj, + TransportClient client, + RpcResponseCallback callback) { if (msgObj instanceof OpenBlocks) { OpenBlocks msg = (OpenBlocks) msgObj; List blocks = Lists.newArrayList(); @@ -99,4 +108,22 @@ public StreamManager getStreamManager() { public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + + /** + * Register an (application, executor) with the given shuffle info. + * + * The "re-" is meant to highlight the intended use of this method -- when this service is + * restarted, this is used to restore the state of executors from before the restart. Normal + * registration will happen via a message handled in receive() + * + * @param appExecId + * @param executorInfo + */ + public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) { + blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo); + } + + public void close() { + blockManager.close(); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java deleted file mode 100644 index 93e6fdd7161f..000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.DataInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.Iterator; -import java.util.Map; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Objects; -import com.google.common.collect.Maps; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * Manages converting shuffle BlockIds into physical segments of local files, from a process outside - * of Executors. Each Executor must register its own configuration about where it stores its files - * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated - * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockManager. - */ -public class ExternalShuffleBlockManager { - private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); - - // Map containing all registered executors' metadata. - private final ConcurrentMap executors; - - // Single-threaded Java executor used to perform expensive recursive directory deletion. - private final Executor directoryCleaner; - - private final TransportConf conf; - - public ExternalShuffleBlockManager(TransportConf conf) { - this(conf, Executors.newSingleThreadExecutor( - // Add `spark` prefix because it will run in NM in Yarn mode. - NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner"))); - } - - // Allows tests to have more control over when directories are cleaned up. - @VisibleForTesting - ExternalShuffleBlockManager(TransportConf conf, Executor directoryCleaner) { - this.conf = conf; - this.executors = Maps.newConcurrentMap(); - this.directoryCleaner = directoryCleaner; - } - - /** Registers a new Executor with all the configuration we need to find its shuffle files. */ - public void registerExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - AppExecId fullId = new AppExecId(appId, execId); - logger.info("Registered executor {} with {}", fullId, executorInfo); - executors.put(fullId, executorInfo); - } - - /** - * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the - * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make - * assumptions about how the hash and sort based shuffles store their data. - */ - public ManagedBuffer getBlockData(String appId, String execId, String blockId) { - String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length < 4) { - throw new IllegalArgumentException("Unexpected block id format: " + blockId); - } else if (!blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); - } - int shuffleId = Integer.parseInt(blockIdParts[1]); - int mapId = Integer.parseInt(blockIdParts[2]); - int reduceId = Integer.parseInt(blockIdParts[3]); - - ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); - if (executor == null) { - throw new RuntimeException( - String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); - } - - if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); - } else { - throw new UnsupportedOperationException( - "Unsupported shuffle manager: " + executor.shuffleManager); - } - } - - /** - * Removes our metadata of all executors registered for the given application, and optionally - * also deletes the local directories associated with the executors of that application in a - * separate thread. - * - * It is not valid to call registerExecutor() for an executor with this appId after invoking - * this method. - */ - public void applicationRemoved(String appId, boolean cleanupLocalDirs) { - logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); - Iterator> it = executors.entrySet().iterator(); - while (it.hasNext()) { - Map.Entry entry = it.next(); - AppExecId fullId = entry.getKey(); - final ExecutorShuffleInfo executor = entry.getValue(); - - // Only touch executors associated with the appId that was removed. - if (appId.equals(fullId.appId)) { - it.remove(); - - if (cleanupLocalDirs) { - logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); - - // Execute the actual deletion in a different thread, as it may take some time. - directoryCleaner.execute(new Runnable() { - @Override - public void run() { - deleteExecutorDirs(executor.localDirs); - } - }); - } - } - } - } - - /** - * Synchronously deletes each directory one at a time. - * Should be executed in its own thread, as this may take a long time. - */ - private void deleteExecutorDirs(String[] dirs) { - for (String localDir : dirs) { - try { - JavaUtils.deleteRecursively(new File(localDir)); - logger.debug("Successfully cleaned up directory: " + localDir); - } catch (Exception e) { - logger.error("Failed to delete directory: " + localDir, e); - } - } - } - - /** - * Hash-based shuffle data is simply stored as one file per block. - * This logic is from FileShuffleBlockManager. - */ - // TODO: Support consolidated hash shuffle files - private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { - File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); - return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); - } - - /** - * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file - * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager, - * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. - */ - private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { - File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.index"); - - DataInputStream in = null; - try { - in = new DataInputStream(new FileInputStream(indexFile)); - in.skipBytes(reduceId * 8); - long offset = in.readLong(); - long nextOffset = in.readLong(); - return new FileSegmentManagedBuffer( - conf, - getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.data"), - offset, - nextOffset - offset); - } catch (IOException e) { - throw new RuntimeException("Failed to open file: " + indexFile, e); - } finally { - if (in != null) { - JavaUtils.closeQuietly(in); - } - } - } - - /** - * Hashes a filename into the corresponding local directory, in a manner consistent with - * Spark's DiskBlockManager.getFile(). - */ - @VisibleForTesting - static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { - int hash = JavaUtils.nonNegativeHash(filename); - String localDir = localDirs[hash % localDirs.length]; - int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); - } - - /** Simply encodes an executor's full ID, which is appId + execId. */ - private static class AppExecId { - final String appId; - final String execId; - - private AppExecId(String appId, String execId) { - this.appId = appId; - this.execId = execId; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - AppExecId appExecId = (AppExecId) o; - return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .toString(); - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java new file mode 100644 index 000000000000..79beec4429a9 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.*; +import java.util.*; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBIterator; +import org.iq80.leveldb.Options; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Manages converting shuffle BlockIds into physical segments of local files, from a process outside + * of Executors. Each Executor must register its own configuration about where it stores its files + * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated + * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. + * + * Executors with shuffle file consolidation are not currently supported, as the index is stored in + * the Executor's memory, unlike the IndexShuffleBlockResolver. + */ +public class ExternalShuffleBlockResolver { + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); + + private static final ObjectMapper mapper = new ObjectMapper(); + /** + * This a common prefix to the key for each app registration we stick in leveldb, so they + * are easy to find, since leveldb lets you search based on prefix. + */ + private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; + private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + + // Map containing all registered executors' metadata. + @VisibleForTesting + final ConcurrentMap executors; + + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + private final TransportConf conf; + + @VisibleForTesting + final File registeredExecutorFile; + @VisibleForTesting + final DB db; + + public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) + throws IOException { + this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( + // Add `spark` prefix because it will run in NM in Yarn mode. + NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner"))); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockResolver( + TransportConf conf, + File registeredExecutorFile, + Executor directoryCleaner) throws IOException { + this.conf = conf; + this.registeredExecutorFile = registeredExecutorFile; + if (registeredExecutorFile != null) { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + DB tmpDb; + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + logger.info("Creating state database at " + registeredExecutorFile); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new + // one, so we can keep processing new apps + logger.error("error opening leveldb file {}. Creating new file, will not be able to " + + "recover state for existing applications", registeredExecutorFile, e); + if (registeredExecutorFile.isDirectory()) { + for (File f : registeredExecutorFile.listFiles()) { + f.delete(); + } + } + registeredExecutorFile.delete(); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + + } + } + // if there is a version mismatch, we throw an exception, which means the service is unusable + checkVersion(tmpDb); + executors = reloadRegisteredExecutors(tmpDb); + db = tmpDb; + } else { + db = null; + executors = Maps.newConcurrentMap(); + } + this.directoryCleaner = directoryCleaner; + } + + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ + public void registerExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + AppExecId fullId = new AppExecId(appId, execId); + logger.info("Registered executor {} with {}", fullId, executorInfo); + try { + if (db != null) { + byte[] key = dbAppExecKey(fullId); + byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); + db.put(key, value); + } + } catch (Exception e) { + logger.error("Error saving registered executors", e); + } + executors.put(fullId, executorInfo); + } + + /** + * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the + * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make + * assumptions about how the hash and sort based shuffles store their data. + */ + public ManagedBuffer getBlockData(String appId, String execId, String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length < 4) { + throw new IllegalArgumentException("Unexpected block id format: " + blockId); + } else if (!blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); + } + int shuffleId = Integer.parseInt(blockIdParts[1]); + int mapId = Integer.parseInt(blockIdParts[2]); + int reduceId = Integer.parseInt(blockIdParts[3]); + + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + + if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) + || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else { + throw new UnsupportedOperationException( + "Unsupported shuffle manager: " + executor.shuffleManager); + } + } + + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + if (db != null) { + try { + db.delete(dbAppExecKey(fullId)); + } catch (IOException e) { + logger.error("Error deleting {} from executor state db", appId, e); + } + } + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + + /** + * Hash-based shuffle data is simply stored as one file per block. + * This logic is from FileShuffleBlockResolver. + */ + // TODO: Support consolidated hash shuffle files + private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { + File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); + } + + /** + * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file + * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, + * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + */ + private ManagedBuffer getSortBasedShuffleBlockData( + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + + DataInputStream in = null; + try { + in = new DataInputStream(new FileInputStream(indexFile)); + in.skipBytes(reduceId * 8); + long offset = in.readLong(); + long nextOffset = in.readLong(); + return new FileSegmentManagedBuffer( + conf, + getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + offset, + nextOffset - offset); + } catch (IOException e) { + throw new RuntimeException("Failed to open file: " + indexFile, e); + } finally { + if (in != null) { + JavaUtils.closeQuietly(in); + } + } + } + + /** + * Hashes a filename into the corresponding local directory, in a manner consistent with + * Spark's DiskBlockManager.getFile(). + */ + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + int hash = JavaUtils.nonNegativeHash(filename); + String localDir = localDirs[hash % localDirs.length]; + int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; + return new File(new File(localDir, String.format("%02x", subDirId)), filename); + } + + void close() { + if (db != null) { + try { + db.close(); + } catch (IOException e) { + logger.error("Exception closing leveldb with registered executors", e); + } + } + } + + /** Simply encodes an executor's full ID, which is appId + execId. */ + public static class AppExecId { + public final String appId; + public final String execId; + + @JsonCreator + public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } + } + + private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { + // we stick a common prefix on all the keys so we can find them in the DB + String appExecJson = mapper.writeValueAsString(appExecId); + String key = (APP_KEY_PREFIX + ";" + appExecJson); + return key.getBytes(Charsets.UTF_8); + } + + private static AppExecId parseDbAppExecKey(String s) throws IOException { + if (!s.startsWith(APP_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX); + } + String json = s.substring(APP_KEY_PREFIX.length() + 1); + AppExecId parsed = mapper.readValue(json, AppExecId.class); + return parsed; + } + + @VisibleForTesting + static ConcurrentMap reloadRegisteredExecutors(DB db) + throws IOException { + ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); + if (db != null) { + DBIterator itr = db.iterator(); + itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), Charsets.UTF_8); + if (!key.startsWith(APP_KEY_PREFIX)) { + break; + } + AppExecId id = parseDbAppExecKey(key); + ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); + registeredExecutors.put(id, shuffleInfo); + } + } + return registeredExecutors; + } + + private static class LevelDBLogger implements org.iq80.leveldb.Logger { + private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + /** + * Simple major.minor versioning scheme. Any incompatible changes should be across major + * versions. Minor version differences are allowed -- meaning we should be able to read + * dbs that are either earlier *or* later on the minor version. + */ + private static void checkVersion(DB db) throws IOException { + byte[] bytes = db.get(StoreVersion.KEY); + if (bytes == null) { + storeVersion(db); + } else { + StoreVersion version = mapper.readValue(bytes, StoreVersion.class); + if (version.major != CURRENT_VERSION.major) { + throw new IOException("cannot read state DB with version " + version + ", incompatible " + + "with current version " + CURRENT_VERSION); + } + storeVersion(db); + } + } + + private static void storeVersion(DB db) throws IOException { + db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); + } + + + public static class StoreVersion { + + final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + + public final int major; + public final int minor; + + @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StoreVersion that = (StoreVersion) o; + + return major == that.major && minor == that.minor; + } + + @Override + public int hashCode() { + int result = major; + result = 31 * result + minor; + return result; + } + } + +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6e8018b723dc..ea6d248d66be 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,10 +47,11 @@ public class ExternalShuffleClient extends ShuffleClient { private final TransportConf conf; private final boolean saslEnabled; + private final boolean saslEncryptionEnabled; private final SecretKeyHolder secretKeyHolder; - private TransportClientFactory clientFactory; - private String appId; + protected TransportClientFactory clientFactory; + protected String appId; /** * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, @@ -58,10 +60,19 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled) { + boolean saslEnabled, + boolean saslEncryptionEnabled) { + Preconditions.checkArgument( + !saslEncryptionEnabled || saslEnabled, + "SASL encryption can only be enabled if SASL is also enabled."); this.conf = conf; this.secretKeyHolder = secretKeyHolder; this.saslEnabled = saslEnabled; + this.saslEncryptionEnabled = saslEncryptionEnabled; + } + + protected void checkInit() { + assert appId != null : "Called before init()"; } @Override @@ -70,7 +81,7 @@ public void init(String appId) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); List bootstraps = Lists.newArrayList(); if (saslEnabled) { - bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); } clientFactory = context.createClientFactory(bootstraps); } @@ -82,7 +93,7 @@ public void fetchBlocks( final String execId, String[] blockIds, BlockFetchingListener listener) { - assert appId != null : "Called before init()"; + checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = @@ -125,7 +136,7 @@ public void registerWithShuffleServer( int port, String execId, ExecutorShuffleInfo executorInfo) throws IOException { - assert appId != null : "Called before init()"; + checkInit(); TransportClient client = clientFactory.createClient(host, port); byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 8ed2e0b39ad2..e653f5cb147e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -29,7 +29,6 @@ import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; -import org.apache.spark.network.util.JavaUtils; /** * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java new file mode 100644 index 000000000000..7543b6be4f2a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.mesos; + +import java.io.IOException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.shuffle.ExternalShuffleClient; +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.util.TransportConf; + +/** + * A client for talking to the external shuffle service in Mesos coarse-grained mode. + * + * This is used by the Spark driver to register with each external shuffle service on the cluster. + * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably + * after the application exits. Mesos does not provide a great alternative to do this, so Spark + * has to detect this itself. + */ +public class MesosExternalShuffleClient extends ExternalShuffleClient { + private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); + + /** + * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. + * Please refer to docs on {@link ExternalShuffleClient} for more information. + */ + public MesosExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled, + boolean saslEncryptionEnabled) { + super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); + } + + public void registerDriverWithShuffleService(String host, int port) throws IOException { + checkInit(); + byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerDriver, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + + "Please manually remove shuffle data after driver exit. Error: " + e); + } + }); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 6c1210b33268..fcb52363e632 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -21,6 +21,7 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -37,7 +38,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public static enum Type { - OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4); private final byte id; @@ -60,6 +61,7 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { case 1: return UploadBlock.decode(buf); case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); + case 4: return RegisterDriver.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index cadc8e8369c6..102d4efb8bf3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -19,6 +19,8 @@ import java.util.Arrays; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -34,7 +36,11 @@ public class ExecutorShuffleInfo implements Encodable { /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ public final String shuffleManager; - public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + @JsonCreator + public ExecutorShuffleInfo( + @JsonProperty("localDirs") String[] localDirs, + @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, + @JsonProperty("shuffleManager") String shuffleManager) { this.localDirs = localDirs; this.subDirsPerLocalDir = subDirsPerLocalDir; this.shuffleManager = shuffleManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 62fce9b0d16c..ce954b8a289e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -23,7 +23,9 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ public class OpenBlocks extends BlockTransferMessage { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index 7eb438504407..cca8b17c4f12 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -21,7 +21,9 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** * Initial registration message between an executor and its local shuffle server. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index bc9daa6158ba..1915295aa6cc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -20,7 +20,8 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index 0b23e112bd51..3caed59d508f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -23,7 +23,9 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java new file mode 100644 index 000000000000..94a61d6caadc --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol.mesos; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A message sent from the driver to register with the MesosExternalShuffleService. + */ +public class RegisterDriver extends BlockTransferMessage { + private final String appId; + + public RegisterDriver(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.REGISTER_DRIVER; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId); + } + + public static RegisterDriver decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + return new RegisterDriver(appId); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index d25283e46ef9..382f613ecbb1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.util.Arrays; import com.google.common.collect.Lists; import org.junit.After; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -72,10 +74,11 @@ public String getSecretKey(String appId) { @BeforeClass public static void beforeAll() throws IOException { SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); - SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); conf = new TransportConf(new SystemPropertyConfigProvider()); - context = new TransportContext(conf, handler); - server = context.createServer(); + context = new TransportContext(conf, new TestRpcHandler()); + + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + server = context.createServer(Arrays.asList(bootstrap)); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java deleted file mode 100644 index 67a07f38eb5a..000000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.util.Map; - -import com.google.common.collect.ImmutableMap; -import org.junit.Test; - -import static org.junit.Assert.*; - -/** - * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. - */ -public class SparkSaslSuite { - - /** Provides a secret key holder which returns secret key == appId */ - private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { - @Override - public String getSaslUser(String appId) { - return "user"; - } - - @Override - public String getSecretKey(String appId) { - return appId; - } - }; - - @Test - public void testMatching() { - SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); - SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); - - assertFalse(client.isComplete()); - assertFalse(server.isComplete()); - - byte[] clientMessage = client.firstToken(); - - while (!client.isComplete()) { - clientMessage = client.response(server.response(clientMessage)); - } - assertTrue(server.isComplete()); - - // Disposal should invalidate - server.dispose(); - assertFalse(server.isComplete()); - client.dispose(); - assertFalse(client.isComplete()); - } - - - @Test - public void testNonMatching() { - SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); - SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); - - assertFalse(client.isComplete()); - assertFalse(server.isComplete()); - - byte[] clientMessage = client.firstToken(); - - try { - while (!client.isComplete()) { - clientMessage = client.response(server.response(clientMessage)); - } - fail("Should not have completed"); - } catch (Exception e) { - assertTrue(e.getMessage().contains("Mismatched response")); - assertFalse(client.isComplete()); - assertFalse(server.isComplete()); - } - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 3f9fe1681cf2..1d197497b7c8 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -45,14 +45,14 @@ public class ExternalShuffleBlockHandlerSuite { TransportClient client = mock(TransportClient.class); OneForOneStreamManager streamManager; - ExternalShuffleBlockManager blockManager; + ExternalShuffleBlockResolver blockResolver; RpcHandler handler; @Before public void beforeEach() { streamManager = mock(OneForOneStreamManager.class); - blockManager = mock(ExternalShuffleBlockManager.class); - handler = new ExternalShuffleBlockHandler(streamManager, blockManager); + blockResolver = mock(ExternalShuffleBlockResolver.class); + handler = new ExternalShuffleBlockHandler(streamManager, blockResolver); } @Test @@ -62,7 +62,7 @@ public void testRegisterExecutor() { ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); handler.receive(client, registerMessage, callback); - verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); + verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); verify(callback, times(1)).onSuccess((byte[]) any()); verify(callback, never()).onFailure((Throwable) any()); @@ -75,12 +75,12 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); - when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); + when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); handler.receive(client, openBlocks, callback); - verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); verify(callback, times(1)).onSuccess(response.capture()); @@ -90,9 +90,11 @@ public void testOpenShuffleBlocks() { (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); assertEquals(2, handle.numChunks); - ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> stream = (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); verify(streamManager, times(1)).registerStream(stream.capture()); - Iterator buffers = (Iterator) stream.getValue(); + Iterator buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java deleted file mode 100644 index dad6428a836f..000000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; - -import com.google.common.io.CharStreams; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import static org.junit.Assert.*; - -public class ExternalShuffleBlockManagerSuite { - static String sortBlock0 = "Hello!"; - static String sortBlock1 = "World!"; - - static String hashBlock0 = "Elementary"; - static String hashBlock1 = "Tabular"; - - static TestShuffleDataContext dataContext; - - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); - - @BeforeClass - public static void beforeAll() throws IOException { - dataContext = new TestShuffleDataContext(2, 5); - - dataContext.create(); - // Write some sort and hash data. - dataContext.insertSortShuffleData(0, 0, - new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); - dataContext.insertHashShuffleData(1, 0, - new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); - } - - @AfterClass - public static void afterAll() { - dataContext.cleanup(); - } - - @Test - public void testBadRequests() { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); - // Unregistered executor - try { - manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); - fail("Should have failed"); - } catch (RuntimeException e) { - assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); - } - - // Invalid shuffle manager - manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - try { - manager.getBlockData("app0", "exec2", "shuffle_1_1_0"); - fail("Should have failed"); - } catch (UnsupportedOperationException e) { - // pass - } - - // Nonexistent shuffle block - manager.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); - try { - manager.getBlockData("app0", "exec3", "shuffle_1_1_0"); - fail("Should have failed"); - } catch (Exception e) { - // pass - } - } - - @Test - public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); - manager.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); - - InputStream block0Stream = - manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); - block0Stream.close(); - assertEquals(sortBlock0, block0); - - InputStream block1Stream = - manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); - block1Stream.close(); - assertEquals(sortBlock1, block1); - } - - @Test - public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); - manager.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); - - InputStream block0Stream = - manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); - block0Stream.close(); - assertEquals(hashBlock0, block0); - - InputStream block1Stream = - manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); - block1Stream.close(); - assertEquals(hashBlock1, block1); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java new file mode 100644 index 000000000000..3c6cb367dea4 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.CharStreams; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExternalShuffleBlockResolverSuite { + static String sortBlock0 = "Hello!"; + static String sortBlock1 = "World!"; + + static String hashBlock0 = "Elementary"; + static String hashBlock1 = "Tabular"; + + static TestShuffleDataContext dataContext; + + static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + @BeforeClass + public static void beforeAll() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort and hash data. + dataContext.insertSortShuffleData(0, 0, + new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); + dataContext.insertHashShuffleData(1, 0, + new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void testBadRequests() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + // Unregistered executor + try { + resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (RuntimeException e) { + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); + } + + // Invalid shuffle manager + resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); + try { + resolver.getBlockData("app0", "exec2", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (UnsupportedOperationException e) { + // pass + } + + // Nonexistent shuffle block + resolver.registerExecutor("app0", "exec3", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + try { + resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (Exception e) { + // pass + } + } + + @Test + public void testSortShuffleBlocks() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + resolver.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + + InputStream block0Stream = + resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(sortBlock1, block1); + } + + @Test + public void testHashShuffleBlocks() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + resolver.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + + InputStream block0Stream = + resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(hashBlock0, block0); + + InputStream block1Stream = + resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(hashBlock1, block1); + } + + @Test + public void jsonSerializationOfExecutorRegistration() throws IOException { + ObjectMapper mapper = new ObjectMapper(); + AppExecId appId = new AppExecId("foo", "bar"); + String appIdJson = mapper.writeValueAsString(appId); + AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class); + assertEquals(parsedAppId, appId); + + ExecutorShuffleInfo shuffleInfo = + new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); + String shuffleJson = mapper.writeValueAsString(shuffleInfo); + ExecutorShuffleInfo parsedShuffleInfo = + mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); + assertEquals(parsedShuffleInfo, shuffleInfo); + + // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // its not legacy yet, but keeping this here in case anybody changes it + String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; + assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); + String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + + "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; + assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 254e3a7a32b9..2f4f1d0df478 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -41,14 +41,15 @@ public class ExternalShuffleCleanupSuite { public void noCleanupAndCleanup() throws IOException { TestShuffleDataContext dataContext = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); - manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); - manager.applicationRemoved("app", false /* cleanup */); + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + resolver.applicationRemoved("app", false /* cleanup */); assertStillThere(dataContext); - manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); - manager.applicationRemoved("app", true /* cleanup */); + resolver.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + resolver.applicationRemoved("app", true /* cleanup */); assertCleanedUp(dataContext); } @@ -64,7 +65,8 @@ public void cleanupUsesExecutor() throws IOException { @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } }; - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, noThreadExecutor); + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", true); @@ -81,11 +83,12 @@ public void cleanupMultipleExecutors() throws IOException { TestShuffleDataContext dataContext0 = createSomeData(); TestShuffleDataContext dataContext1 = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); - manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); - manager.applicationRemoved("app", true); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + resolver.applicationRemoved("app", true); assertCleanedUp(dataContext0); assertCleanedUp(dataContext1); @@ -96,25 +99,26 @@ public void cleanupOnlyRemovedApp() throws IOException { TestShuffleDataContext dataContext0 = createSomeData(); TestShuffleDataContext dataContext1 = createSomeData(); - ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); - manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); - manager.applicationRemoved("app-nonexistent", true); + resolver.applicationRemoved("app-nonexistent", true); assertStillThere(dataContext0); assertStillThere(dataContext1); - manager.applicationRemoved("app-0", true); + resolver.applicationRemoved("app-0", true); assertCleanedUp(dataContext0); assertStillThere(dataContext1); - manager.applicationRemoved("app-1", true); + resolver.applicationRemoved("app-1", true); assertCleanedUp(dataContext0); assertCleanedUp(dataContext1); // Make sure it's not an error to cleanup multiple times - manager.applicationRemoved("app-1", true); + resolver.applicationRemoved("app-1", true); assertCleanedUp(dataContext0); assertCleanedUp(dataContext1); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 02c10bcb7b26..a3f9a38b1aeb 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -92,7 +92,7 @@ public static void beforeAll() throws IOException { dataContext1.insertHashShuffleData(1, 0, exec1Blocks); conf = new TransportConf(new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(conf); + handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } @@ -136,7 +136,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -274,7 +274,7 @@ public void testFetchNoServer() throws Exception { private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 759a12910c94..aa99efda9494 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.util.Arrays; import org.junit.After; import org.junit.Before; @@ -27,10 +28,11 @@ import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -41,11 +43,12 @@ public class ExternalShuffleSecuritySuite { TransportServer server; @Before - public void beforeEach() { - RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), - new TestSecretKeyHolder("my-app-id", "secret")); - TransportContext context = new TransportContext(conf, handler); - this.server = context.createServer(); + public void beforeEach() throws IOException { + TransportContext context = + new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, + new TestSecretKeyHolder("my-app-id", "secret")); + this.server = context.createServer(Arrays.asList(bootstrap)); } @After @@ -58,13 +61,13 @@ public void afterEach() { @Test public void testValid() throws IOException { - validate("my-app-id", "secret"); + validate("my-app-id", "secret", false); } @Test public void testBadAppId() { try { - validate("wrong-app-id", "secret"); + validate("wrong-app-id", "secret", false); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); } @@ -73,16 +76,21 @@ public void testBadAppId() { @Test public void testBadSecret() { try { - validate("my-app-id", "bad-secret"); + validate("my-app-id", "bad-secret", false); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } } + @Test + public void testEncryption() throws IOException { + validate("my-app-id", "secret", true); + } + /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey) throws IOException { + private void validate(String appId, String secretKey, boolean encrypt) throws IOException { ExternalShuffleClient client = - new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 842741e3d354..b35a6d685dd0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -28,11 +28,16 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.junit.Assert.*; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 1ad0d72ae5ec..06e46f924109 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -20,7 +20,9 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import com.google.common.collect.ImmutableMap; @@ -67,13 +69,13 @@ public void afterEach() { public void testNoFailures() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // Immediately return both blocks successfully. ImmutableMap.builder() .put("b0", block0) .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -86,13 +88,13 @@ public void testNoFailures() throws IOException { public void testUnrecoverableFailure() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0 throws a non-IOException error, so it will be failed without retry. ImmutableMap.builder() .put("b0", new RuntimeException("Ouch!")) .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -105,7 +107,7 @@ public void testUnrecoverableFailure() throws IOException { public void testSingleIOExceptionOnFirst() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // IOException will cause a retry. Since b0 fails, we will retry both. ImmutableMap.builder() .put("b0", new IOException("Connection failed or something")) @@ -114,8 +116,8 @@ public void testSingleIOExceptionOnFirst() throws IOException { ImmutableMap.builder() .put("b0", block0) .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -128,7 +130,7 @@ public void testSingleIOExceptionOnFirst() throws IOException { public void testSingleIOExceptionOnSecond() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // IOException will cause a retry. Since b1 fails, we will not retry b0. ImmutableMap.builder() .put("b0", block0) @@ -136,8 +138,8 @@ public void testSingleIOExceptionOnSecond() throws IOException { .build(), ImmutableMap.builder() .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -150,7 +152,7 @@ public void testSingleIOExceptionOnSecond() throws IOException { public void testTwoIOExceptions() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0's IOException will trigger retry, b1's will be ignored. ImmutableMap.builder() .put("b0", new IOException()) @@ -164,8 +166,8 @@ public void testTwoIOExceptions() throws IOException { // b1 returns successfully within 2 retries. ImmutableMap.builder() .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -178,7 +180,7 @@ public void testTwoIOExceptions() throws IOException { public void testThreeIOExceptions() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0's IOException will trigger retry, b1's will be ignored. ImmutableMap.builder() .put("b0", new IOException()) @@ -196,8 +198,8 @@ public void testThreeIOExceptions() throws IOException { // This is not reached -- b1 has failed. ImmutableMap.builder() .put("b1", block1) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -210,7 +212,7 @@ public void testThreeIOExceptions() throws IOException { public void testRetryAndUnrecoverable() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); - Map[] interactions = new Map[] { + List> interactions = Arrays.asList( // b0's IOException will trigger retry, subsequent messages will be ignored. ImmutableMap.builder() .put("b0", new IOException()) @@ -226,8 +228,8 @@ public void testRetryAndUnrecoverable() throws IOException { // b2 succeeds in its last retry. ImmutableMap.builder() .put("b2", block2) - .build(), - }; + .build() + ); performInteractions(interactions, listener); @@ -248,7 +250,8 @@ public void testRetryAndUnrecoverable() throws IOException { * subset of the original blocks in a second interaction. */ @SuppressWarnings("unchecked") - private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + private static void performInteractions(List> interactions, + BlockFetchingListener listener) throws IOException { TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 76639114df5d..3fdde054ab6c 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -29,7 +29,7 @@ /** * Manages some sort- and hash-based shuffle data, including the creation - * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. + * and cleanup of directories that can be read by the {@link ExternalShuffleBlockResolver}. */ public class TestShuffleDataContext { public final String[] localDirs; @@ -61,9 +61,9 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); long offset = 0; indexStream.writeLong(offset); @@ -82,7 +82,7 @@ public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) thr for (int i = 0; i < blocks.length; i ++) { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; Files.write(blocks[i], - ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId)); + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId)); } } diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index acec8f18f2b5..a99f7c4392d3 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -33,6 +33,8 @@ http://spark.apache.org/ network-yarn + + provided @@ -47,7 +49,6 @@ org.apache.hadoop hadoop-client - provided diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index a34aabe9e78a..11ea7f3fd3cf 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -17,25 +17,23 @@ package org.apache.spark.network.yarn; -import java.lang.Override; +import java.io.File; import java.nio.ByteBuffer; +import java.util.List; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; -import org.apache.hadoop.yarn.server.api.AuxiliaryService; -import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; -import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; -import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; -import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.apache.hadoop.yarn.server.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; -import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.yarn.util.HadoopConfigProvider; @@ -76,9 +74,27 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + // Handles registering executors and opening shuffle blocks + @VisibleForTesting + ExternalShuffleBlockHandler blockHandler; + + // Where to store & reload executor info for recovering state after an NM restart + @VisibleForTesting + File registeredExecutorFile; + + // just for testing when you want to find an open port + @VisibleForTesting + static int boundPort = -1; + + // just for integration tests that want to look at this file -- in general not sensible as + // a static + @VisibleForTesting + static YarnShuffleService instance; + public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); + instance = this; } /** @@ -95,23 +111,42 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + registeredExecutorFile = + findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); + try { + blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + } catch (Exception e) { + logger.error("Failed to initialize external shuffle service", e); + } + + List bootstraps = Lists.newArrayList(); if (authEnabled) { secretManager = new ShuffleSecretManager(); - rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); } int port = conf.getInt( SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, rpcHandler); - shuffleServer = transportContext.createServer(port); + TransportContext transportContext = new TransportContext(transportConf, blockHandler); + shuffleServer = transportContext.createServer(port, bootstraps); + // the port should normally be fixed, but for tests its useful to find an open port + port = shuffleServer.getPort(); + boundPort = port; String authEnabledString = authEnabled ? "enabled" : "not enabled"; logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}.", port, authEnabledString); + "Authentication is {}. Registered executor file is {}", port, authEnabledString, + registeredExecutorFile); } @Override @@ -136,6 +171,7 @@ public void stopApplication(ApplicationTerminationContext context) { if (isAuthenticationEnabled()) { secretManager.unregisterApp(appId); } + blockHandler.applicationRemoved(appId, false /* clean up local dirs */); } catch (Exception e) { logger.error("Exception when stopping application {}", appId, e); } @@ -153,6 +189,16 @@ public void stopContainer(ContainerTerminationContext context) { logger.info("Stopping container {}", containerId); } + private File findRegisteredExecutorFile(String[] localDirs) { + for (String dir: localDirs) { + File f = new File(dir, "registeredExecutors.ldb"); + if (f.exists()) { + return f; + } + } + return new File(localDirs[0], "registeredExecutors.ldb"); + } + /** * Close the shuffle server to clean up any associated state. */ @@ -162,6 +208,9 @@ protected void serviceStop() { if (shuffleServer != null) { shuffleServer.close(); } + if (blockHandler != null) { + blockHandler.close(); + } } catch (Exception e) { logger.error("Exception when stopping service", e); } @@ -172,5 +221,4 @@ protected void serviceStop() { public ByteBuffer getMetaData() { return ByteBuffer.allocate(0); } - } diff --git a/pom.xml b/pom.xml index e25eced87757..0716016523ee 100644 --- a/pom.xml +++ b/pom.xml @@ -25,8 +25,8 @@ 14 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -59,7 +59,7 @@ - 3.0.4 + ${maven.version} @@ -97,60 +97,90 @@ sql/catalyst sql/core sql/hive + unsafe assembly external/twitter external/flume external/flume-sink + external/flume-assembly external/mqtt + external/mqtt-assembly external/zeromq examples repl + launcher + external/kafka + external/kafka-assembly UTF-8 UTF-8 - org.spark-project.akka - 2.3.4-spark - 1.6 + com.typesafe.akka + 2.3.11 + 1.7 + 3.3.3 spark - 2.0.1 - 0.21.0 + 0.21.1 shaded-protobuf 1.7.10 1.2.17 - 1.0.4 - 2.4.1 + 2.2.0 + 2.5.0 ${hadoop.version} - 0.98.7-hadoop1 + 0.98.7-hadoop2 hbase - 1.4.0 + 1.6.0 3.4.5 + 2.4.0 org.spark-project.hive - 0.13.1a + 1.2.1.spark - 0.13.1 + 1.2.1 10.10.1.1 - 1.6.0rc3 - 1.2.3 + 1.7.0 + 1.6.0 + 1.2.4 8.1.14.v20131031 + 3.0.0.v201112011016 0.5.0 - 3.1.0 - 1.7.6 - + 2.4.0 + 2.0.8 + 3.1.2 + 1.7.7 + hadoop2 0.7.1 - 1.8.3 - 1.1.0 - 4.2.6 - 3.1.1 - ${project.build.directory}/spark-test-classpath.txt + 1.9.16 + 1.2.1 + + 4.3.2 + + 3.1 + 3.4.1 2.10.4 2.10 ${scala.version} org.scala-lang - 1.8.8 - 1.1.1.6 + 1.9.13 + 2.4.4 + 1.1.1.7 + 1.1.2 + 1.2.0-incubating + 1.10 + + 2.6 + + 3.3.2 + 3.2.10 + 2.7.8 + 1.9 + 2.5 + 3.5.2 + 1.3.9 + 0.9.2 + + ${java.home} ${session.executionRootDirectory} @@ -176,7 +207,6 @@ 512m 512m - central @@ -234,10 +264,18 @@ false + + spark-hive-staging + Staging Repo for Hive 1.2.1 (Spark Version) + https://oss.sonatype.org/content/repositories/orgspark-project-1113 + + true + + mapr-repo MapR Repository - http://repository.mapr.com/maven + http://repository.mapr.com/maven/ true @@ -245,10 +283,23 @@ false + spring-releases Spring Release Repository https://repo.spring.io/libs-release + + false + + + false + + + + + twttr-repo + Twttr Repository + http://maven.twttr.com true @@ -279,17 +330,6 @@ unused 1.0.0 - - - org.codehaus.groovy - groovy-all - 2.3.7 - provided - @@ -390,32 +425,74 @@ provided - + org.apache.commons commons-lang3 - 3.3.2 + ${commons-lang3.version} + + + commons-lang + commons-lang + ${commons-lang2.version} commons-codec commons-codec - 1.5 + ${commons-codec.version} org.apache.commons commons-math3 ${commons.math3.version} + + org.apache.ivy + ivy + ${ivy.version} + com.google.code.findbugs jsr305 - 1.3.9 + ${jsr305.version} + + + commons-httpclient + commons-httpclient + ${httpclient.classic.version} + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + org.apache.httpcomponents + httpcore + ${commons.httpclient.version} org.seleniumhq.selenium selenium-java 2.42.2 test + + + com.google.guava + guava + + + io.netty + netty + + + + + + xml-apis + xml-apis + 1.4.01 + test org.slf4j @@ -449,7 +526,7 @@ com.ning compress-lzf - 1.0.0 + 1.0.3 org.xerial.snappy @@ -460,7 +537,7 @@ net.jpountz.lz4 lz4 - 1.2.0 + 1.3.0 com.clearspring.analytics @@ -541,7 +618,7 @@ io.netty netty-all - 4.0.23.Final + 4.0.29.Final org.apache.derby @@ -573,6 +650,52 @@ metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + + + com.fasterxml.jackson.core + jackson-annotations + ${fasterxml.jackson.version} + + + + com.fasterxml.jackson.module + jackson-module-scala_${scala.binary.version} + ${fasterxml.jackson.version} + + + com.google.guava + guava + + + + + com.sun.jersey + jersey-server + ${jersey.version} + ${hadoop.deps.scope} + + + com.sun.jersey + jersey-core + ${jersey.version} + ${hadoop.deps.scope} + + + com.sun.jersey + jersey-json + ${jersey.version} + + + stax + stax-api + + + org.scala-lang scala-compiler @@ -604,23 +727,10 @@ 2.2.1 test - - org.easymock - easymockclassextension - 3.1 - test - - - - asm - asm - 3.3.1 - test - org.mockito - mockito-all - 1.9.0 + mockito-core + 1.9.5 test @@ -635,6 +745,18 @@ 4.10 test + + org.hamcrest + hamcrest-core + 1.3 + test + + + org.hamcrest + hamcrest-library + 1.3 + test + com.novocode junit-interface @@ -644,7 +766,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} ${hadoop.deps.scope} @@ -653,6 +775,22 @@ + + org.apache.curator + curator-client + ${curator.version} + + + org.apache.curator + curator-framework + ${curator.version} + + + org.apache.curator + curator-test + ${curator.version} + test + org.apache.hadoop hadoop-client @@ -663,6 +801,10 @@ asm asm + + org.codehaus.jackson + jackson-mapper-asl + org.ow2.asm asm @@ -675,6 +817,10 @@ commons-logging commons-logging + + org.mockito + mockito-all + org.mortbay.jetty servlet-api-2.5 @@ -925,119 +1071,709 @@ ${hadoop.deps.scope} - ${hive.group} - hive-beeline - ${hive.version} - ${hive.deps.scope} + org.codehaus.jackson + jackson-xc + ${codehaus.jackson.version} - ${hive.group} - hive-cli - ${hive.version} - ${hive.deps.scope} + org.codehaus.jackson + jackson-jaxrs + ${codehaus.jackson.version} ${hive.group} - hive-exec + hive-beeline ${hive.version} ${hive.deps.scope} - commons-logging - commons-logging + ${hive.group} + hive-common - com.esotericsoftware.kryo - kryo + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging ${hive.group} - hive-jdbc - ${hive.version} - ${hive.deps.scope} - - - ${hive.group} - hive-metastore + hive-cli ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} - hive-serde + hive-common ${hive.version} ${hive.deps.scope} - commons-logging - commons-logging + ${hive.group} + hive-shims + + + org.apache.ant + ant + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j commons-logging - commons-logging-api + commons-logging + - com.twitter - parquet-column - ${parquet.version} - ${parquet.deps.scope} - - - com.twitter - parquet-hadoop - ${parquet.version} - ${parquet.deps.scope} - - - org.apache.flume - flume-ng-core - ${flume.version} - ${flume.deps.scope} + ${hive.group} + hive-exec + + ${hive.version} + ${hive.deps.scope} + + - io.netty - netty + ${hive.group} + hive-metastore - org.apache.thrift - libthrift + ${hive.group} + hive-shims - org.mortbay.jetty - servlet-api + ${hive.group} + hive-ant - - - - org.apache.flume - flume-ng-sdk - ${flume.version} - ${flume.deps.scope} - + - io.netty - netty + ${hive.group} + spark-client + + + + + ant + ant + + + org.apache.ant + ant + + + com.esotericsoftware.kryo + kryo + + + commons-codec + commons-codec + + + commons-httpclient + commons-httpclient + + + org.apache.avro + avro-mapred + + + + org.apache.calcite + calcite-core + + + org.apache.curator + apache-curator + + + org.apache.curator + curator-client + + + org.apache.curator + curator-framework org.apache.thrift libthrift - - - - - - - - - - org.apache.maven.plugins + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + + + ${hive.group} + hive-jdbc + ${hive.version} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-common + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + + + + ${hive.group} + hive-metastore + ${hive.version} + ${hive.deps.scope} + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-shims + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + com.google.guava + guava + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + + + + ${hive.group} + hive-serde + ${hive.version} + ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + com.google.code.findbugs + jsr305 + + + org.apache.avro + avro + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + + + + ${hive.group} + hive-service + ${hive.version} + ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + org.apache.curator + curator-framework + + + org.apache.curator + curator-recipes + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + + + + + ${hive.group} + hive-shims + ${hive.version} + ${hive.deps.scope} + + + com.google.guava + guava + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + + + org.apache.parquet + parquet-column + ${parquet.version} + ${parquet.deps.scope} + + + org.apache.parquet + parquet-hadoop + ${parquet.version} + ${parquet.deps.scope} + + + org.apache.parquet + parquet-avro + ${parquet.version} + ${parquet.test.deps.scope} + + + com.twitter + parquet-hadoop-bundle + ${hive.parquet.version} + compile + + + org.apache.flume + flume-ng-core + ${flume.version} + ${flume.deps.scope} + + + io.netty + netty + + + org.apache.flume + flume-ng-auth + + + org.apache.thrift + libthrift + + + org.mortbay.jetty + servlet-api + + + + + org.apache.flume + flume-ng-sdk + ${flume.version} + ${flume.deps.scope} + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.apache.calcite + calcite-core + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.google.guava + guava + + + com.google.code.findbugs + jsr305 + + + org.codehaus.janino + janino + + + + org.hsqldb + hsqldb + + + org.pentaho + pentaho-aggdesigner-algorithm + + + + + org.apache.calcite + calcite-avatica + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.codehaus.janino + janino + ${janino.version} + + + joda-time + joda-time + ${joda.version} + + + org.jodd + jodd-core + ${jodd.version} + + + org.datanucleus + datanucleus-core + ${datanucleus-core.version} + + + org.apache.thrift + libthrift + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + + + org.apache.thrift + libfb303 + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + + + + + + + + + org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1047,7 +1783,7 @@ - 3.0.4 + ${maven.version} ${java.version} @@ -1060,13 +1796,19 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven scala-maven-plugin - 3.2.0 + 3.2.2 + + eclipse-add-source + + add-source + + scala-compile-first process-resources @@ -1110,35 +1852,30 @@ ${java.version} -target ${java.version} + -Xlint:all,-serial,-path - - - - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} - - org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} UTF-8 1024m true + + -Xlint:all,-serial,-path + org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1148,15 +1885,28 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + + ${test_classpath} + ${test.java.home} + + test true - ${session.executionRootDirectory} + ${project.build.directory}/tmp + ${spark.test.home} 1 + false false false - ${test_classpath} true + true + + src false @@ -1179,14 +1929,20 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + ${test.java.home} + test true + ${project.build.directory}/tmp ${spark.test.home} 1 false false true + true + + __not_used__ @@ -1201,17 +1957,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1220,6 +1976,7 @@ create-source-jar jar-no-fork + test-jar-no-fork @@ -1227,7 +1984,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1245,7 +2002,84 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.5 + + + org.apache.maven.plugins + maven-shade-plugin + 2.4.1 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 + + + + + org.eclipse.m2e + lifecycle-mapping + 1.0.0 + + + + + + org.apache.maven.plugins + maven-dependency-plugin + [2.8,) + + build-classpath + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + [2.6,) + + test-jar + + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + [1.8,) + + run + + + + + + + + + @@ -1255,7 +2089,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1264,34 +2098,12 @@ test - ${test_classpath_file} + test_classpath - - - org.codehaus.gmavenplus - gmavenplus-plugin - 1.2 - - - process-test-classes - - execute - - - - - - - - - org.apache.maven.plugins @@ -1427,6 +2228,25 @@ org.scalatest scalatest-maven-plugin + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + prepare-package + + test-jar + + + + log4j.properties + + + + + @@ -1460,6 +2280,7 @@ kinesis-asl extras/kinesis-asl + extras/kinesis-asl-assembly @@ -1514,38 +2335,28 @@ --> - hadoop-0.23 - - - - org.apache.avro - avro - - + hadoop-1 - 0.23.10 + 1.2.1 + 2.4.1 + 0.98.7-hadoop1 + hadoop1 + 1.8.8 + org.spark-project.akka + 2.3.4-spark hadoop-2.2 - - 2.2.0 - 2.5.0 - 0.98.7-hadoop2 - hadoop2 - + hadoop-2.3 2.3.0 - 2.5.0 - 0.9.0 - 0.98.7-hadoop2 - 3.1.1 - hadoop2 + 0.9.3 @@ -1553,58 +2364,26 @@ hadoop-2.4 2.4.0 - 2.5.0 - 0.9.0 - 0.98.7-hadoop2 - 3.1.1 - hadoop2 + 0.9.3 - yarn - - yarn - network/yarn - - - - - mapr3 + hadoop-2.6 - 1.0.3-mapr-3.0.3 - 2.4.1-mapr-1408 - 0.94.17-mapr-1405 - 3.4.5-mapr-1406 + 2.6.0 + 0.9.3 + 3.4.6 + 2.6.0 - mapr4 - - 2.4.1-mapr-1408 - 2.4.1-mapr-1408 - 0.94.17-mapr-1405-4.0.0-FCS - 3.4.5-mapr-1406 - - - - org.apache.curator - curator-recipes - 2.4.0 - - - org.apache.zookeeper - zookeeper - - - - - org.apache.zookeeper - zookeeper - 3.4.5-mapr-1406 - - + yarn + + yarn + network/yarn + @@ -1613,22 +2392,6 @@ sql/hive-thriftserver - - hive-0.12.0 - - 0.12.0-protobuf-2.5 - 0.12.0 - 10.4.2.0 - - - - hive-0.13.1 - - 0.13.1a - 0.13.1 - 10.10.1.1 - - scala-2.10 @@ -1641,10 +2404,25 @@ ${scala.version} org.scala-lang - - external/kafka - external/kafka-assembly - + + + + ${jline.groupid} + jline + ${jline.version} + + + + + + + test-java-home + + env.JAVA_HOME + + + ${env.JAVA_HOME} + @@ -1653,10 +2431,8 @@ scala-2.11 - 2.11.2 + 2.11.7 2.11 - 2.12 - jline @@ -1680,5 +2456,8 @@ parquet-provided + + sparkr + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f0cbf4e57b8c..f16bf989f200 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.2.0" + val previousSparkVersion = "1.4.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b17532c1d814..88745dc086a0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -16,6 +16,7 @@ */ import com.typesafe.tools.mima.core._ +import com.typesafe.tools.mima.core.ProblemFilters._ /** * Additional excludes for checking of Spark's binary compatibility. @@ -33,9 +34,295 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // The old JSON RDD is removed in favor of streaming Jackson + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") + ) ++ Seq( + // SPARK-8479 Add numNonzeros and numActives to Matrix. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numActives") + ) ++ Seq( + // SPARK-8914 Remove RDDApi + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-7292 Provide operator to truncate lineage cheaply + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.RDDCheckpointData"), + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.CheckpointRDD") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") + ) ++ Seq( + // SPARK-6797 Support YARN modes for SparkR + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.PairwiseRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.createRWorker"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.StringRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.BaseRRDD.this") + ) ++ Seq( + // SPARK-7422 add argmax for sparse vectors + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.argmax") + ) ++ Seq( + // SPARK-8906 Move all internal data source classes into execution.datasources + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") + ) ++ Seq( + // SPARK-4751 Dynamic allocation for standalone mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) + + case v if v.startsWith("1.4") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives"), + // SPARK-7681 add SparseVector support for gemv + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.multiply") + ) ++ Seq( + // Execution should never be included as its always internal. + MimaBuild.excludeSparkPackage("sql.execution"), + // This `protected[sql]` method was removed in 1.3.1 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.checkAnalysis"), + // These `private[sql]` class were removed in 1.4.0: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), + // These test support classes were moved out of src/main and into src/test: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.TestGroupWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") + ) ++ Seq( + // SPARK-7530 Added StreamingContext.getState() + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") + ) + case v if v.startsWith("1.3") => Seq( MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in the 1.2 build. MimaBuild.excludeSparkPackage("unused"), @@ -142,6 +429,47 @@ object MimaExcludes { "org.apache.spark.graphx.Graph.getCheckpointFiles"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-5814 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") + ) ++ Seq( + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index fbc8983b953b..ea52bfd67944 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,43 +15,46 @@ * limitations under the License. */ -import java.io.File +import java.io._ import scala.util.Properties -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion -import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} +import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings +import spray.revolver.RevolverPlugin._ + object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", + "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation + + val testTempDir = s"$sparkHome/target/tmp" } object SparkBuild extends PomBuild { @@ -66,6 +69,7 @@ object SparkBuild extends PomBuild { import scala.collection.mutable var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq("sbt") + // scalastyle:off println if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") @@ -85,6 +89,7 @@ object SparkBuild extends PomBuild { println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.") profiles ++= Seq("yarn") } + // scalastyle:on println profiles } @@ -93,8 +98,10 @@ object SparkBuild extends PomBuild { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) + // scalastyle:off println println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") + // scalastyle:on println v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } @@ -113,18 +120,25 @@ object SparkBuild extends PomBuild { case _ => } - override val userPropertiesMap = System.getProperties.toMap + override val userPropertiesMap = System.getProperties.asScala.toMap lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( - javaHome := Properties.envOrNone("JAVA_HOME").map(file), + lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq( + libraryDependencies += compilerPlugin( + "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), + scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) + + lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( + javaHome := sys.env.get("JAVA_HOME") + .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) + .map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - unidocGenjavadocVersion := "0.8", + unidocGenjavadocVersion := "0.9-spark0", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), @@ -138,6 +152,39 @@ object SparkBuild extends PomBuild { javacOptions in (Compile, doc) ++= { val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty + }, + + javacOptions in Compile ++= Seq("-encoding", "UTF-8"), + + // Implements -Xfatal-warnings, ignoring deprecation warnings. + // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. + compile in Compile := { + val analysis = (compile in Compile).value + val s = streams.value + + def logProblem(l: (=> String) => Unit, f: File, p: xsbti.Problem) = { + l(f.toString + ":" + p.position.line.fold("")(_ + ":") + " " + p.message) + l(p.position.lineContent) + l("") + } + + var failed = 0 + analysis.infos.allInfos.foreach { case (k, i) => + i.reportedProblems foreach { p => + val deprecation = p.message.contains("is deprecated") + + if (!deprecation) { + failed = failed + 1 + } + + logProblem(if (deprecation) s.log.warn else s.log.error, k, p) + } + } + + if (failed > 0) { + sys.error(s"$failed fatal warnings") + } + analysis } ) @@ -149,26 +196,31 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) + .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: Add Sql to mima checks - allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn).contains(x)).foreach { + allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, + networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } + /* Unsafe settings */ + enable(Unsafe.settings)(unsafe) + /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Enable Assembly for streamingMqtt test */ + enable(inConfig(Test)(Assembly.settings))(streamingMqtt) + + /* Package pyspark artifacts in a separate zip file for YARN. */ + enable(PySparkAssembly.settings)(assembly) + /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst macro settings */ - enable(Catalyst.settings)(catalyst) - /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -177,6 +229,36 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + + /** + * Adds the ability to run the spark shell directly from SBT without building an assembly + * jar. + * + * Usage: `build/sbt sparkShell` + */ + val sparkShell = taskKey[Unit]("start a spark-shell.") + val sparkSql = taskKey[Unit]("starts the spark sql CLI.") + + enable(Seq( + connectInput in run := true, + fork := true, + outputStrategy in run := Some (StdoutOutput), + + javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m"), + + sparkShell := { + (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value + }, + + javaOptions in Compile += "-Dspark.master=local", + + sparkSql := { + (runMain in Compile).toTask(" org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver").value + } + ))(assembly) + + enable(Seq(sparkShell := sparkShell in "assembly"))(spark) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -187,6 +269,13 @@ object SparkBuild extends PomBuild { } +object Unsafe { + lazy val settings = Seq( + // This option is needed to suppress warnings from sun.misc.Unsafe usage + javacOptions in Compile += "-XDignore.symbol.file" + ) +} + object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } @@ -195,7 +284,7 @@ object Flume { This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. */ -object ExludedDependencies { +object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } ) @@ -226,18 +315,12 @@ object OldDeps { ) } -object Catalyst { - lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - // Quasiquotes break compiling scala doc... - // TODO: Investigate fixing this. - sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) -} - object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -246,21 +329,24 @@ object SQL { |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution - |import org.apache.spark.sql.test.TestSQLContext._ + |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.types._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } object Hive { lazy val settings = Seq( - javaOptions += "-XX:MaxPermSize=1g", + javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => @@ -268,6 +354,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -276,10 +363,11 @@ object Hive { |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution + |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ - |import org.apache.spark.sql.types._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + |import org.apache.spark.sql.hive.test.TestHive.implicits._ + |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new @@ -290,6 +378,7 @@ object Hive { } object Assembly { + import sbtassembly.AssemblyUtils._ import sbtassembly.Plugin._ import AssemblyKeys._ @@ -302,13 +391,16 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" } }, + jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + s"${mName}-test-${v}.jar" + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -321,6 +413,55 @@ object Assembly { ) } +object PySparkAssembly { + import sbtassembly.Plugin._ + import AssemblyKeys._ + import java.util.zip.{ZipOutputStream, ZipEntry} + + lazy val settings = Seq( + // Use a resource generator to copy all .py files from python/pyspark into a managed directory + // to be included in the assembly. We can't just add "python/" to the assembly's resource dir + // list since that will copy unneeded / unwanted files. + resourceGenerators in Compile <+= resourceManaged in Compile map { outDir: File => + val src = new File(BuildCommons.sparkHome, "python/pyspark") + val zipFile = new File(BuildCommons.sparkHome , "python/lib/pyspark.zip") + zipFile.delete() + zipRecursive(src, zipFile) + Seq[File]() + } + ) + + private def zipRecursive(source: File, destZipFile: File) = { + val destOutput = new ZipOutputStream(new FileOutputStream(destZipFile)) + addFilesToZipStream("", source, destOutput) + destOutput.flush() + destOutput.close() + } + + private def addFilesToZipStream(parent: String, source: File, output: ZipOutputStream): Unit = { + if (source.isDirectory()) { + output.putNextEntry(new ZipEntry(parent + source.getName())) + for (file <- source.listFiles()) { + addFilesToZipStream(parent + source.getName() + File.separator, file, output) + } + } else { + val in = new FileInputStream(source) + output.putNextEntry(new ZipEntry(parent + source.getName())) + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + output.write(buf, 0, n) + } + } + output.closeEntry() + in.close() + } + } + +} + object Unidoc { import BuildCommons._ @@ -332,25 +473,39 @@ object Unidoc { names.map(s => "org.apache.spark." + s).mkString(":") } + private def ignoreUndocumentedPackages(packages: Seq[Seq[File]]): Seq[Seq[File]] = { + packages + .map(_.filterNot(_.getName.contains("$"))) + .map(_.filterNot(_.getCanonicalPath.contains("akka"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/unsafe"))) + .map(_.filterNot(_.getCanonicalPath.contains("python"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) + } + lazy val settings = scalaJavaUnidocSettings ++ Seq ( publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + + // Skip actual catalyst, but include the subproject. + // Catalyst is not public API and contains quasiquotes which break scaladoc. + unidocAllSources in (ScalaUnidoc, unidoc) := { + ignoreUndocumentedPackages((unidocAllSources in (ScalaUnidoc, unidoc)).value) + }, // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { - (unidocAllSources in (JavaUnidoc, unidoc)).value - .map(_.filterNot(_.getName.contains("$"))) - .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) - .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) + ignoreUndocumentedPackages((unidocAllSources in (JavaUnidoc, unidoc)).value) }, // Javadoc options: create a window title, and group key packages on index page @@ -369,11 +524,15 @@ object Unidoc { "mllib.tree.impurity", "mllib.tree.model", "mllib.util", "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", - "ml", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", "ml.tuning" + "ml", "ml.attribute", "ml.classification", "ml.clustering", "ml.evaluation", "ml.feature", + "ml.param", "ml.recommendation", "ml.regression", "ml.tuning" ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" - ) + ), + + // Group similar methods together based on the @group annotation. + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") ) } @@ -383,22 +542,28 @@ object TestSettings { lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, + // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes + // launched by the tests have access to the correct test-time classpath. + envVars in Test ++= Map( + "SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", - javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") + javaOptions in Test += "-Dderby.system.durability=test", + javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark")) .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", - javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" + javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, - // This places test scope jars on the classpath of executors during tests. - javaOptions in Test += - "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. - map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), @@ -407,6 +572,13 @@ object TestSettings { libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, + // Make sure the test temp directory exists. + resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => + if (!new File(testTempDir).isDirectory()) { + require(new File(testTempDir).mkdirs()) + } + Seq[File]() + }, concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), // Remove certain packages from Scaladoc scalacOptions in (Compile, doc) := Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905..51820460ca1a 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,16 +19,18 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") + libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da41..471d00bd8223 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala deleted file mode 100644 index 3d43c3529955..000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} - -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit - -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList - } - - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() - -} diff --git a/pylintrc b/pylintrc new file mode 100644 index 000000000000..6a675770da69 --- /dev/null +++ b/pylintrc @@ -0,0 +1,404 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=pyspark.heapq3 + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + +# Allow optimization of some AST trees. This will activate a peephole AST +# optimizer, which will apply various small optimizations. For instance, it can +# be used to obtain the result of joining multiple strings with the addition +# operator. Joining a lot of strings can lead to a maximum recursion error in +# Pylint and this flag can prevent that. It has one side effect, the resulting +# AST will be different than the one from reality. +optimize-ast=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" + +# These errors are arranged in order of number of warning given in pylint. +# If you would like to improve the code quality of pyspark, remove any of these disabled errors +# run ./dev/lint-python and see if the errors raised by pylint can be fixed. + +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions= + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=__.*__ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_$|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis +ignored-modules= + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/python/docs/conf.py b/python/docs/conf.py index b00dce95d65b..163987dd8e5f 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -48,16 +48,16 @@ # General information about the project. project = u'PySpark' -copyright = u'2014, Author' +copyright = u'' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.3-SNAPSHOT' +version = 'master' # The full version, including alpha/beta/rc tags. -release = '1.3-SNAPSHOT' +release = os.environ.get('RELEASE_VERSION', version) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -97,6 +97,10 @@ # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False +# -- Options for autodoc -------------------------------------------------- + +# Look at the first line of the docstring for function and method signatures. +autodoc_docstring_signature = True # -- Options for HTML output ---------------------------------------------- diff --git a/python/docs/index.rst b/python/docs/index.rst index d150de9d5c50..f7eede9c3c82 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.sql.SQLContext` + + Main entry point for DataFrame and SQL functionality. + + :class:`pyspark.sql.DataFrame` + + A distributed collection of data grouped into named columns. + Indices and tables ================== diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index f10d1339a9a8..86d4186a2c79 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -1,17 +1,22 @@ pyspark.ml package -===================== +================== -Submodules ----------- - -pyspark.ml module ------------------ +ML Pipeline APIs +---------------- .. automodule:: pyspark.ml :members: :undoc-members: :inherited-members: +pyspark.ml.param module +----------------------- + +.. automodule:: pyspark.ml.param + :members: + :undoc-members: + :inherited-members: + pyspark.ml.feature module ------------------------- @@ -27,3 +32,43 @@ pyspark.ml.classification module :members: :undoc-members: :inherited-members: + +pyspark.ml.clustering module +---------------------------- + +.. automodule:: pyspark.ml.clustering + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.recommendation module +-------------------------------- + +.. automodule:: pyspark.ml.recommendation + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.regression module +---------------------------- + +.. automodule:: pyspark.ml.regression + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.tuning module +------------------------ + +.. automodule:: pyspark.ml.tuning + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.evaluation module +---------------------------- + +.. automodule:: pyspark.ml.evaluation + :members: + :undoc-members: + :inherited-members: diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 4548b8739ed9..2d54ab118b94 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -1,16 +1,13 @@ pyspark.mllib package ===================== -Submodules ----------- - pyspark.mllib.classification module ----------------------------------- .. automodule:: pyspark.mllib.classification :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.clustering module ------------------------------- @@ -18,7 +15,13 @@ pyspark.mllib.clustering module .. automodule:: pyspark.mllib.clustering :members: :undoc-members: - :show-inheritance: + +pyspark.mllib.evaluation module +------------------------------- + +.. automodule:: pyspark.mllib.evaluation + :members: + :undoc-members: pyspark.mllib.feature module ------------------------------- @@ -28,6 +31,13 @@ pyspark.mllib.feature module :undoc-members: :show-inheritance: +pyspark.mllib.fpm module +------------------------ + +.. automodule:: pyspark.mllib.fpm + :members: + :undoc-members: + pyspark.mllib.linalg module --------------------------- @@ -36,13 +46,20 @@ pyspark.mllib.linalg module :undoc-members: :show-inheritance: +pyspark.mllib.linalg.distributed module +--------------------------------------- + +.. automodule:: pyspark.mllib.linalg.distributed + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.random module --------------------------- .. automodule:: pyspark.mllib.random :members: :undoc-members: - :show-inheritance: pyspark.mllib.recommendation module ----------------------------------- @@ -50,7 +67,6 @@ pyspark.mllib.recommendation module .. automodule:: pyspark.mllib.recommendation :members: :undoc-members: - :show-inheritance: pyspark.mllib.regression module ------------------------------- @@ -58,7 +74,7 @@ pyspark.mllib.regression module .. automodule:: pyspark.mllib.regression :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.stat module ------------------------- @@ -66,7 +82,6 @@ pyspark.mllib.stat module .. automodule:: pyspark.mllib.stat :members: :undoc-members: - :show-inheritance: pyspark.mllib.tree module ------------------------- @@ -74,7 +89,7 @@ pyspark.mllib.tree module .. automodule:: pyspark.mllib.tree :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.util module ------------------------- @@ -82,4 +97,3 @@ pyspark.mllib.util module .. automodule:: pyspark.mllib.util :members: :undoc-members: - :show-inheritance: diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 65b3650ae10a..6259379ed05b 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -1,10 +1,23 @@ pyspark.sql module ================== -Module contents ---------------- +Module Context +-------------- .. automodule:: pyspark.sql :members: :undoc-members: - :show-inheritance: + + +pyspark.sql.types module +------------------------ +.. automodule:: pyspark.sql.types + :members: + :undoc-members: + + +pyspark.sql.functions module +---------------------------- +.. automodule:: pyspark.sql.functions + :members: + :undoc-members: diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index f08185627d0b..50822c93faba 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -8,3 +8,10 @@ Module contents :members: :undoc-members: :show-inheritance: + +pyspark.streaming.kafka module +------------------------------ +.. automodule:: pyspark.streaming.kafka + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index d3efcdf221d8..5f70ac6ed8fe 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -22,17 +22,17 @@ - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - :class:`RDD`: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - :class:`Broadcast`: A broadcast variable that gets reused across tasks. - - L{Accumulator} + - :class:`Accumulator`: An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - :class:`SparkConf`: For configuring Spark. - - L{SparkFiles} + - :class:`SparkFiles`: Access files shipped with jobs. - - L{StorageLevel} + - :class:`StorageLevel`: Finer-grained cache persistence levels. """ @@ -45,6 +45,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler # for back compatibility @@ -53,5 +54,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", - "Profiler", "BasicProfiler", + "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8..6ef8cf53cc74 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -54,7 +54,7 @@ ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): -... for i in xrange(len(val1)): +... for i in range(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) @@ -83,12 +83,16 @@ >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... -Exception:... +TypeError:... """ +import sys import select import struct -import SocketServer +if sys.version < '3': + import SocketServer +else: + import socketserver as SocketServer import threading from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer @@ -247,6 +251,7 @@ class AccumulatorServer(SocketServer.TCPServer): def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + self.server_close() def _start_update_server(): @@ -256,3 +261,9 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 6b8a8b256a89..663c9abe0881 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -16,10 +16,15 @@ # import os -import cPickle +import sys import gc from tempfile import NamedTemporaryFile +if sys.version < '3': + import cPickle as pickle +else: + import pickle + unicode = str __all__ = ['Broadcast'] @@ -70,33 +75,19 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None): self._path = path def dump(self, value, f): - if isinstance(value, basestring): - if isinstance(value, unicode): - f.write('U') - value = value.encode('utf8') - else: - f.write('S') - f.write(value) - else: - f.write('P') - cPickle.dump(value, f, 2) + pickle.dump(value, f, 2) f.close() return f.name def load(self, path): with open(path, 'rb', 1 << 20) as f: - flag = f.read(1) - data = f.read() - if flag == 'P': - # cPickle.loads() may create lots of objects, disable GC - # temporary for better performance - gc.disable() - try: - return cPickle.loads(data) - finally: - gc.enable() - else: - return data.decode('utf8') if flag == 'U' else data + # pickle.load() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return pickle.load(f) + finally: + gc.enable() @property def value(self): @@ -124,4 +115,6 @@ def __reduce__(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index bb0783555aa7..3b647985801b 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -40,164 +40,126 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ - +from __future__ import print_function import operator import os +import io import pickle import struct import sys import types from functools import partial import itertools -from copy_reg import _extension_registry, _inverted_registry, _extension_cache -import new import dis import traceback -import platform - -PyImp = platform.python_implementation() - -import logging -cloudLog = logging.getLogger("Cloud.Transport") +if sys.version < '3': + from pickle import Pickler + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + PY3 = False +else: + types.ClassType = type + from pickle import _Pickler as Pickler + from io import BytesIO as StringIO + PY3 = True #relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +STORE_GLOBAL = dis.opname.index('STORE_GLOBAL') +DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL') +LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL') GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +HAVE_ARGUMENT = dis.HAVE_ARGUMENT +EXTENDED_ARG = dis.EXTENDED_ARG -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) -if PyImp == "PyPy": - # register builtin type in `new` - new.method = types.MethodType - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO - -# These helper functions were copied from PiCloud's util module. def islambda(func): - return getattr(func,'func_name') == '' + return getattr(func,'__name__') == '' -def xrange_params(xrangeobj): - """Returns a 3 element tuple describing the xrange start, step, and len - respectively - Note: Only guarentees that elements of xrange are the same. parameters may - be different. - e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same - though w/ iteration - """ +_BUILTIN_TYPE_NAMES = {} +for k, v in types.__dict__.items(): + if type(v) is type: + _BUILTIN_TYPE_NAMES[v] = k - xrange_len = len(xrangeobj) - if not xrange_len: #empty - return (0,1,0) - start = xrangeobj[0] - if xrange_len == 1: #one element - return start, 1, 1 - return (start, xrangeobj[1] - xrangeobj[0], xrange_len) -#debug variables intended for developer use: -printSerialization = False -printMemoization = False +def _builtin_type(name): + return getattr(types, name) -useForcedImports = True #Should I use forced imports for tracking? +class CloudPickler(Pickler): + dispatch = Pickler.dispatch.copy() -class CloudPickler(pickle.Pickler): - - dispatch = pickle.Pickler.dispatch.copy() - savedForceImports = False - savedDjangoEnv = False #hack tro transport django environment - - def __init__(self, file, protocol=None, min_size_to_save= 0): - pickle.Pickler.__init__(self,file,protocol) - self.modules = set() #set of modules needed to depickle - self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + def __init__(self, file, protocol=None): + Pickler.__init__(self, file, protocol) + # set of modules to unpickle + self.modules = set() + # map ids to dictionary. used to ensure that functions can share global env + self.globals_ref = {} def dump(self, obj): - # note: not thread safe - # minimal side-effects, so not fixing - recurse_limit = 3000 - base_recurse = sys.getrecursionlimit() - if base_recurse < recurse_limit: - sys.setrecursionlimit(recurse_limit) self.inject_addons() try: - return pickle.Pickler.dump(self, obj) - except RuntimeError, e: + return Pickler.dump(self, obj) + except RuntimeError as e: if 'recursion' in e.args[0]: - msg = """Could not pickle object as excessively deep recursion required. - Try _fast_serialization=2 or contact PiCloud support""" + msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - finally: - new_recurse = sys.getrecursionlimit() - if new_recurse == recurse_limit: - sys.setrecursionlimit(base_recurse) + + def save_memoryview(self, obj): + """Fallback to save_string""" + Pickler.save_string(self, str(obj)) def save_buffer(self, obj): """Fallback to save_string""" - pickle.Pickler.save_string(self,str(obj)) - dispatch[buffer] = save_buffer + Pickler.save_string(self,str(obj)) + if PY3: + dispatch[memoryview] = save_memoryview + else: + dispatch[buffer] = save_buffer - #block broken objects - def save_unsupported(self, obj, pack=None): + def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) dispatch[types.GeneratorType] = save_unsupported - #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it - try: - slice(0,1).__reduce__() - except TypeError: #can't pickle - - dispatch[slice] = save_unsupported - - #itertools objects do not pickle! + # itertools objects do not pickle! for v in itertools.__dict__.values(): if type(v) is type: dispatch[v] = save_unsupported - - def save_dict(self, obj): - """hack fix - If the dict is a global, deal with it in a special way - """ - #print 'saving', obj - if obj is __builtins__: - self.save_reduce(_get_module_builtins, (), obj=obj) - else: - pickle.Pickler.save_dict(self, obj) - dispatch[pickle.DictionaryType] = save_dict - - - def save_module(self, obj, pack=struct.pack): + def save_module(self, obj): """ Save a module as an import """ - #print 'try save import', obj.__name__ self.modules.add(obj) - self.save_reduce(subimport,(obj.__name__,), obj=obj) - dispatch[types.ModuleType] = save_module #new type + self.save_reduce(subimport, (obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module - def save_codeobject(self, obj, pack=struct.pack): + def save_codeobject(self, obj): """ Save a code object """ - #print 'try to save codeobj: ', obj - args = ( - obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, - obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars - ) + if PY3: + args = ( + obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, + obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, + obj.co_cellvars + ) + else: + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) self.save_reduce(types.CodeType, args, obj=obj) - dispatch[types.CodeType] = save_codeobject #new type + dispatch[types.CodeType] = save_codeobject - def save_function(self, obj, name=None, pack=struct.pack): + def save_function(self, obj, name=None): """ Registered with the dispatch to handle all function types. Determines what kind of function obj is (e.g. lambda, defined at @@ -205,12 +167,14 @@ def save_function(self, obj, name=None, pack=struct.pack): """ write = self.write - name = obj.__name__ + if name is None: + name = obj.__name__ modname = pickle.whichmodule(obj, name) - #print 'which gives %s %s %s' % (modname, obj, name) + # print('which gives %s %s %s' % (modname, obj, name)) try: themodule = sys.modules[modname] - except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + except KeyError: + # eval'd items such as namedtuple give invalid items for their function __module__ modname = '__main__' if modname == '__main__': @@ -221,37 +185,18 @@ def save_function(self, obj, name=None, pack=struct.pack): if getattr(themodule, name, None) is obj: return self.save_global(obj, name) - if not self.savedDjangoEnv: - #hack for django - if we detect the settings module, we transport it - django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') - if django_settings: - django_mod = sys.modules.get(django_settings) - if django_mod: - cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) - self.savedDjangoEnv = True - self.modules.add(django_mod) - write(pickle.MARK) - self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) - write(pickle.POP_MARK) - - # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule is None: - #Force server to import modules that have been imported in main - modList = None - if themodule is None and not self.savedForceImports: - mainmod = sys.modules['__main__'] - if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): - modList = list(mainmod.___pyc_forcedImports__) - self.savedForceImports = True - self.save_function_tuple(obj, modList) + if islambda(obj) or obj.__code__.co_filename == '' or themodule is None: + #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) + self.save_function_tuple(obj) return - else: # func is nested + else: + # func is nested klass = getattr(themodule, name, None) if klass is None or klass is not obj: - self.save_function_tuple(obj, [themodule]) + self.save_function_tuple(obj) return if obj.__dict__: @@ -266,7 +211,7 @@ def save_function(self, obj, name=None, pack=struct.pack): self.memoize(obj) dispatch[types.FunctionType] = save_function - def save_function_tuple(self, func, forced_imports): + def save_function_tuple(self, func): """ Pickles an actual func object. A func comprises: code, globals, defaults, closure, and dict. We @@ -281,19 +226,6 @@ def save_function_tuple(self, func, forced_imports): save = self.save write = self.write - # save the modules (if any) - if forced_imports: - write(pickle.MARK) - save(_modules_to_main) - #print 'forced imports are', forced_imports - - forced_names = map(lambda m: m.__name__, forced_imports) - save((forced_names,)) - - #save((forced_imports,)) - write(pickle.REDUCE) - write(pickle.POP_MARK) - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater @@ -318,6 +250,8 @@ def extract_code_globals(co): Find all globals names read or written to by codeblock co """ code = co.co_code + if not PY3: + code = [ord(c) for c in code] names = co.co_names out_names = set() @@ -327,18 +261,18 @@ def extract_code_globals(co): while i < n: op = code[i] - i = i+1 + i += 1 if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + oparg = code[i] + code[i+1] * 256 + extended_arg extended_arg = 0 - i = i+2 + i += 2 if op == EXTENDED_ARG: - extended_arg = oparg*65536L + extended_arg = oparg*65536 if op in GLOBAL_OPS: out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - if co.co_consts: # see if nested function have any global refs + # see if nested function have any global refs + if co.co_consts: for const in co.co_consts: if type(const) is types.CodeType: out_names |= CloudPickler.extract_code_globals(const) @@ -350,46 +284,28 @@ def extract_func_data(self, func): Turn the function into a tuple of data necessary to recreate it: code, globals, defaults, closure, dict """ - code = func.func_code + code = func.__code__ # extract all global ref's - func_global_refs = CloudPickler.extract_code_globals(code) + func_global_refs = self.extract_code_globals(code) # process all variables referenced by global environment f_globals = {} for var in func_global_refs: - #Some names, such as class functions are not global - we don't need them - if func.func_globals.has_key(var): - f_globals[var] = func.func_globals[var] + if var in func.__globals__: + f_globals[var] = func.__globals__[var] # defaults requires no processing - defaults = func.func_defaults - - def get_contents(cell): - try: - return cell.cell_contents - except ValueError, e: #cell is empty error on not yet assigned - raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') - + defaults = func.__defaults__ # process closure - if func.func_closure: - closure = map(get_contents, func.func_closure) - else: - closure = [] + closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] # save the dict - dct = func.func_dict - - if printSerialization: - outvars = ['code: ' + str(code) ] - outvars.append('globals: ' + str(f_globals)) - outvars.append('defaults: ' + str(defaults)) - outvars.append('closure: ' + str(closure)) - print 'function ', func, 'is extracted to: ', ', '.join(outvars) + dct = func.__dict__ - base_globals = self.globals_ref.get(id(func.func_globals), {}) - self.globals_ref[id(func.func_globals)] = base_globals + base_globals = self.globals_ref.get(id(func.__globals__), {}) + self.globals_ref[id(func.__globals__)] = base_globals return (code, f_globals, defaults, closure, dct, base_globals) @@ -400,8 +316,9 @@ def save_builtin_function(self, obj): dispatch[types.BuiltinFunctionType] = save_builtin_function def save_global(self, obj, name=None, pack=struct.pack): - write = self.write - memo = self.memo + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) if name is None: name = obj.__name__ @@ -410,98 +327,76 @@ def save_global(self, obj, name=None, pack=struct.pack): if modname is None: modname = pickle.whichmodule(obj, name) - try: - __import__(modname) - themodule = sys.modules[modname] - except (ImportError, KeyError, AttributeError): #should never occur - raise pickle.PicklingError( - "Can't pickle %r: Module %s cannot be found" % - (obj, modname)) - if modname == '__main__': themodule = None - - if themodule: + else: + __import__(modname) + themodule = sys.modules[modname] self.modules.add(themodule) - sendRef = True - typ = type(obj) - #print 'saving', obj, typ - try: - try: #Deal with case when getattribute fails with exceptions - klass = getattr(themodule, name) - except (AttributeError): - if modname == '__builtin__': #new.* are misrepeported - modname = 'new' - __import__(modname) - themodule = sys.modules[modname] - try: - klass = getattr(themodule, name) - except AttributeError, a: - # print themodule, name, obj, type(obj) - raise pickle.PicklingError("Can't pickle builtin %s" % obj) - else: - raise + if hasattr(themodule, name) and getattr(themodule, name) is obj: + return Pickler.save_global(self, obj, name) - except (ImportError, KeyError, AttributeError): - if typ == types.TypeType or typ == types.ClassType: - sendRef = False - else: #we can't deal with this - raise - else: - if klass is not obj and (typ == types.TypeType or typ == types.ClassType): - sendRef = False - if not sendRef: - #note: Third party types might crash this - add better checks! - d = dict(obj.__dict__) #copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties - d.pop('__dict__',None) - d.pop('__weakref__',None) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + d = dict(obj.__dict__) # copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): + # don't extract dict that are properties + d.pop('__dict__', None) + d.pop('__weakref__', None) # hack as __new__ is stored differently in the __dict__ new_override = d.get('__new__', None) if new_override: d['__new__'] = obj.__new__ - self.save_reduce(type(obj),(obj.__name__,obj.__bases__, - d),obj=obj) - #print 'internal reduce dask %s %s' % (obj, d) - return + self.save(_load_class) + self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) + d.pop('__doc__', None) + # handle property and staticmethod + dd = {} + for k, v in d.items(): + if isinstance(v, property): + k = ('property', k) + v = (v.fget, v.fset, v.fdel, v.__doc__) + elif isinstance(v, staticmethod) and hasattr(v, '__func__'): + k = ('staticmethod', k) + v = v.__func__ + elif isinstance(v, classmethod) and hasattr(v, '__func__'): + k = ('classmethod', k) + v = v.__func__ + dd[k] = v + self.save(dd) + self.write(pickle.TUPLE2) + self.write(pickle.REDUCE) - if self.proto >= 2: - code = _extension_registry.get((modname, name)) - if code: - assert code > 0 - if code <= 0xff: - write(pickle.EXT1 + chr(code)) - elif code <= 0xffff: - write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) - else: - write(pickle.EXT4 + pack("> sys.stderr, 'Cloud not import django settings %s:' % (name) - print_exec(sys.stderr) - if modified_env: - del os.environ['DJANGO_SETTINGS_MODULE'] - else: - #add project directory to sys,path: - if hasattr(module,'__file__'): - dirname = os.path.split(module.__file__)[0] + '/' - sys.path.append(dirname) # restores function attributes def _restore_attr(obj, attr): @@ -851,13 +655,16 @@ def _restore_attr(obj, attr): setattr(obj, key, val) return obj + def _get_module_builtins(): return pickle.__builtins__ + def print_exec(stream): ei = sys.exc_info() traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + def _modules_to_main(modList): """Force every module in modList to be placed into main""" if not modList: @@ -868,22 +675,16 @@ def _modules_to_main(modList): if type(modname) is str: try: mod = __import__(modname) - except Exception, i: #catch all... - sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ -A version mismatch is likely. Specific error was:\n' % modname) + except Exception as e: + sys.stderr.write('warning: could not import %s\n. ' + 'Your function may unexpectedly error due to this import failing;' + 'A version mismatch is likely. Specific error was:\n' % modname) print_exec(sys.stderr) else: - setattr(main,mod.__name__, mod) - else: - #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) - #In old version actual module was sent - setattr(main,modname.__name__, modname) + setattr(main, mod.__name__, mod) -#object generators: -def _build_xrange(start, step, len): - """Built xrange explicitly""" - return xrange(start, start + step*len, step) +#object generators: def _genpartial(func, args, kwds): if not args: args = () @@ -891,22 +692,26 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) + def _fill_function(func, globals, defaults, dict): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ - func.func_globals.update(globals) - func.func_defaults = defaults - func.func_dict = dict + func.__globals__.update(globals) + func.__defaults__ = defaults + func.__dict__ = dict return func + def _make_cell(value): - return (lambda: value).func_closure[0] + return (lambda: value).__closure__[0] + def _reconstruct_closure(values): return tuple([_make_cell(v) for v in values]) + def _make_skel_func(code, closures, base_globals = None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other @@ -922,46 +727,26 @@ def _make_skel_func(code, closures, base_globals = None): None, None, closure) +def _load_class(cls, d): + """ + Loads additional properties into class `cls`. + """ + for k, v in d.items(): + if isinstance(k, tuple): + typ, k = k + if typ == 'property': + v = property(*v) + elif typ == 'staticmethod': + v = staticmethod(v) + elif typ == 'classmethod': + v = classmethod(v) + setattr(cls, k, v) + return cls + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] - -def _generateImage(size, mode, str_rep): - """Generate image from string representation""" - import Image - i = Image.new(mode, size) - i.fromstring(str_rep) - return i - -def _lazyloadImage(fp): - import Image - fp.seek(0) #works in almost any case - return Image.open(fp) - -"""Timeseries""" -def _genTimeSeries(reduce_args, state): - import scikits.timeseries.tseries as ts - from numpy import ndarray - from numpy.ma import MaskedArray - - - time_series = ts._tsreconstruct(*reduce_args) - - #from setstate modified - (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state - #print 'regenerating %s' % dtyp - - MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) - _dates = time_series._dates - #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ - ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) - _dates.freq = frq - _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, - toobj=None, toord=None, tostr=None)) - # Update the _optinfo dictionary - time_series._optinfo.update(infodict) - return time_series - diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index dc7cd0bce56f..924da3eecf21 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -44,7 +44,7 @@ >>> conf.get("spark.executorEnv.VAR1") u'value1' ->>> print conf.toDebugString() +>>> print(conf.toDebugString()) spark.executorEnv.VAR1=value1 spark.executorEnv.VAR3=value3 spark.executorEnv.VAR4=value4 @@ -56,6 +56,13 @@ __all__ = ['SparkConf'] +import sys +import re + +if sys.version > '3': + unicode = str + __doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__) + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index bf1f61c8504e..1b2a52ad6411 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import os import shutil import sys @@ -30,11 +32,13 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler -from py4j.java_collections import ListConverter +if sys.version > '3': + xrange = range __all__ = ['SparkContext'] @@ -58,12 +62,13 @@ class SparkContext(object): _gateway = None _jvm = None - _writeToFile = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler): @@ -131,7 +136,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if sparkHome: self._conf.setSparkHome(sparkHome) if environment: - for key, value in environment.iteritems(): + for key, value in environment.items(): self._conf.setExecutorEnv(key, value) for key, value in DEFAULT_CONFIGS.items(): self._conf.setIfMissing(key, value) @@ -147,10 +152,19 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.master = self._conf.get("spark.master") self.appName = self._conf.get("spark.app.name") self.sparkHome = self._conf.get("spark.home", None) + + # Let YARN know it's a pyspark app, so it distributes needed libraries. + if self.master == "yarn-client": + self._conf.set("spark.yarn.isPython", "true") + for (k, v) in self._conf.getAll(): if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] self.environment[varName] = v + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + # disable randomness of hash of string in worker, if this is not + # launched by spark-submit + self.environment["PYTHONHASHSEED"] = "0" # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) @@ -164,6 +178,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') + self.pythonVer = "%d.%d" % sys.version_info[:2] # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -185,7 +200,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) @@ -218,7 +233,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: if (SparkContext._active_spark_context and @@ -259,6 +273,13 @@ def __exit__(self, type, value, trace): """ self.stop() + def setLogLevel(self, logLevel): + """ + Control our logLevel. This overrides any user-defined log settings. + Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + """ + self._jsc.setLogLevel(logLevel) + @classmethod def setSystemProperty(cls, key, value): """ @@ -275,6 +296,26 @@ def version(self): """ return self._jsc.version() + @property + @ignore_unicode_prefix + def applicationId(self): + """ + A unique identifier for the Spark application. + Its format depends on the scheduler implementation. + + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + + >>> sc.applicationId # doctest: +ELLIPSIS + u'local-...' + """ + return self._jsc.sc().applicationId() + + @property + def startTime(self): + """Return the epoch time when the Spark Context was started.""" + return self._jsc.startTime() + @property def defaultParallelism(self): """ @@ -303,6 +344,38 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None + def emptyRDD(self): + """ + Create an RDD that has no partitions or elements. + """ + return RDD(self._jsc.emptyRDD(), self, NoOpSerializer()) + + def range(self, start, end=None, step=1, numSlices=None): + """ + Create a new RDD of int containing elements from `start` to `end` + (exclusive), increased by `step` every element. Can be called the same + way as python's built-in range() function. If called with a single argument, + the argument is interpreted as `end`, and `start` is set to 0. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numSlices: the number of partitions of the new RDD + :return: An RDD of int + + >>> sc.range(5).collect() + [0, 1, 2, 3, 4] + >>> sc.range(2, 4).collect() + [2, 3] + >>> sc.range(1, 7, 2).collect() + [1, 3, 5] + """ + if end is None: + end = start + start = 0 + + return self.parallelize(xrange(start, end, step), numSlices) + def parallelize(self, c, numSlices=None): """ Distribute a local Python collection to form an RDD. Using xrange @@ -322,7 +395,7 @@ def parallelize(self, c, numSlices=None): start0 = c[0] def getStart(split): - return start0 + (split * size / numSlices) * step + return start0 + int((split * size / numSlices)) * step def f(split, iterator): return xrange(getStart(split), getStart(split + 1), step) @@ -356,6 +429,7 @@ def pickleFile(self, name, minPartitions=None): minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.objectFile(name, minPartitions), self) + @ignore_unicode_prefix def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all @@ -368,7 +442,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello world!") + ... _ = testFile.write("Hello world!") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello world!'] @@ -377,6 +451,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode)) + @ignore_unicode_prefix def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system @@ -410,9 +485,9 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): >>> dirPath = os.path.join(tempdir, "files") >>> os.mkdir(dirPath) >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: - ... file1.write("1") + ... _ = file1.write("1") >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: - ... file2.write("2") + ... _ = file2.write("2") >>> textFiles = sc.wholeTextFiles(dirPath) >>> sorted(textFiles.collect()) [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] @@ -455,7 +530,7 @@ def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: d = {} - for k, v in d.iteritems(): + for k, v in d.items(): jm[k] = v return jm @@ -607,6 +682,7 @@ def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) + @ignore_unicode_prefix def union(self, rdds): """ Build the union of a list of RDDs. @@ -617,7 +693,7 @@ def union(self, rdds): >>> path = os.path.join(tempdir, "union-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello") + ... _ = testFile.write("Hello") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello'] @@ -630,7 +706,6 @@ def union(self, rdds): rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): @@ -658,7 +733,7 @@ def accumulator(self, value, accum_param=None): elif isinstance(value, complex): accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM else: - raise Exception("No default accumulator param for type %s" % type(value)) + raise TypeError("No default accumulator param for type %s" % type(value)) SparkContext._next_accum_id += 1 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) @@ -676,7 +751,7 @@ def addFile(self, path): >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: - ... testFile.write("100") + ... _ = testFile.write("100") >>> sc.addFile(path) >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: @@ -704,11 +779,13 @@ def addPyFile(self, path): """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix - - if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() def setCheckpointDir(self, dirName): """ @@ -743,7 +820,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): The application can use L{SparkContext.cancelJobGroup} to cancel all running jobs in this group. - >>> import thread, threading + >>> import threading >>> from time import sleep >>> result = "Not Set" >>> lock = threading.Lock() @@ -762,10 +839,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") >>> supress = lock.acquire() - >>> supress = thread.start_new_thread(start_job, (10,)) - >>> supress = thread.start_new_thread(stop_job, tuple()) + >>> supress = threading.Thread(target=start_job, args=(10,)).start() + >>> supress = threading.Thread(target=stop_job).start() >>> supress = lock.acquire() - >>> print result + >>> print(result) Cancelled If interruptOnCancel is set to true for the job group, then job cancellation will result @@ -808,6 +885,12 @@ def cancelAllJobs(self): """ self._jsc.sc().cancelAllJobs() + def statusTracker(self): + """ + Return :class:`StatusTracker` object + """ + return StatusTracker(self._jsc.statusTracker()) + def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ Executes the given partitionFunc on the specified set of partitions, @@ -825,14 +908,13 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ if partitions is None: partitions = range(rdd._jrdd.partitions().size()) - javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) - return list(mappedRDD._collect_iterator_through_file(it)) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index f09587f21170..7f06d4288c87 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -24,9 +24,10 @@ import traceback import time import gc -from errno import EINTR, ECHILD, EAGAIN +from errno import EINTR, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT + from pyspark.worker import main as worker_main from pyspark.serializers import read_int, write_int @@ -53,29 +54,21 @@ def worker(sock): # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) exit_code = 0 try: worker_main(infile, outfile) except SystemExit as exc: exit_code = compute_real_exit_code(exc.code) finally: - outfile.flush() + try: + outfile.flush() + except Exception: + pass return exit_code -# Cleanup zombie children -def cleanup_dead_children(): - try: - while True: - pid, _ = os.waitpid(0, os.WNOHANG) - if not pid: - break - except: - pass - - def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -85,8 +78,12 @@ def manager(): listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() - write_int(listen_port, sys.stdout) - sys.stdout.flush() + + # re-open stdin/stdout in 'wb' mode + stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4) + stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4) + write_int(listen_port, stdout_bin) + stdout_bin.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -98,6 +95,7 @@ def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + signal.signal(SIGCHLD, SIG_IGN) reuse = os.environ.get("SPARK_REUSE_WORKER") @@ -112,12 +110,9 @@ def handle_sigterm(*args): else: raise - # cleanup in signal handler will cause deadlock - cleanup_dead_children() - if 0 in ready_fds: try: - worker_pid = read_int(sys.stdin) + worker_pid = read_int(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) @@ -142,7 +137,7 @@ def handle_sigterm(*args): time.sleep(1) pid = os.fork() # error here will shutdown daemon else: - outfile = sock.makefile('w') + outfile = sock.makefile(mode='wb') write_int(e.errno, outfile) # Signal that the fork failed outfile.flush() outfile.close() @@ -154,7 +149,7 @@ def handle_sigterm(*args): listen_sock.close() try: # Acknowledge that the fork was successful - outfile = sock.makefile("w") + outfile = sock.makefile(mode="wb") write_int(os.getpid(), outfile) outfile.flush() outfile.close() diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index bc441f138f7f..b27e91a4cc25 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -627,51 +627,49 @@ def merge(iterables, key=None, reverse=False): if key is None: for order, it in enumerate(map(iter, iterables)): try: - next = it.next - h_append([next(), order * direction, next]) + h_append([next(it), order * direction, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - value, order, next = s = h[0] + value, order, it = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted + s[0] = next(it) # raises StopIteration when exhausted _heapreplace(h, s) # restore heap condition except StopIteration: _heappop(h) # remove empty iterator if h: # fast case when only a single iterator remains - value, order, next = h[0] + value, order, it = h[0] yield value - for value in next.__self__: + for value in it: yield value return for order, it in enumerate(map(iter, iterables)): try: - next = it.next - value = next() - h_append([key(value), order * direction, value, next]) + value = next(it) + h_append([key(value), order * direction, value, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - key_value, order, value, next = s = h[0] + key_value, order, value, it = s = h[0] yield value - value = next() + value = next(it) s[0] = key(value) s[2] = value _heapreplace(h, s) except StopIteration: _heappop(h) if h: - key_value, order, value, next = h[0] + key_value, order, value, it = h[0] yield value - for value in next.__self__: + for value in it: yield value @@ -885,6 +883,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": - import doctest - print(doctest.testmod()) + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a0a028446d5f..cd4c55f79f18 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -18,59 +18,80 @@ import atexit import os import sys +import select import signal import shlex +import socket import platform from subprocess import Popen, PIPE -from threading import Thread + +if sys.version >= '3': + xrange = range + from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_collections import ListConverter +from pyspark.serializers import read_int -def launch_gateway(): - SPARK_HOME = os.environ["SPARK_HOME"] - gateway_port = -1 +# patching ListConverter, or it will convert bytearray into Java ArrayList +def can_convert_list(self, obj): + return isinstance(obj, (list, tuple, xrange)) + +ListConverter.can_convert = can_convert_list + + +def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: + SPARK_HOME = os.environ["SPARK_HOME"] # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" - submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") - submit_args = submit_args if submit_args is not None else "" - submit_args = shlex.split(submit_args) - command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] + submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + if os.environ.get("SPARK_TESTING"): + submit_args = ' '.join([ + "--conf spark.ui.enabled=false", + submit_args + ]) + command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) + + # Start a socket that will be used by PythonGatewayServer to communicate its port to us + callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + callback_socket.bind(('127.0.0.1', 0)) + callback_socket.listen(1) + callback_host, callback_port = callback_socket.getsockname() + env = dict(os.environ) + env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host + env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - env = dict(os.environ) - env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) - - try: - # Determine which ephemeral port the server started on: - gateway_port = proc.stdout.readline() - gateway_port = int(gateway_port) - except ValueError: - # Grab the remaining lines of stdout - (stdout, _) = proc.communicate() - exit_code = proc.poll() - error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n" - error_msg += "Warning: Expected GatewayServer to output a port, but found " - if gateway_port == "" and stdout == "": - error_msg += "no output.\n" - else: - error_msg += "the following:\n\n" - error_msg += "--------------------------------------------------------------\n" - error_msg += gateway_port + stdout - error_msg += "--------------------------------------------------------------\n" - raise Exception(error_msg) + proc = Popen(command, stdin=PIPE, env=env) + + gateway_port = None + # We use select() here in order to avoid blocking indefinitely if the subprocess dies + # before connecting + while gateway_port is None and proc.poll() is None: + timeout = 1 # (seconds) + readable, _, _ = select.select([callback_socket], [], [], timeout) + if callback_socket in readable: + gateway_connection = callback_socket.accept()[0] + # Determine which ephemeral port the server started on: + gateway_port = read_int(gateway_connection.makefile(mode="rb")) + gateway_connection.close() + callback_socket.close() + if gateway_port is None: + raise Exception("Java gateway process exited before sending the driver its port number") # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -88,23 +109,8 @@ def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream - - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() - # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/join.py b/python/pyspark/join.py index b4a844713745..94df3990164d 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -32,11 +32,12 @@ """ from pyspark.resultiterable import ResultIterable +from functools import reduce def _do_python_join(rdd, other, numPartitions, dispatch): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) + vs = rdd.mapValues(lambda v: (1, v)) + ws = other.mapValues(lambda v: (2, v)) return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__())) @@ -48,7 +49,7 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -62,7 +63,7 @@ def dispatch(seq): wbuf.append(v) if not vbuf: vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -76,7 +77,7 @@ def dispatch(seq): wbuf.append(v) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -98,14 +99,15 @@ def dispatch(seq): def python_cogroup(rdds, numPartitions): def make_mapper(i): - return lambda (k, v): (k, (i, v)) - vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + return lambda v: (i, v) + vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)] union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) rdd_len = len(vrdds) def dispatch(seq): - bufs = [[] for i in range(rdd_len)] - for (n, v) in seq: + bufs = [[] for _ in range(rdd_len)] + for n, v in seq: bufs[n].append(v) - return tuple(map(ResultIterable, bufs)) + return tuple(ResultIterable(vs) for vs in bufs) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 47fed80f42e1..327a11b14b5a 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,7 +15,6 @@ # limitations under the License. # -from pyspark.ml.param import * -from pyspark.ml.pipeline import * +from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel -__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] +__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6bd2aa8e4783..83f808efc3bf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,48 +15,798 @@ # limitations under the License. # -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ - HasRegParam +from pyspark.ml.param.shared import * +from pyspark.ml.regression import ( + RandomForestParams, DecisionTreeModel, TreeEnsembleModels) +from pyspark.mllib.common import inherit_doc -__all__ = ['LogisticRegression', 'LogisticRegressionModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', + 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', + 'NaiveBayesModel'] @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): """ Logistic regression. + Currently, this class only supports binary classification. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors - >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \ - Row(label=1.0, features=Vectors.dense(1.0)), \ - Row(label=0.0, features=Vectors.sparse(1, [], []))])) - >>> lr = LogisticRegression() \ - .setMaxIter(5) \ - .setRegParam(0.01) - >>> model = lr.fit(dataset) - >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) - >>> print model.transform(test0).head().prediction + >>> df = sc.parallelize([ + ... Row(label=1.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> model = lr.fit(df) + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + >>> result = model.transform(test0).head() + >>> result.prediction 0.0 - >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) - >>> print model.transform(test1).head().prediction + >>> result.probability + DenseVector([0.99..., 0.00...]) + >>> result.rawPrediction + DenseVector([8.22..., -8.22...]) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction 1.0 + >>> lr.setParams("vector") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.classification.LogisticRegression" + + # a placeholder to make it appear in the generated doc + elasticNetParam = \ + Param(Params._dummy(), "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + thresholds = Param(Params._dummy(), "thresholds", + "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original" + + " probability of that class and t is the class' threshold.") + threshold = Param(Params._dummy(), "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, thresholds=None, + probabilityCol="probability", rawPredictionCol="rawPrediction"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + threshold=0.5, thresholds=None, \ + probabilityCol="probability", rawPredictionCol="rawPrediction") + If the threshold and thresholds Params are both set, they must be equivalent. + """ + super(LogisticRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.LogisticRegression", self.uid) + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty + # is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = \ + Param(self, "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + #: param for threshold in binary classification, in range [0, 1]. + self.threshold = Param(self, "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + #: param for thresholds or cutoffs in binary or multiclass classification + self.thresholds = \ + Param(self, "thresholds", + "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original" + + " probability of that class and t is the class' threshold.") + self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, + fitIntercept=True, threshold=0.5) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + self._checkThresholdConsistency() + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, thresholds=None, + probabilityCol="probability", rawPredictionCol="rawPrediction"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + threshold=0.5, thresholds=None, \ + probabilityCol="probability", rawPredictionCol="rawPrediction") + Sets params for logistic regression. + If the threshold and thresholds Params are both set, they must be equivalent. + """ + kwargs = self.setParams._input_kwargs + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. + """ + self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] + return self + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. + """ + self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] + return self + + def getThresholds(self): + """ + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. + """ + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) + + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) + class LogisticRegressionModel(JavaModel): """ Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + +class TreeClassifierParams(object): + """ + Private class to track supported impurity measures. + """ + supportedImpurities = ["entropy", "gini"] + + +class GBTParams(object): + """ + Private class to track supported GBT params. + """ + supportedLossTypes = ["logistic"] + + +@inherit_doc +class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, + HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` + learning algorithm for classification. + It supports both binary and multiclass labels, as well as both continuous and categorical + features. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") + >>> model = dt.fit(td) + >>> model.numNodes + 3 + >>> model.depth + 1 + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> result = model.transform(test0).head() + >>> result.prediction + 0.0 + >>> result.probability + DenseVector([1.0, 0.0]) + >>> result.rawPrediction + DenseVector([1.0, 0.0]) + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + """ + super(DecisionTreeClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid) + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + Sets params for the DecisionTreeClassifier. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return DecisionTreeClassificationModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + +@inherit_doc +class DecisionTreeClassificationModel(DecisionTreeModel): + """ + Model fitted by DecisionTreeClassifier. + """ + + +@inherit_doc +class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, + HasRawPredictionCol, HasProbabilityCol, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Random_forest Random Forest` + learning algorithm for classification. + It supports both binary and multiclass labels, as well as both continuous and categorical + features. + + >>> import numpy + >>> from numpy import allclose + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) + >>> model = rf.fit(td) + >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) + True + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> result = model.transform(test0).head() + >>> result.prediction + 0.0 + >>> numpy.argmax(result.probability) + 0 + >>> numpy.argmax(result.rawPrediction) + 0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + numTrees=20, featureSubsetStrategy="auto", seed=None): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + numTrees=20, featureSubsetStrategy="auto", seed=None) + """ + super(RandomForestClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + #: param for Fraction of the training data used for learning each decision tree, + # in range (0, 1] + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: param for Number of trees to train (>= 1) + self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") + #: param for The number of features to consider for splits at each tree node + self.featureSubsetStrategy = \ + Param(self, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + impurity="gini", numTrees=20, featureSubsetStrategy="auto") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + impurity="gini", numTrees=20, featureSubsetStrategy="auto"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ + impurity="gini", numTrees=20, featureSubsetStrategy="auto") + Sets params for linear classification. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return RandomForestClassificationModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setNumTrees(self, value): + """ + Sets the value of :py:attr:`numTrees`. + """ + self._paramMap[self.numTrees] = value + return self + + def getNumTrees(self): + """ + Gets the value of numTrees or its default value. + """ + return self.getOrDefault(self.numTrees) + + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + self._paramMap[self.featureSubsetStrategy] = value + return self + + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + + +class RandomForestClassificationModel(TreeEnsembleModels): + """ + Model fitted by RandomForestClassifier. + """ + + +@inherit_doc +class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` + learning algorithm for classification. + It supports binary labels, as well as both continuous and categorical features. + Note: Multiclass labels are not currently supported. + + >>> from numpy import allclose + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed") + >>> model = gbt.fit(td) + >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) + True + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + lossType = Param(Params._dummy(), "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + + "contribution of each estimator") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", + maxIter=20, stepSize=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="logistic", maxIter=20, stepSize=0.1) + """ + super(GBTClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.GBTClassifier", self.uid) + #: param for Loss function which GBT tries to minimize (case-insensitive). + self.lossType = Param(self, "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + #: Fraction of the training data used for learning each decision tree, in range (0, 1]. + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of + # each estimator + self.stepSize = Param(self, "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator") + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="logistic", maxIter=20, stepSize=0.1) + Sets params for Gradient Boosted Tree Classification. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GBTClassificationModel(java_model) + + def setLossType(self, value): + """ + Sets the value of :py:attr:`lossType`. + """ + self._paramMap[self.lossType] = value + return self + + def getLossType(self): + """ + Gets the value of lossType or its default value. + """ + return self.getOrDefault(self.lossType) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + self._paramMap[self.stepSize] = value + return self + + def getStepSize(self): + """ + Gets the value of stepSize or its default value. + """ + return self.getOrDefault(self.stepSize) + + +class GBTClassificationModel(TreeEnsembleModels): + """ + Model fitted by GBTClassifier. + """ + + +@inherit_doc +class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, + HasRawPredictionCol): + """ + Naive Bayes Classifiers. + It supports both Multinomial and Bernoulli NB. Multinomial NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + can handle finitely supported discrete data. For example, by converting documents into + TF-IDF vectors, it can be used for document classification. By making every vector a + binary (0/1) data, it can also be used as Bernoulli NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + The input feature values must be nonnegative. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + >>> model = nb.fit(df) + >>> model.pi + DenseVector([-0.51..., -0.91...]) + >>> model.theta + DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() + >>> result = model.transform(test0).head() + >>> result.prediction + 1.0 + >>> result.probability + DenseVector([0.42..., 0.57...]) + >>> result.rawPrediction + DenseVector([-1.60..., -1.32...]) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, + modelType="multinomial"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ + modelType="multinomial") + """ + super(NaiveBayes, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.NaiveBayes", self.uid) + #: param for the smoothing parameter. + self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + #: param for the model type. + self.modelType = Param(self, "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) " + + "and bernoulli.") + self._setDefault(smoothing=1.0, modelType="multinomial") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, + modelType="multinomial"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ + modelType="multinomial") + Sets params for Naive Bayes. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return NaiveBayesModel(java_model) + + def setSmoothing(self, value): + """ + Sets the value of :py:attr:`smoothing`. + """ + self._paramMap[self.smoothing] = value + return self + + def getSmoothing(self): + """ + Gets the value of smoothing or its default value. + """ + return self.getOrDefault(self.smoothing) + + def setModelType(self, value): + """ + Sets the value of :py:attr:`modelType`. + """ + self._paramMap[self.modelType] = value + return self + + def getModelType(self): + """ + Gets the value of modelType or its default value. + """ + return self.getOrDefault(self.modelType) + + +class NaiveBayesModel(JavaModel): + """ + Model fitted by NaiveBayes. + """ + + @property + def pi(self): + """ + log of class priors. + """ + return self._call_java("pi") + + @property + def theta(self): + """ + log of class conditional probabilities. + """ + return self._call_java("theta") + if __name__ == "__main__": import doctest @@ -65,10 +815,10 @@ class LogisticRegressionModel(JavaModel): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sc = SparkContext("local[2]", "ml.classification tests") + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py new file mode 100644 index 000000000000..cb4c16e25a7a --- /dev/null +++ b/python/pyspark/ml/clustering.py @@ -0,0 +1,171 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * +from pyspark.mllib.common import inherit_doc + +__all__ = ['KMeans', 'KMeansModel'] + + +class KMeansModel(JavaModel): + """ + Model fitted by KMeans. + """ + + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy arrays.""" + return [c.toArray() for c in self._call_java("clusterCenters")] + + +@inherit_doc +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): + """ + K-means clustering with support for multiple parallel runs and a k-means++ like initialization + mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + they are executed together with joint passes over the data for efficiency. + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> kmeans = KMeans(k=2, seed=1) + >>> model = kmeans.fit(df) + >>> centers = model.clusterCenters() + >>> len(centers) + 2 + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[0].prediction == rows[1].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + """ + + # a placeholder to make it appear in the generated doc + k = Param(Params._dummy(), "k", "number of clusters to create") + initMode = Param(Params._dummy(), "initMode", + "the initialization algorithm. This can be either \"random\" to " + + "choose random points as initial cluster centers, or \"k-means||\" " + + "to use a parallel variant of k-means++") + initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") + + @keyword_only + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + """ + super(KMeans, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) + self.k = Param(self, "k", "number of clusters to create") + self.initMode = Param(self, "initMode", + "the initialization algorithm. This can be either \"random\" to " + + "choose random points as initial cluster centers, or \"k-means||\" " + + "to use a parallel variant of k-means++") + self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") + self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def _create_model(self, java_model): + return KMeansModel(java_model) + + @keyword_only + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + + Sets params for KMeans. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + + >>> algo = KMeans().setK(10) + >>> algo.getK() + 10 + """ + self._paramMap[self.k] = value + return self + + def getK(self): + """ + Gets the value of `k` + """ + return self.getOrDefault(self.k) + + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + + >>> algo = KMeans() + >>> algo.getInitMode() + 'k-means||' + >>> algo = algo.setInitMode("random") + >>> algo.getInitMode() + 'random' + """ + self._paramMap[self.initMode] = value + return self + + def getInitMode(self): + """ + Gets the value of `initMode` + """ + return self.getOrDefault(self.initMode) + + def setInitSteps(self, value): + """ + Sets the value of :py:attr:`initSteps`. + + >>> algo = KMeans().setInitSteps(10) + >>> algo.getInitSteps() + 10 + """ + self._paramMap[self.initSteps] = value + return self + + def getInitSteps(self): + """ + Gets the value of `initSteps` + """ + return self.getOrDefault(self.initSteps) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.clustering tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py new file mode 100644 index 000000000000..6b0a9ffde9f4 --- /dev/null +++ b/python/pyspark/ml/evaluation.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import abstractmethod, ABCMeta + +from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol +from pyspark.ml.util import keyword_only +from pyspark.mllib.common import inherit_doc + +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', + 'MulticlassClassificationEvaluator'] + + +@inherit_doc +class Evaluator(Params): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _evaluate(self, dataset): + """ + Evaluates the output. + + :param dataset: a dataset that contains labels/observations and + predictions + :return: metric + """ + raise NotImplementedError() + + def evaluate(self, dataset, params=None): + """ + Evaluates the output with optional parameters. + + :param dataset: a dataset that contains labels/observations and + predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + if params is None: + params = dict() + if isinstance(params, dict): + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _evaluate(self, dataset): + """ + Evaluates the output. + :param dataset: a dataset that contains labels/observations and predictions. + :return: evaluation metric + """ + self._transfer_params_to_java() + return self._java_obj.evaluate(dataset._jdf) + + +@inherit_doc +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): + """ + Evaluator for binary classification, which expects two input + columns: rawPrediction and label. + + >>> from pyspark.mllib.linalg import Vectors + >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]), + ... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)]) + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + ... + >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw") + >>> evaluator.evaluate(dataset) + 0.70... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) + 0.83... + """ + + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)") + + @keyword_only + def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC"): + """ + __init__(self, rawPredictionCol="rawPrediction", labelCol="label", \ + metricName="areaUnderROC") + """ + super(BinaryClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) + #: param for metric name in evaluation (areaUnderROC|areaUnderPR) + self.metricName = Param(self, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)") + self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC"): + """ + setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \ + metricName="areaUnderROC") + Sets params for binary classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +@inherit_doc +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Regression, which expects two input + columns: prediction and label. + + >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), + ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + ... + >>> evaluator = RegressionEvaluator(predictionCol="raw") + >>> evaluator.evaluate(dataset) + 2.842... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) + 0.993... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) + 2.649... + """ + # Because we will maximize evaluation value (ref: `CrossValidator`), + # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + # we take and output the negative of this metric. + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + """ + super(RegressionEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) + #: param for metric name in evaluation (mse|rmse|r2|mae) + self.metricName = Param(self, "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="rmse") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + Sets params for regression evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +@inherit_doc +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Multiclass Classification, which expects two input + columns: prediction and label. + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) + ... + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") + >>> evaluator.evaluate(dataset) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + 0.66... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation " + "(f1|precision|recall|weightedPrecision|weightedRecall)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + """ + super(MulticlassClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) + # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) + self.metricName = Param(self, "metricName", + "metric name in evaluation" + " (f1|precision|recall|weightedPrecision|weightedRecall)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="f1") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + Sets params for multiclass classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.evaluation tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index e088acd0ca82..04b2b2ccc9e5 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,68 +15,1298 @@ # limitations under the License. # -from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures -from pyspark.ml.util import inherit_doc -from pyspark.ml.wrapper import JavaTransformer +import sys +if sys.version > '3': + basestring = str -__all__ = ['Tokenizer', 'HashingTF'] +from pyspark.rdd import ignore_unicode_prefix +from pyspark.ml.param.shared import * +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer +from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector + +__all__ = ['Binarizer', 'Bucketizer', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', + 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', + 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', + 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', + 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + + +@inherit_doc +class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + Binarize a column of continuous features given a threshold. + + >>> df = sqlContext.createDataFrame([(0.5,)], ["values"]) + >>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features") + >>> binarizer.transform(df).head().features + 0.0 + >>> binarizer.setParams(outputCol="freqs").transform(df).head().freqs + 0.0 + >>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"} + >>> binarizer.transform(df, params).head().vector + 1.0 + """ + + # a placeholder to make it appear in the generated doc + threshold = Param(Params._dummy(), "threshold", + "threshold in binary classification prediction, in range [0, 1]") + + @keyword_only + def __init__(self, threshold=0.0, inputCol=None, outputCol=None): + """ + __init__(self, threshold=0.0, inputCol=None, outputCol=None) + """ + super(Binarizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) + self.threshold = Param(self, "threshold", + "threshold in binary classification prediction, in range [0, 1]") + self._setDefault(threshold=0.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, threshold=0.0, inputCol=None, outputCol=None): + """ + setParams(self, threshold=0.0, inputCol=None, outputCol=None) + Sets params for this Binarizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + self._paramMap[self.threshold] = value + return self + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.threshold) + + +@inherit_doc +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + Maps a column of continuous features to a column of feature buckets. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], + ... inputCol="values", outputCol="buckets") + >>> bucketed = bucketizer.transform(df).collect() + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 0.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 + >>> bucketizer.setParams(outputCol="b").transform(df).head().b + 0.0 + """ + + # a placeholder to make it appear in the generated doc + splits = \ + Param(Params._dummy(), "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, " + + "there are n buckets. A bucket defined by splits x,y holds values in the " + + "range [x,y) except the last bucket, which also includes y. The splits " + + "should be strictly increasing. Values at -inf, inf must be explicitly " + + "provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.") + + @keyword_only + def __init__(self, splits=None, inputCol=None, outputCol=None): + """ + __init__(self, splits=None, inputCol=None, outputCol=None) + """ + super(Bucketizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) + #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, + # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) + # except the last bucket, which also includes y. The splits should be strictly increasing. + # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, + # values outside the splits specified will be treated as errors. + self.splits = \ + Param(self, "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, " + + "there are n buckets. A bucket defined by splits x,y holds values in the " + + "range [x,y) except the last bucket, which also includes y. The splits " + + "should be strictly increasing. Values at -inf, inf must be explicitly " + + "provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, splits=None, inputCol=None, outputCol=None): + """ + setParams(self, splits=None, inputCol=None, outputCol=None) + Sets params for this Bucketizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setSplits(self, value): + """ + Sets the value of :py:attr:`splits`. + """ + self._paramMap[self.splits] = value + return self + + def getSplits(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.splits) + + +@inherit_doc +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): + """ + Outputs the Hadamard product (i.e., the element-wise product) of each input vector + with a provided "weight" vector. In other words, it scales each column of the dataset + by a scalar multiplier. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"]) + >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]), + ... inputCol="values", outputCol="eprod") + >>> ep.transform(df).head().eprod + DenseVector([2.0, 2.0, 9.0]) + >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod + DenseVector([4.0, 3.0, 15.0]) + """ + + # a placeholder to make it appear in the generated doc + scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + + "it must be MLlib Vector type.") + + @keyword_only + def __init__(self, scalingVec=None, inputCol=None, outputCol=None): + """ + __init__(self, scalingVec=None, inputCol=None, outputCol=None) + """ + super(ElementwiseProduct, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", + self.uid) + self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " + + "it must be MLlib Vector type.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, scalingVec=None, inputCol=None, outputCol=None): + """ + setParams(self, scalingVec=None, inputCol=None, outputCol=None) + Sets params for this ElementwiseProduct. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setScalingVec(self, value): + """ + Sets the value of :py:attr:`scalingVec`. + """ + self._paramMap[self.scalingVec] = value + return self + + def getScalingVec(self): + """ + Gets the value of scalingVec or its default value. + """ + return self.getOrDefault(self.scalingVec) + + +@inherit_doc +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): + """ + Maps a sequence of terms to their term frequencies using the + hashing trick. + + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["words"]) + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> hashingTF.transform(df).head().features + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} + >>> hashingTF.transform(df, params).head().vector + SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) + """ + + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) + """ + super(HashingTF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) + self._setDefault(numFeatures=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) + Sets params for this HashingTF. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +@inherit_doc +class IDF(JavaEstimator, HasInputCol, HasOutputCol): + """ + Compute the Inverse Document Frequency (IDF) given a collection of documents. + + >>> from pyspark.mllib.linalg import DenseVector + >>> df = sqlContext.createDataFrame([(DenseVector([1.0, 2.0]),), + ... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"]) + >>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf") + >>> idf.fit(df).transform(df).head().idf + DenseVector([0.0, 0.0]) + >>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqs + DenseVector([0.0, 0.0]) + >>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"} + >>> idf.fit(df, params).transform(df).head().vector + DenseVector([0.2877, 0.0]) + """ + + # a placeholder to make it appear in the generated doc + minDocFreq = Param(Params._dummy(), "minDocFreq", + "minimum of documents in which a term should appear for filtering") + + @keyword_only + def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): + """ + __init__(self, minDocFreq=0, inputCol=None, outputCol=None) + """ + super(IDF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) + self.minDocFreq = Param(self, "minDocFreq", + "minimum of documents in which a term should appear for filtering") + self._setDefault(minDocFreq=0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): + """ + setParams(self, minDocFreq=0, inputCol=None, outputCol=None) + Sets params for this IDF. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinDocFreq(self, value): + """ + Sets the value of :py:attr:`minDocFreq`. + """ + self._paramMap[self.minDocFreq] = value + return self + + def getMinDocFreq(self): + """ + Gets the value of minDocFreq or its default value. + """ + return self.getOrDefault(self.minDocFreq) + + def _create_model(self, java_model): + return IDFModel(java_model) + + +class IDFModel(JavaModel): + """ + Model fitted by IDF. + """ + + +@inherit_doc +@ignore_unicode_prefix +class NGram(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that converts the input array of strings into an array of n-grams. Null + values in the input array are ignored. + It returns an array of n-grams where each n-gram is represented by a space-separated string of + words. + When the input is empty, an empty array is returned. + When the input array length is less than n (number of elements per n-gram), no n-grams are + returned. + + >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + >>> # Change n-gram length + >>> ngram.setParams(n=4).transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Temporarily modify output column. + >>> ngram.transform(df, {ngram.outputCol: "output"}).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Must use keyword arguments to specify params. + >>> ngram.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + + @keyword_only + def __init__(self, n=2, inputCol=None, outputCol=None): + """ + __init__(self, n=2, inputCol=None, outputCol=None) + """ + super(NGram, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) + self.n = Param(self, "n", "number of elements per n-gram (>=1)") + self._setDefault(n=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, n=2, inputCol=None, outputCol=None): + """ + setParams(self, n=2, inputCol=None, outputCol=None) + Sets params for this NGram. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setN(self, value): + """ + Sets the value of :py:attr:`n`. + """ + self._paramMap[self.n] = value + return self + + def getN(self): + """ + Gets the value of n or its default value. + """ + return self.getOrDefault(self.n) + + +@inherit_doc +class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + Normalize a vector to have unit norm using the given p-norm. + + >>> from pyspark.mllib.linalg import Vectors + >>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0}) + >>> df = sqlContext.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense", "sparse"]) + >>> normalizer = Normalizer(p=2.0, inputCol="dense", outputCol="features") + >>> normalizer.transform(df).head().features + DenseVector([0.6, -0.8]) + >>> normalizer.setParams(inputCol="sparse", outputCol="freqs").transform(df).head().freqs + SparseVector(4, {1: 0.8, 3: 0.6}) + >>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.outputCol: "vector"} + >>> normalizer.transform(df, params).head().vector + DenseVector([0.4286, -0.5714]) + """ + + # a placeholder to make it appear in the generated doc + p = Param(Params._dummy(), "p", "the p norm value.") + + @keyword_only + def __init__(self, p=2.0, inputCol=None, outputCol=None): + """ + __init__(self, p=2.0, inputCol=None, outputCol=None) + """ + super(Normalizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) + self.p = Param(self, "p", "the p norm value.") + self._setDefault(p=2.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, p=2.0, inputCol=None, outputCol=None): + """ + setParams(self, p=2.0, inputCol=None, outputCol=None) + Sets params for this Normalizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setP(self, value): + """ + Sets the value of :py:attr:`p`. + """ + self._paramMap[self.p] = value + return self + + def getP(self): + """ + Gets the value of p or its default value. + """ + return self.getOrDefault(self.p) + + +@inherit_doc +class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): + """ + A one-hot encoder that maps a column of category indices to a + column of binary vectors, with at most a single one-value per row + that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to + an output vector of `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via + :py:attr:`dropLast`) because it makes the vector entries sum up to + one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + Note that this is different from scikit-learn's OneHotEncoder, + which keeps all categories. + The output vectors are sparse. + + .. seealso:: + + :py:class:`StringIndexer` for converting categorical values into + category indices + + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> model = stringIndexer.fit(stringIndDf) + >>> td = model.transform(stringIndDf) + >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") + >>> encoder.transform(td).head().features + SparseVector(2, {0: 1.0}) + >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs + SparseVector(2, {0: 1.0}) + >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} + >>> encoder.transform(td, params).head().test + SparseVector(3, {0: 1.0}) + """ + + # a placeholder to make it appear in the generated doc + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") + + @keyword_only + def __init__(self, dropLast=True, inputCol=None, outputCol=None): + """ + __init__(self, includeFirst=True, inputCol=None, outputCol=None) + """ + super(OneHotEncoder, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) + self.dropLast = Param(self, "dropLast", "whether to drop the last category") + self._setDefault(dropLast=True) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, dropLast=True, inputCol=None, outputCol=None): + """ + setParams(self, dropLast=True, inputCol=None, outputCol=None) + Sets params for this OneHotEncoder. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + self._paramMap[self.dropLast] = value + return self + + def getDropLast(self): + """ + Gets the value of dropLast or its default value. + """ + return self.getOrDefault(self.dropLast) + + +@inherit_doc +class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): + """ + Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, + which is available at `http://en.wikipedia.org/wiki/Polynomial_expansion`, "In mathematics, an + expansion of a product of sums expresses it as a sum of products by using the fact that + multiplication distributes over addition". Take a 2-variable feature vector as an example: + `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"]) + >>> px = PolynomialExpansion(degree=2, inputCol="dense", outputCol="expanded") + >>> px.transform(df).head().expanded + DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) + >>> px.setParams(outputCol="test").transform(df).head().test + DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) + """ + + # a placeholder to make it appear in the generated doc + degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") + + @keyword_only + def __init__(self, degree=2, inputCol=None, outputCol=None): + """ + __init__(self, degree=2, inputCol=None, outputCol=None) + """ + super(PolynomialExpansion, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) + self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") + self._setDefault(degree=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, degree=2, inputCol=None, outputCol=None): + """ + setParams(self, degree=2, inputCol=None, outputCol=None) + Sets params for this PolynomialExpansion. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setDegree(self, value): + """ + Sets the value of :py:attr:`degree`. + """ + self._paramMap[self.degree] = value + return self + + def getDegree(self): + """ + Gets the value of degree or its default value. + """ + return self.getOrDefault(self.degree) @inherit_doc +@ignore_unicode_prefix +class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + A regex based tokenizer that extracts tokens either by using the + provided regex pattern (in Java dialect) to split the text + (default) or repeatedly matching the regex (if gaps is false). + Optional parameters also allow filtering tokens using a minimal + length. + It returns an array of strings that can be empty. + + >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") + >>> reTokenizer.transform(df).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> # Change a parameter. + >>> reTokenizer.setParams(outputCol="tokens").transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> reTokenizer.transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> reTokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") + gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") + pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") + + @keyword_only + def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): + """ + __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) + """ + super(RegexTokenizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) + self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") + self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") + self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") + self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): + """ + setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) + Sets params for this RegexTokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinTokenLength(self, value): + """ + Sets the value of :py:attr:`minTokenLength`. + """ + self._paramMap[self.minTokenLength] = value + return self + + def getMinTokenLength(self): + """ + Gets the value of minTokenLength or its default value. + """ + return self.getOrDefault(self.minTokenLength) + + def setGaps(self, value): + """ + Sets the value of :py:attr:`gaps`. + """ + self._paramMap[self.gaps] = value + return self + + def getGaps(self): + """ + Gets the value of gaps or its default value. + """ + return self.getOrDefault(self.gaps) + + def setPattern(self, value): + """ + Sets the value of :py:attr:`pattern`. + """ + self._paramMap[self.pattern] = value + return self + + def getPattern(self): + """ + Gets the value of pattern or its default value. + """ + return self.getOrDefault(self.pattern) + + +@inherit_doc +class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + Standardizes features by removing the mean and scaling to unit variance using column summary + statistics on the samples in the training set. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled") + >>> model = standardScaler.fit(df) + >>> model.mean + DenseVector([1.0]) + >>> model.std + DenseVector([1.4142]) + >>> model.transform(df).collect()[1].scaled + DenseVector([1.4142]) + """ + + # a placeholder to make it appear in the generated doc + withMean = Param(Params._dummy(), "withMean", "Center data with mean") + withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") + + @keyword_only + def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): + """ + __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None) + """ + super(StandardScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) + self.withMean = Param(self, "withMean", "Center data with mean") + self.withStd = Param(self, "withStd", "Scale to unit standard deviation") + self._setDefault(withMean=False, withStd=True) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None): + """ + setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) + Sets params for this StandardScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setWithMean(self, value): + """ + Sets the value of :py:attr:`withMean`. + """ + self._paramMap[self.withMean] = value + return self + + def getWithMean(self): + """ + Gets the value of withMean or its default value. + """ + return self.getOrDefault(self.withMean) + + def setWithStd(self, value): + """ + Sets the value of :py:attr:`withStd`. + """ + self._paramMap[self.withStd] = value + return self + + def getWithStd(self): + """ + Gets the value of withStd or its default value. + """ + return self.getOrDefault(self.withStd) + + def _create_model(self, java_model): + return StandardScalerModel(java_model) + + +class StandardScalerModel(JavaModel): + """ + Model fitted by StandardScaler. + """ + + @property + def std(self): + """ + Standard deviation of the StandardScalerModel. + """ + return self._call_java("std") + + @property + def mean(self): + """ + Mean of the StandardScalerModel. + """ + return self._call_java("mean") + + +@inherit_doc +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): + """ + A label indexer that maps a string column of labels to an ML column of label indices. + If the input column is numeric, we cast it to string and index the string values. + The indices are in [0, numLabels), ordered by label frequencies. + So the most frequent label gets index 0. + + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> model = stringIndexer.fit(stringIndDf) + >>> td = model.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] + """ + + @keyword_only + def __init__(self, inputCol=None, outputCol=None): + """ + __init__(self, inputCol=None, outputCol=None) + """ + super(StringIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None): + """ + setParams(self, inputCol=None, outputCol=None) + Sets params for this StringIndexer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return StringIndexerModel(java_model) + + +class StringIndexerModel(JavaModel): + """ + Model fitted by StringIndexer. + """ + + +@inherit_doc +@ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A tokenizer that converts the input string to lowercase and then splits it by white spaces. - >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")])) - >>> tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - >>> print tokenizer.transform(dataset).head() + >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> tokenizer.transform(df).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> # Change a parameter. + >>> tokenizer.setParams(outputCol="tokens").transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() + >>> tokenizer.transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> tokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.Tokenizer" + @keyword_only + def __init__(self, inputCol=None, outputCol=None): + """ + __init__(self, inputCol=None, outputCol=None) + """ + super(Tokenizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this Tokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) @inherit_doc -class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): """ - Maps a sequence of terms to their term frequencies using the - hashing trick. + A feature transformer that merges multiple columns into a vector column. - >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])])) - >>> hashingTF = HashingTF() \ - .setNumFeatures(10) \ - .setInputCol("words") \ - .setOutputCol("features") - >>> print hashingTF.transform(dataset).head().features - (10,[7,8,9],[1.0,1.0,1.0]) - >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(dataset, params).head().vector - (5,[2,3,4],[1.0,1.0,1.0]) + >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) + >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") + >>> vecAssembler.transform(df).head().features + DenseVector([1.0, 0.0, 3.0]) + >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs + DenseVector([1.0, 0.0, 3.0]) + >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} + >>> vecAssembler.transform(df, params).head().vector + DenseVector([0.0, 1.0]) + """ + + @keyword_only + def __init__(self, inputCols=None, outputCol=None): + """ + __init__(self, inputCols=None, outputCol=None) + """ + super(VectorAssembler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCols=None, outputCol=None): + """ + setParams(self, inputCols=None, outputCol=None) + Sets params for this VectorAssembler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +@inherit_doc +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ + Class for indexing categorical feature columns in a dataset of [[Vector]]. + + This has 2 usage modes: + - Automatically identify categorical features (default behavior) + - This helps process a dataset of unknown vectors into a dataset with some continuous + features and some categorical features. The choice between continuous and categorical + is based upon a maxCategories parameter. + - Set maxCategories to the maximum number of categorical any categorical feature should + have. + - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + and feature 1 will be declared continuous. + - Index all features, if all features are categorical + - If maxCategories is set to be very large, then this will build an index of unique + values for all features. + - Warning: This can cause problems if features are continuous since this will collect ALL + unique values to the driver. + - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + If maxCategories >= 3, then both features will be declared categorical. + + This returns a model which can transform categorical features to use 0-based indices. + + Index stability: + - This is not guaranteed to choose the same category index across multiple runs. + - If a categorical feature includes value 0, then this is guaranteed to map value 0 to + index 0. This maintains vector sparsity. + - More stability may be added in the future. + + TODO: Future extensions: The following functionality is planned for the future: + - Preserve metadata in transform; if a feature's metadata is already present, + do not recompute. + - Specify certain features to not index, either via a parameter or via existing metadata. + - Add warning if a categorical feature has only 1 category. + - Add option for allowing unknown categories. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([-1.0, 0.0]),), + ... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"]) + >>> indexer = VectorIndexer(maxCategories=2, inputCol="a", outputCol="indexed") + >>> model = indexer.fit(df) + >>> model.transform(df).head().indexed + DenseVector([1.0, 0.0]) + >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test + DenseVector([0.0, 1.0]) + >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"} + >>> model2 = indexer.fit(df, params) + >>> model2.transform(df).head().vector + DenseVector([1.0, 0.0]) + """ + + # a placeholder to make it appear in the generated doc + maxCategories = Param(Params._dummy(), "maxCategories", + "Threshold for the number of values a categorical feature can take " + + "(>= 2). If a feature is found to have > maxCategories values, then " + + "it is declared continuous.") + + @keyword_only + def __init__(self, maxCategories=20, inputCol=None, outputCol=None): + """ + __init__(self, maxCategories=20, inputCol=None, outputCol=None) + """ + super(VectorIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) + self.maxCategories = Param(self, "maxCategories", + "Threshold for the number of values a categorical feature " + + "can take (>= 2). If a feature is found to have " + + "> maxCategories values, then it is declared continuous.") + self._setDefault(maxCategories=20) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, maxCategories=20, inputCol=None, outputCol=None): + """ + setParams(self, maxCategories=20, inputCol=None, outputCol=None) + Sets params for this VectorIndexer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMaxCategories(self, value): + """ + Sets the value of :py:attr:`maxCategories`. + """ + self._paramMap[self.maxCategories] = value + return self + + def getMaxCategories(self): + """ + Gets the value of maxCategories or its default value. + """ + return self.getOrDefault(self.maxCategories) - _java_class = "org.apache.spark.ml.feature.HashingTF" + def _create_model(self, java_model): + return VectorIndexerModel(java_model) + + +class VectorIndexerModel(JavaModel): + """ + Model fitted by VectorIndexer. + """ + + +@inherit_doc +@ignore_unicode_prefix +class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): + """ + Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further + natural language processing or machine learning process. + + >>> sent = ("a b " * 100 + "a c " * 10).split(" ") + >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) + >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> model.getVectors().show() + +----+--------------------+ + |word| vector| + +----+--------------------+ + | a|[-0.3511952459812...| + | b|[0.29077222943305...| + | c|[0.02315592765808...| + +----+--------------------+ + ... + >>> model.findSynonyms("a", 2).show() + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b|0.29255685145799626| + | c|-0.5414068302988307| + +----+-------------------+ + ... + >>> model.transform(doc).head().model + DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) + """ + + # a placeholder to make it appear in the generated doc + vectorSize = Param(Params._dummy(), "vectorSize", + "the dimension of codes after transforming from words") + numPartitions = Param(Params._dummy(), "numPartitions", + "number of partitions for sentences of words") + minCount = Param(Params._dummy(), "minCount", + "the minimum number of times a token must appear to be included in the " + + "word2vec model's vocabulary") + + @keyword_only + def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, + seed=None, inputCol=None, outputCol=None): + """ + __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \ + seed=None, inputCol=None, outputCol=None) + """ + super(Word2Vec, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) + self.vectorSize = Param(self, "vectorSize", + "the dimension of codes after transforming from words") + self.numPartitions = Param(self, "numPartitions", + "number of partitions for sentences of words") + self.minCount = Param(self, "minCount", + "the minimum number of times a token must appear to be included " + + "in the word2vec model's vocabulary") + self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, + seed=None) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, + seed=None, inputCol=None, outputCol=None): + """ + setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \ + inputCol=None, outputCol=None) + Sets params for this Word2Vec. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setVectorSize(self, value): + """ + Sets the value of :py:attr:`vectorSize`. + """ + self._paramMap[self.vectorSize] = value + return self + + def getVectorSize(self): + """ + Gets the value of vectorSize or its default value. + """ + return self.getOrDefault(self.vectorSize) + + def setNumPartitions(self, value): + """ + Sets the value of :py:attr:`numPartitions`. + """ + self._paramMap[self.numPartitions] = value + return self + + def getNumPartitions(self): + """ + Gets the value of numPartitions or its default value. + """ + return self.getOrDefault(self.numPartitions) + + def setMinCount(self, value): + """ + Sets the value of :py:attr:`minCount`. + """ + self._paramMap[self.minCount] = value + return self + + def getMinCount(self): + """ + Gets the value of minCount or its default value. + """ + return self.getOrDefault(self.minCount) + + def _create_model(self, java_model): + return Word2VecModel(java_model) + + +class Word2VecModel(JavaModel): + """ + Model fitted by Word2Vec. + """ + + def getVectors(self): + """ + Returns the vector representation of the words as a dataframe + with two fields, word and vector. + """ + return self._call_java("getVectors") + + def findSynonyms(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns a dataframe with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + return self._call_java("findSynonyms", word, num) + + +@inherit_doc +class PCA(JavaEstimator, HasInputCol, HasOutputCol): + """ + PCA trains a model to project vectors to a low-dimensional space using PCA. + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + ... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + ... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + >>> df = sqlContext.createDataFrame(data,["features"]) + >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features") + >>> model = pca.fit(df) + >>> model.transform(df).collect()[0].pca_features + DenseVector([1.648..., -4.013...]) + """ + + # a placeholder to make it appear in the generated doc + k = Param(Params._dummy(), "k", "the number of principal components") + + @keyword_only + def __init__(self, k=None, inputCol=None, outputCol=None): + """ + __init__(self, k=None, inputCol=None, outputCol=None) + """ + super(PCA, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) + self.k = Param(self, "k", "the number of principal components") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, k=None, inputCol=None, outputCol=None): + """ + setParams(self, k=None, inputCol=None, outputCol=None) + Set params for this PCA. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + self._paramMap[self.k] = value + return self + + def getK(self): + """ + Gets the value of k or its default value. + """ + return self.getOrDefault(self.k) + + def _create_model(self, java_model): + return PCAModel(java_model) + + +class PCAModel(JavaModel): + """ + Model fitted by PCA. + """ + + +@inherit_doc +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): + """ + .. note:: Experimental + + Implements the transforms required for fitting a dataset against an + R model formula. Currently we support a limited subset of the R + operators, including '~', '+', '-', and '.'. Also see the R formula + docs: + http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + + >>> df = sqlContext.createDataFrame([ + ... (1.0, 1.0, "a"), + ... (0.0, 2.0, "b"), + ... (0.0, 0.0, "a") + ... ], ["y", "x", "s"]) + >>> rf = RFormula(formula="y ~ x + s") + >>> rf.fit(df).transform(df).show() + +---+---+---+---------+-----+ + | y| x| s| features|label| + +---+---+---+---------+-----+ + |1.0|1.0| a|[1.0,1.0]| 1.0| + |0.0|2.0| b|[2.0,0.0]| 0.0| + |0.0|0.0| a|[0.0,1.0]| 0.0| + +---+---+---+---------+-----+ + ... + >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show() + +---+---+---+--------+-----+ + | y| x| s|features|label| + +---+---+---+--------+-----+ + |1.0|1.0| a| [1.0]| 1.0| + |0.0|2.0| b| [2.0]| 0.0| + |0.0|0.0| a| [0.0]| 0.0| + +---+---+---+--------+-----+ + ... + """ + + # a placeholder to make it appear in the generated doc + formula = Param(Params._dummy(), "formula", "R model formula") + + @keyword_only + def __init__(self, formula=None, featuresCol="features", labelCol="label"): + """ + __init__(self, formula=None, featuresCol="features", labelCol="label") + """ + super(RFormula, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self.formula = Param(self, "formula", "R model formula") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, formula=None, featuresCol="features", labelCol="label"): + """ + setParams(self, formula=None, featuresCol="features", labelCol="label") + Sets params for RFormula. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setFormula(self, value): + """ + Sets the value of :py:attr:`formula`. + """ + self._paramMap[self.formula] = value + return self + + def getFormula(self): + """ + Gets the value of :py:attr:`formula`. + """ + return self.getOrDefault(self.formula) + + def _create_model(self, java_model): + return RFormulaModel(java_model) + + +class RFormulaModel(JavaModel): + """ + Model fitted by :py:class:`RFormula`. + """ if __name__ == "__main__": import doctest from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import Row, SQLContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + globs['sqlContext'] = sqlContext + testData = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="b"), + Row(id=2, label="c"), Row(id=3, label="a"), + Row(id=4, label="a"), Row(id=5, label="c")], 2) + globs['stringIndDf'] = sqlContext.createDataFrame(testData) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: exit(-1) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5566792cead4..eeeac49b2198 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -16,6 +16,7 @@ # from abc import ABCMeta +import copy from pyspark.ml.util import Identifiable @@ -25,23 +26,30 @@ class Param(object): """ - A param with self-contained documentation and optionally default value. + A param with self-contained documentation. """ - def __init__(self, parent, name, doc, defaultValue=None): + def __init__(self, parent, name, doc): if not isinstance(parent, Identifiable): - raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) - self.parent = parent + raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) + self.parent = parent.uid self.name = str(name) self.doc = str(doc) - self.defaultValue = defaultValue def __str__(self): - return str(self.parent) + "-" + self.name + return str(self.parent) + "__" + self.name def __repr__(self): - return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ - (self.parent, self.name, self.doc, self.defaultValue) + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + if isinstance(other, Param): + return self.parent == other.parent and self.name == other.name + else: + return False class Params(Identifiable): @@ -54,29 +62,195 @@ class Params(Identifiable): def __init__(self): super(Params, self).__init__() - #: embedded param map - self.paramMap = {} + #: internal param map for user-supplied values param map + self._paramMap = {} + + #: internal param map for default values + self._defaultParamMap = {} + + #: value returned by :py:func:`params` + self._params = None @property def params(self): """ - Returns all params. The default implementation uses - :py:func:`dir` to get all attributes of type + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ - return filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"]) + if self._params is None: + self._params = list(filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"])) + return self._params + + def explainParam(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self._defaultParamMap: + values.append("default: %s" % self._defaultParamMap[param]) + if param in self._paramMap: + values.append("current: %s" % self._paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self.explainParam(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self._paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self._defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has + a default value. + """ + return self.isSet(param) or self.hasDefault(param) + + def hasParam(self, paramName): + """ + Tests whether this instance contains a param with a given + (string) name. + """ + param = self._resolveParam(paramName) + return param in self.params + + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if neither is set. + """ + param = self._resolveParam(param) + if param in self._paramMap: + return self._paramMap[param] + else: + return self._defaultParamMap[param] - def _merge_params(self, params): - paramMap = self.paramMap.copy() - paramMap.update(params) + def extractParamMap(self, extra=None): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extra. + :param extra: extra param values + :return: merged param map + """ + if extra is None: + extra = dict() + paramMap = self._defaultParamMap.copy() + paramMap.update(self._paramMap) + paramMap.update(extra) return paramMap + def copy(self, extra=None): + """ + Creates a copy of this instance with the same uid and some + extra params. The default implementation creates a + shallow copy using :py:func:`copy.copy`, and then copies the + embedded and extra parameters over and returns the copy. + Subclasses should override this method if the default approach + is not sufficient. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() + that = copy.copy(self) + that._paramMap = self.extractParamMap(extra) + return that + + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if not (self.uid == param.parent and self.hasParam(param.name)): + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + @staticmethod def _dummy(): """ - Returns a dummy Params instance used as a placeholder to generate docs. + Returns a dummy Params instance used as a placeholder to + generate docs. """ dummy = Params() dummy.uid = "undefined" return dummy + + def _set(self, **kwargs): + """ + Sets user-supplied params. + """ + for param, value in kwargs.items(): + self._paramMap[getattr(self, param)] = value + return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.items(): + self._defaultParamMap[getattr(self, param)] = value + return self + + def _copyValues(self, to, extra=None): + """ + Copies param values from this instance to another instance for + params shared by them. + :param to: the target instance + :param extra: extra params to be copied + :return: the target instance with param values copied + """ + if extra is None: + extra = dict() + paramMap = self.extractParamMap(extra) + for p in self.params: + if p in paramMap and to.hasParam(p.name): + to._set(**{p.name: paramMap[p]}) + return to diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py deleted file mode 100644 index 5eb81106f116..000000000000 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ /dev/null @@ -1,98 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -header = """# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#""" - - -def _gen_param_code(name, doc, defaultValue): - """ - Generates Python code for a shared param class. - - :param name: param name - :param doc: param doc - :param defaultValue: string representation of the param - :return: code string - """ - # TODO: How to correctly inherit instance attributes? - template = '''class Has$Name(Params): - """ - Params with $name. - """ - - # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) - - def __init__(self): - super(Has$Name, self).__init__() - #: param for $doc - self.$name = Param(self, "$name", "$doc", $defaultValue) - - def set$Name(self, value): - """ - Sets the value of :py:attr:`$name`. - """ - self.paramMap[self.$name] = value - return self - - def get$Name(self): - """ - Gets the value of $name or its default value. - """ - if self.$name in self.paramMap: - return self.paramMap[self.$name] - else: - return self.$name.defaultValue''' - - upperCamelName = name[0].upper() + name[1:] - return template \ - .replace("$name", name) \ - .replace("$Name", upperCamelName) \ - .replace("$doc", doc) \ - .replace("$defaultValue", defaultValue) - -if __name__ == "__main__": - print header - print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" - print "from pyspark.ml.param import Param, Params\n\n" - shared = [ - ("maxIter", "max number of iterations", "100"), - ("regParam", "regularization constant", "0.1"), - ("featuresCol", "features column name", "'features'"), - ("labelCol", "label column name", "'label'"), - ("predictionCol", "prediction column name", "'prediction'"), - ("inputCol", "input column name", "'input'"), - ("outputCol", "output column name", "'output'"), - ("numFeatures", "number of features", "1 << 18")] - code = [] - for name, doc, defaultValue in shared: - code.append(_gen_param_code(name, doc, defaultValue)) - print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py new file mode 100644 index 000000000000..69efc424ec4e --- /dev/null +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +header = """# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#""" + +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py + + +def _gen_param_header(name, doc, defaultValueStr): + """ + Generates the header part for shared variables + + :param name: param name + :param doc: param doc + """ + template = '''class Has$Name(Params): + """ + Mixin for param $name: $doc. + """ + + # a placeholder to make it appear in the generated doc + $name = Param(Params._dummy(), "$name", "$doc") + + def __init__(self): + super(Has$Name, self).__init__() + #: param for $doc + self.$name = Param(self, "$name", "$doc")''' + if defaultValueStr is not None: + template += ''' + self._setDefault($name=$defaultValueStr)''' + + Name = name[0].upper() + name[1:] + return template \ + .replace("$name", name) \ + .replace("$Name", Name) \ + .replace("$doc", doc) \ + .replace("$defaultValueStr", str(defaultValueStr)) + + +def _gen_param_code(name, doc, defaultValueStr): + """ + Generates Python code for a shared param class. + + :param name: param name + :param doc: param doc + :param defaultValueStr: string representation of the default value + :return: code string + """ + # TODO: How to correctly inherit instance attributes? + template = ''' + def set$Name(self, value): + """ + Sets the value of :py:attr:`$name`. + """ + self._paramMap[self.$name] = value + return self + + def get$Name(self): + """ + Gets the value of $name or its default value. + """ + return self.getOrDefault(self.$name)''' + + Name = name[0].upper() + name[1:] + return template \ + .replace("$name", name) \ + .replace("$Name", Name) \ + .replace("$doc", doc) \ + .replace("$defaultValueStr", str(defaultValueStr)) + +if __name__ == "__main__": + print(header) + print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") + print("from pyspark.ml.param import Param, Params\n\n") + shared = [ + ("maxIter", "max number of iterations (>= 0)", None), + ("regParam", "regularization parameter (>= 0)", None), + ("featuresCol", "features column name", "'features'"), + ("labelCol", "label column name", "'label'"), + ("predictionCol", "prediction column name", "'prediction'"), + ("probabilityCol", "Column name for predicted class conditional probabilities. " + + "Note: Not all models output well-calibrated probability estimates! These probabilities " + + "should be treated as confidences, not precise probabilities.", "'probability'"), + ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"), + ("inputCol", "input column name", None), + ("inputCols", "input column names", None), + ("outputCol", "output column name", "self.uid + '__output'"), + ("numFeatures", "number of features", None), + ("checkpointInterval", "checkpoint interval (>= 1)", None), + ("seed", "random seed", "hash(type(self).__name__)"), + ("tol", "the convergence tolerance for iterative algorithms", None), + ("stepSize", "Step size to be used for each iteration of optimization.", None)] + code = [] + for name, doc, defaultValueStr in shared: + param_code = _gen_param_header(name, doc, defaultValueStr) + code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr)) + + decisionTreeParams = [ + ("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " + + "depth 1 means 1 internal node + 2 leaf nodes."), + ("maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature."), + ("minInstancesPerNode", "Minimum number of instances each child must have after split. " + + "If a split causes the left or right child to have fewer than minInstancesPerNode, the " + + "split will be discarded as invalid. Should be >= 1."), + ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."), + ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), + ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + + "Caching can speed up training of deeper trees.")] + + decisionTreeCode = '''class DecisionTreeParams(Params): + """ + Mixin for Decision Tree parameters. + """ + + # a placeholder to make it appear in the generated doc + $dummyPlaceHolders + + def __init__(self): + super(DecisionTreeParams, self).__init__() + $realParams''' + dtParamMethods = "" + dummyPlaceholders = "" + realParams = "" + paramTemplate = """$name = Param($owner, "$name", "$doc")""" + for name, doc in decisionTreeParams: + variable = paramTemplate.replace("$name", name).replace("$doc", doc) + dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " + realParams += "#: param for " + doc + "\n " + realParams += "self." + variable.replace("$owner", "self") + "\n " + dtParamMethods += _gen_param_code(name, doc, None) + "\n" + code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + .replace("$realParams", realParams) + dtParamMethods) + print("\n\n\n".join(code)) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 586822f2de42..595124726366 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -15,246 +15,527 @@ # limitations under the License. # -# DO NOT MODIFY. The code is generated by _gen_shared_params.py. +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. from pyspark.ml.param import Param, Params class HasMaxIter(Params): """ - Params with maxIter. + Mixin for param maxIter: max number of iterations (>= 0). """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0)") def __init__(self): super(HasMaxIter, self).__init__() - #: param for max number of iterations - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + #: param for max number of iterations (>= 0) + self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)") def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self.paramMap[self.maxIter] = value + self._paramMap[self.maxIter] = value return self def getMaxIter(self): """ Gets the value of maxIter or its default value. """ - if self.maxIter in self.paramMap: - return self.paramMap[self.maxIter] - else: - return self.maxIter.defaultValue + return self.getOrDefault(self.maxIter) class HasRegParam(Params): """ - Params with regParam. + Mixin for param regParam: regularization parameter (>= 0). """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0)") def __init__(self): super(HasRegParam, self).__init__() - #: param for regularization constant - self.regParam = Param(self, "regParam", "regularization constant", 0.1) + #: param for regularization parameter (>= 0) + self.regParam = Param(self, "regParam", "regularization parameter (>= 0)") def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self.paramMap[self.regParam] = value + self._paramMap[self.regParam] = value return self def getRegParam(self): """ Gets the value of regParam or its default value. """ - if self.regParam in self.paramMap: - return self.paramMap[self.regParam] - else: - return self.regParam.defaultValue + return self.getOrDefault(self.regParam) class HasFeaturesCol(Params): """ - Params with featuresCol. + Mixin for param featuresCol: features column name. """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + self.featuresCol = Param(self, "featuresCol", "features column name") + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self.paramMap[self.featuresCol] = value + self._paramMap[self.featuresCol] = value return self def getFeaturesCol(self): """ Gets the value of featuresCol or its default value. """ - if self.featuresCol in self.paramMap: - return self.paramMap[self.featuresCol] - else: - return self.featuresCol.defaultValue + return self.getOrDefault(self.featuresCol) class HasLabelCol(Params): """ - Params with labelCol. + Mixin for param labelCol: label column name. """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + labelCol = Param(Params._dummy(), "labelCol", "label column name") def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name", 'label') + self.labelCol = Param(self, "labelCol", "label column name") + self._setDefault(labelCol='label') def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self.paramMap[self.labelCol] = value + self._paramMap[self.labelCol] = value return self def getLabelCol(self): """ Gets the value of labelCol or its default value. """ - if self.labelCol in self.paramMap: - return self.paramMap[self.labelCol] - else: - return self.labelCol.defaultValue + return self.getOrDefault(self.labelCol) class HasPredictionCol(Params): """ - Params with predictionCol. + Mixin for param predictionCol: prediction column name. """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + self.predictionCol = Param(self, "predictionCol", "prediction column name") + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self.paramMap[self.predictionCol] = value + self._paramMap[self.predictionCol] = value return self def getPredictionCol(self): """ Gets the value of predictionCol or its default value. """ - if self.predictionCol in self.paramMap: - return self.paramMap[self.predictionCol] - else: - return self.predictionCol.defaultValue + return self.getOrDefault(self.predictionCol) + + +class HasProbabilityCol(Params): + """ + Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + """ + + # a placeholder to make it appear in the generated doc + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + + def __init__(self): + super(HasProbabilityCol, self).__init__() + #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. + self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + self._setDefault(probabilityCol='probability') + + def setProbabilityCol(self, value): + """ + Sets the value of :py:attr:`probabilityCol`. + """ + self._paramMap[self.probabilityCol] = value + return self + + def getProbabilityCol(self): + """ + Gets the value of probabilityCol or its default value. + """ + return self.getOrDefault(self.probabilityCol) + + +class HasRawPredictionCol(Params): + """ + Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. + """ + + # a placeholder to make it appear in the generated doc + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + def __init__(self): + super(HasRawPredictionCol, self).__init__() + #: param for raw prediction (a.k.a. confidence) column name + self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + self._setDefault(rawPredictionCol='rawPrediction') + + def setRawPredictionCol(self, value): + """ + Sets the value of :py:attr:`rawPredictionCol`. + """ + self._paramMap[self.rawPredictionCol] = value + return self + + def getRawPredictionCol(self): + """ + Gets the value of rawPredictionCol or its default value. + """ + return self.getOrDefault(self.rawPredictionCol) class HasInputCol(Params): """ - Params with inputCol. + Mixin for param inputCol: input column name. """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + inputCol = Param(Params._dummy(), "inputCol", "input column name") def __init__(self): super(HasInputCol, self).__init__() #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name", 'input') + self.inputCol = Param(self, "inputCol", "input column name") def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self.paramMap[self.inputCol] = value + self._paramMap[self.inputCol] = value return self def getInputCol(self): """ Gets the value of inputCol or its default value. """ - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - else: - return self.inputCol.defaultValue + return self.getOrDefault(self.inputCol) + + +class HasInputCols(Params): + """ + Mixin for param inputCols: input column names. + """ + + # a placeholder to make it appear in the generated doc + inputCols = Param(Params._dummy(), "inputCols", "input column names") + + def __init__(self): + super(HasInputCols, self).__init__() + #: param for input column names + self.inputCols = Param(self, "inputCols", "input column names") + + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + self._paramMap[self.inputCols] = value + return self + + def getInputCols(self): + """ + Gets the value of inputCols or its default value. + """ + return self.getOrDefault(self.inputCols) class HasOutputCol(Params): """ - Params with outputCol. + Mixin for param outputCol: output column name. """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + outputCol = Param(Params._dummy(), "outputCol", "output column name") def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name", 'output') + self.outputCol = Param(self, "outputCol", "output column name") + self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self.paramMap[self.outputCol] = value + self._paramMap[self.outputCol] = value return self def getOutputCol(self): """ Gets the value of outputCol or its default value. """ - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - else: - return self.outputCol.defaultValue + return self.getOrDefault(self.outputCol) class HasNumFeatures(Params): """ - Params with numFeatures. + Mixin for param numFeatures: number of features. """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + self.numFeatures = Param(self, "numFeatures", "number of features") def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self.paramMap[self.numFeatures] = value + self._paramMap[self.numFeatures] = value return self def getNumFeatures(self): """ Gets the value of numFeatures or its default value. """ - if self.numFeatures in self.paramMap: - return self.paramMap[self.numFeatures] - else: - return self.numFeatures.defaultValue + return self.getOrDefault(self.numFeatures) + + +class HasCheckpointInterval(Params): + """ + Mixin for param checkpointInterval: checkpoint interval (>= 1). + """ + + # a placeholder to make it appear in the generated doc + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1)") + + def __init__(self): + super(HasCheckpointInterval, self).__init__() + #: param for checkpoint interval (>= 1) + self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)") + + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + self._paramMap[self.checkpointInterval] = value + return self + + def getCheckpointInterval(self): + """ + Gets the value of checkpointInterval or its default value. + """ + return self.getOrDefault(self.checkpointInterval) + + +class HasSeed(Params): + """ + Mixin for param seed: random seed. + """ + + # a placeholder to make it appear in the generated doc + seed = Param(Params._dummy(), "seed", "random seed") + + def __init__(self): + super(HasSeed, self).__init__() + #: param for random seed + self.seed = Param(self, "seed", "random seed") + self._setDefault(seed=hash(type(self).__name__)) + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + self._paramMap[self.seed] = value + return self + + def getSeed(self): + """ + Gets the value of seed or its default value. + """ + return self.getOrDefault(self.seed) + + +class HasTol(Params): + """ + Mixin for param tol: the convergence tolerance for iterative algorithms. + """ + + # a placeholder to make it appear in the generated doc + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms") + + def __init__(self): + super(HasTol, self).__init__() + #: param for the convergence tolerance for iterative algorithms + self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms") + + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + self._paramMap[self.tol] = value + return self + + def getTol(self): + """ + Gets the value of tol or its default value. + """ + return self.getOrDefault(self.tol) + + +class HasStepSize(Params): + """ + Mixin for param stepSize: Step size to be used for each iteration of optimization.. + """ + + # a placeholder to make it appear in the generated doc + stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.") + + def __init__(self): + super(HasStepSize, self).__init__() + #: param for Step size to be used for each iteration of optimization. + self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.") + + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + self._paramMap[self.stepSize] = value + return self + + def getStepSize(self): + """ + Gets the value of stepSize or its default value. + """ + return self.getOrDefault(self.stepSize) + + +class DecisionTreeParams(Params): + """ + Mixin for Decision Tree parameters. + """ + + # a placeholder to make it appear in the generated doc + maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") + minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") + minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") + maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + + + def __init__(self): + super(DecisionTreeParams, self).__init__() + #: param for Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + self.maxDepth = Param(self, "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + #: param for Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature. + self.maxBins = Param(self, "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") + #: param for Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1. + self.minInstancesPerNode = Param(self, "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") + #: param for Minimum information gain for a split to be considered at a tree node. + self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") + #: param for Maximum memory in MB allocated to histogram aggregation. + self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") + #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. + self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + + def setMaxDepth(self, value): + """ + Sets the value of :py:attr:`maxDepth`. + """ + self._paramMap[self.maxDepth] = value + return self + + def getMaxDepth(self): + """ + Gets the value of maxDepth or its default value. + """ + return self.getOrDefault(self.maxDepth) + + def setMaxBins(self, value): + """ + Sets the value of :py:attr:`maxBins`. + """ + self._paramMap[self.maxBins] = value + return self + + def getMaxBins(self): + """ + Gets the value of maxBins or its default value. + """ + return self.getOrDefault(self.maxBins) + + def setMinInstancesPerNode(self, value): + """ + Sets the value of :py:attr:`minInstancesPerNode`. + """ + self._paramMap[self.minInstancesPerNode] = value + return self + + def getMinInstancesPerNode(self): + """ + Gets the value of minInstancesPerNode or its default value. + """ + return self.getOrDefault(self.minInstancesPerNode) + + def setMinInfoGain(self, value): + """ + Sets the value of :py:attr:`minInfoGain`. + """ + self._paramMap[self.minInfoGain] = value + return self + + def getMinInfoGain(self): + """ + Gets the value of minInfoGain or its default value. + """ + return self.getOrDefault(self.minInfoGain) + + def setMaxMemoryInMB(self, value): + """ + Sets the value of :py:attr:`maxMemoryInMB`. + """ + self._paramMap[self.maxMemoryInMB] = value + return self + + def getMaxMemoryInMB(self): + """ + Gets the value of maxMemoryInMB or its default value. + """ + return self.getOrDefault(self.maxMemoryInMB) + + def setCacheNodeIds(self, value): + """ + Sets the value of :py:attr:`cacheNodeIds`. + """ + self._paramMap[self.cacheNodeIds] = value + return self + + def getCacheNodeIds(self): + """ + Gets the value of cacheNodeIds or its default value. + """ + return self.getOrDefault(self.cacheNodeIds) + diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2d239f8c802a..13cf2b0f7bbd 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -18,10 +18,8 @@ from abc import ABCMeta, abstractmethod from pyspark.ml.param import Param, Params -from pyspark.ml.util import inherit_doc - - -__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] +from pyspark.ml.util import keyword_only +from pyspark.mllib.common import inherit_doc @inherit_doc @@ -33,18 +31,42 @@ class Estimator(Params): __metaclass__ = ABCMeta @abstractmethod - def fit(self, dataset, params={}): + def _fit(self, dataset): """ - Fits a model to the input dataset with optional parameters. + Fits a model to the input dataset. This is called by the + default implementation of fit. :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` - :param params: an optional param map that overwrites embedded - params + :py:class:`pyspark.sql.DataFrame` :returns: fitted model """ raise NotImplementedError() + def fit(self, dataset, params=None): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. If a list/tuple of param maps is given, + this calls fit on each param map and returns a + list of models. + :returns: fitted model(s) + """ + if params is None: + params = dict() + if isinstance(params, (list, tuple)): + return [self.fit(dataset, paramMap) for paramMap in params] + elif isinstance(params, dict): + if params: + return self.copy(params)._fit(dataset) + else: + return self._fit(dataset) + else: + raise ValueError("Params must be either a param map or a list/tuple of param maps, " + "but got %s." % type(params)) + @inherit_doc class Transformer(Params): @@ -56,18 +78,45 @@ class Transformer(Params): __metaclass__ = ABCMeta @abstractmethod - def transform(self, dataset, params={}): + def _transform(self, dataset): """ Transforms the input dataset with optional parameters. :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` - :param params: an optional param map that overwrites embedded - params + :py:class:`pyspark.sql.DataFrame` :returns: transformed dataset """ raise NotImplementedError() + def transform(self, dataset, params=None): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. + :returns: transformed dataset + """ + if params is None: + params = dict() + if isinstance(params, dict): + if params: + return self.copy(params,)._transform(dataset) + else: + return self._transform(dataset) + else: + raise ValueError("Params must be either a param map but got %s." % type(params)) + + +@inherit_doc +class Model(Transformer): + """ + Abstract class for models that are fitted by estimators. + """ + + __metaclass__ = ABCMeta + @inherit_doc class Pipeline(Estimator): @@ -89,10 +138,18 @@ class Pipeline(Estimator): identity transformer. """ - def __init__(self): + @keyword_only + def __init__(self, stages=None): + """ + __init__(self, stages=None) + """ + if stages is None: + stages = [] super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) def setStages(self, value): """ @@ -100,23 +157,33 @@ def setStages(self, value): :param value: a list of transformers or estimators :return: the pipeline instance """ - self.paramMap[self.stages] = value + self._paramMap[self.stages] = value return self def getStages(self): """ Get pipeline stages. """ - if self.stages in self.paramMap: - return self.paramMap[self.stages] + if self.stages in self._paramMap: + return self._paramMap[self.stages] + + @keyword_only + def setParams(self, stages=None): + """ + setParams(self, stages=None) + Sets params for Pipeline. + """ + if stages is None: + stages = [] + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) - def fit(self, dataset, params={}): - paramMap = self._merge_params(params) - stages = paramMap[self.stages] + def _fit(self, dataset): + stages = self.getStages() for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): - raise ValueError( - "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) + raise TypeError( + "Cannot recognize a pipeline stage of type %s." % type(stage)) indexOfLastEstimator = -1 for i, stage in enumerate(stages): if isinstance(stage, Estimator): @@ -126,29 +193,41 @@ def fit(self, dataset, params={}): if i <= indexOfLastEstimator: if isinstance(stage, Transformer): transformers.append(stage) - dataset = stage.transform(dataset, paramMap) + dataset = stage.transform(dataset) else: # must be an Estimator - model = stage.fit(dataset, paramMap) + model = stage.fit(dataset) transformers.append(model) if i < indexOfLastEstimator: - dataset = model.transform(dataset, paramMap) + dataset = model.transform(dataset) else: transformers.append(stage) return PipelineModel(transformers) + def copy(self, extra=None): + if extra is None: + extra = dict() + that = Params.copy(self, extra) + stages = [stage.copy(extra) for stage in that.getStages()] + return that.setStages(stages) + @inherit_doc -class PipelineModel(Transformer): +class PipelineModel(Model): """ Represents a compiled pipeline with transformers and fitted models. """ - def __init__(self, transformers): + def __init__(self, stages): super(PipelineModel, self).__init__() - self.transformers = transformers + self.stages = stages - def transform(self, dataset, params={}): - paramMap = self._merge_params(params) - for t in self.transformers: - dataset = t.transform(dataset, paramMap) + def _transform(self, dataset): + for t in self.stages: + dataset = t.transform(dataset) return dataset + + def copy(self, extra=None): + if extra is None: + extra = dict() + stages = [stage.copy(extra) for stage in self.stages] + return PipelineModel(stages) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py new file mode 100644 index 000000000000..b06099ac0aee --- /dev/null +++ b/python/pyspark/ml/recommendation.py @@ -0,0 +1,306 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * +from pyspark.mllib.common import inherit_doc + + +__all__ = ['ALS', 'ALSModel'] + + +@inherit_doc +class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed): + """ + Alternating Least Squares (ALS) matrix factorization. + + ALS attempts to estimate the ratings matrix `R` as the product of + two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically + these approximations are called 'factor' matrices. The general + approach is iterative. During each iteration, one of the factor + matrices is held constant, while the other is solved for using least + squares. The newly-solved factor matrix is then held constant while + solving for the other factor matrix. + + This is a blocked implementation of the ALS factorization algorithm + that groups the two sets of factors (referred to as "users" and + "products") into blocks and reduces communication by only sending + one copy of each user vector to each product block on each + iteration, and only for the product blocks that need that user's + feature vector. This is achieved by pre-computing some information + about the ratings matrix to determine the "out-links" of each user + (which blocks of products it will contribute to) and "in-link" + information for each product (which of the feature vectors it + receives from each user block it will depend on). This allows us to + send only an array of feature vectors between each user block and + product block, and have the product block find the users' ratings + and update the products based on these messages. + + For implicit preference data, the algorithm used is based on + "Collaborative Filtering for Implicit Feedback Datasets", available + at `http://dx.doi.org/10.1109/ICDM.2008.22`, adapted for the blocked + approach used here. + + Essentially instead of finding the low-rank approximations to the + rating matrix `R`, this finds the approximations for a preference + matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0. + The ratings then act as 'confidence' values related to strength of + indicated user preferences rather than explicit ratings given to + items. + + >>> df = sqlContext.createDataFrame( + ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ... ["user", "item", "rating"]) + >>> als = ALS(rank=10, maxIter=5) + >>> model = als.fit(df) + >>> model.rank + 10 + >>> model.userFactors.orderBy("id").collect() + [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)] + >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) + >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) + >>> predictions[0] + Row(user=0, item=2, prediction=0.39...) + >>> predictions[1] + Row(user=1, item=0, prediction=3.19...) + >>> predictions[2] + Row(user=2, item=0, prediction=-1.15...) + """ + + # a placeholder to make it appear in the generated doc + rank = Param(Params._dummy(), "rank", "rank of the factorization") + numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") + numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks") + implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference") + alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference") + userCol = Param(Params._dummy(), "userCol", "column name for user ids") + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids") + ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings") + nonnegative = Param(Params._dummy(), "nonnegative", + "whether to use nonnegative constraint for least squares") + + @keyword_only + def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, + ratingCol="rating", nonnegative=False, checkpointInterval=10): + """ + __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ + implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ + ratingCol="rating", nonnegative=false, checkpointInterval=10) + """ + super(ALS, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) + self.rank = Param(self, "rank", "rank of the factorization") + self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") + self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") + self.implicitPrefs = Param(self, "implicitPrefs", "whether to use implicit preference") + self.alpha = Param(self, "alpha", "alpha for implicit preference") + self.userCol = Param(self, "userCol", "column name for user ids") + self.itemCol = Param(self, "itemCol", "column name for item ids") + self.ratingCol = Param(self, "ratingCol", "column name for ratings") + self.nonnegative = Param(self, "nonnegative", + "whether to use nonnegative constraint for least squares") + self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, + ratingCol="rating", nonnegative=False, checkpointInterval=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, + ratingCol="rating", nonnegative=False, checkpointInterval=10): + """ + setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ + ratingCol="rating", nonnegative=False, checkpointInterval=10) + Sets params for ALS. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return ALSModel(java_model) + + def setRank(self, value): + """ + Sets the value of :py:attr:`rank`. + """ + self._paramMap[self.rank] = value + return self + + def getRank(self): + """ + Gets the value of rank or its default value. + """ + return self.getOrDefault(self.rank) + + def setNumUserBlocks(self, value): + """ + Sets the value of :py:attr:`numUserBlocks`. + """ + self._paramMap[self.numUserBlocks] = value + return self + + def getNumUserBlocks(self): + """ + Gets the value of numUserBlocks or its default value. + """ + return self.getOrDefault(self.numUserBlocks) + + def setNumItemBlocks(self, value): + """ + Sets the value of :py:attr:`numItemBlocks`. + """ + self._paramMap[self.numItemBlocks] = value + return self + + def getNumItemBlocks(self): + """ + Gets the value of numItemBlocks or its default value. + """ + return self.getOrDefault(self.numItemBlocks) + + def setNumBlocks(self, value): + """ + Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. + """ + self._paramMap[self.numUserBlocks] = value + self._paramMap[self.numItemBlocks] = value + + def setImplicitPrefs(self, value): + """ + Sets the value of :py:attr:`implicitPrefs`. + """ + self._paramMap[self.implicitPrefs] = value + return self + + def getImplicitPrefs(self): + """ + Gets the value of implicitPrefs or its default value. + """ + return self.getOrDefault(self.implicitPrefs) + + def setAlpha(self, value): + """ + Sets the value of :py:attr:`alpha`. + """ + self._paramMap[self.alpha] = value + return self + + def getAlpha(self): + """ + Gets the value of alpha or its default value. + """ + return self.getOrDefault(self.alpha) + + def setUserCol(self, value): + """ + Sets the value of :py:attr:`userCol`. + """ + self._paramMap[self.userCol] = value + return self + + def getUserCol(self): + """ + Gets the value of userCol or its default value. + """ + return self.getOrDefault(self.userCol) + + def setItemCol(self, value): + """ + Sets the value of :py:attr:`itemCol`. + """ + self._paramMap[self.itemCol] = value + return self + + def getItemCol(self): + """ + Gets the value of itemCol or its default value. + """ + return self.getOrDefault(self.itemCol) + + def setRatingCol(self, value): + """ + Sets the value of :py:attr:`ratingCol`. + """ + self._paramMap[self.ratingCol] = value + return self + + def getRatingCol(self): + """ + Gets the value of ratingCol or its default value. + """ + return self.getOrDefault(self.ratingCol) + + def setNonnegative(self, value): + """ + Sets the value of :py:attr:`nonnegative`. + """ + self._paramMap[self.nonnegative] = value + return self + + def getNonnegative(self): + """ + Gets the value of nonnegative or its default value. + """ + return self.getOrDefault(self.nonnegative) + + +class ALSModel(JavaModel): + """ + Model fitted by ALS. + """ + + @property + def rank(self): + """rank of the matrix factorization model""" + return self._call_java("rank") + + @property + def userFactors(self): + """ + a DataFrame that stores user factors in two columns: `id` and + `features` + """ + return self._call_java("userFactors") + + @property + def itemFactors(self): + """ + a DataFrame that stores item factors in two columns: `id` and + `features` + """ + return self._call_java("itemFactors") + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.recommendation tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py new file mode 100644 index 000000000000..44f60a769566 --- /dev/null +++ b/python/pyspark/ml/regression.py @@ -0,0 +1,581 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * +from pyspark.mllib.common import inherit_doc + + +__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', + 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel', + 'RandomForestRegressor', 'RandomForestRegressionModel'] + + +@inherit_doc +class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + HasRegParam, HasTol): + """ + Linear regression. + + The learning objective is to minimize the squared error, with regularization. + The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ + + This support multiple types of regularization: + - none (a.k.a. ordinary least squares) + - L2 (ridge regression) + - L1 (Lasso) + - L2 + L1 (elastic net) + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0) + >>> model = lr.fit(df) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + >>> lr.setParams("vector") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + elasticNetParam = \ + Param(Params._dummy(), "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + """ + super(LinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.LinearRegression", self.uid) + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty + # is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = \ + Param(self, "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + + "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + Sets params for linear regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return LinearRegressionModel(java_model) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + +class LinearRegressionModel(JavaModel): + """ + Model fitted by LinearRegression. + """ + + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + +class TreeRegressorParams(object): + """ + Private class to track supported impurity measures. + """ + supportedImpurities = ["variance"] + + +class RandomForestParams(object): + """ + Private class to track supported random forest parameters. + """ + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + +class GBTParams(object): + """ + Private class to track supported GBT params. + """ + supportedLossTypes = ["squared", "absolute"] + + +@inherit_doc +class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` + learning algorithm for regression. + It supports both continuous and categorical features. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> dt = DecisionTreeRegressor(maxDepth=2) + >>> model = dt.fit(df) + >>> model.depth + 1 + >>> model.numNodes + 3 + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + """ + super(DecisionTreeRegressor, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid) + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + Sets params for the DecisionTreeRegressor. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return DecisionTreeRegressionModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + +@inherit_doc +class DecisionTreeModel(JavaModel): + + @property + def numNodes(self): + """Return number of nodes of the decision tree.""" + return self._call_java("numNodes") + + @property + def depth(self): + """Return depth of the decision tree.""" + return self._call_java("depth") + + def __repr__(self): + return self._call_java("toString") + + +@inherit_doc +class TreeEnsembleModels(JavaModel): + + @property + def treeWeights(self): + """Return the weights for each tree""" + return list(self._call_java("javaTreeWeights")) + + def __repr__(self): + return self._call_java("toString") + + +@inherit_doc +class DecisionTreeRegressionModel(DecisionTreeModel): + """ + Model fitted by DecisionTreeRegressor. + """ + + +@inherit_doc +class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Random_forest Random Forest` + learning algorithm for regression. + It supports both continuous and categorical features. + + >>> from numpy import allclose + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) + >>> model = rf.fit(df) + >>> allclose(model.treeWeights, [1.0, 1.0]) + True + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 0.5 + """ + + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", + numTrees=20, featureSubsetStrategy="auto", seed=None): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", numTrees=20, \ + featureSubsetStrategy="auto", seed=None) + """ + super(RandomForestRegressor, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.RandomForestRegressor", self.uid) + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) + #: param for Fraction of the training data used for learning each decision tree, + # in range (0, 1] + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: param for Number of trees to train (>= 1) + self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") + #: param for The number of features to consider for splits at each tree node + self.featureSubsetStrategy = \ + Param(self, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + impurity="variance", numTrees=20, featureSubsetStrategy="auto") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + impurity="variance", numTrees=20, featureSubsetStrategy="auto"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ + impurity="variance", numTrees=20, featureSubsetStrategy="auto") + Sets params for linear regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return RandomForestRegressionModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setNumTrees(self, value): + """ + Sets the value of :py:attr:`numTrees`. + """ + self._paramMap[self.numTrees] = value + return self + + def getNumTrees(self): + """ + Gets the value of numTrees or its default value. + """ + return self.getOrDefault(self.numTrees) + + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + self._paramMap[self.featureSubsetStrategy] = value + return self + + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + + +class RandomForestRegressionModel(TreeEnsembleModels): + """ + Model fitted by RandomForestRegressor. + """ + + +@inherit_doc +class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` + learning algorithm for regression. + It supports both continuous and categorical features. + + >>> from numpy import allclose + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> gbt = GBTRegressor(maxIter=5, maxDepth=2) + >>> model = gbt.fit(df) + >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) + True + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + lossType = Param(Params._dummy(), "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + + "contribution of each estimator") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", + maxIter=20, stepSize=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="squared", maxIter=20, stepSize=0.1) + """ + super(GBTRegressor, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) + #: param for Loss function which GBT tries to minimize (case-insensitive). + self.lossType = Param(self, "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + #: Fraction of the training data used for learning each decision tree, in range (0, 1]. + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of + # each estimator + self.stepSize = Param(self, "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator") + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="squared", maxIter=20, stepSize=0.1) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="squared", maxIter=20, stepSize=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="squared", maxIter=20, stepSize=0.1) + Sets params for Gradient Boosted Tree Regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GBTRegressionModel(java_model) + + def setLossType(self, value): + """ + Sets the value of :py:attr:`lossType`. + """ + self._paramMap[self.lossType] = value + return self + + def getLossType(self): + """ + Gets the value of lossType or its default value. + """ + return self.getOrDefault(self.lossType) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + self._paramMap[self.stepSize] = value + return self + + def getStepSize(self): + """ + Gets the value of stepSize or its default value. + """ + return self.getOrDefault(self.stepSize) + + +class GBTRegressionModel(TreeEnsembleModels): + """ + Model fitted by GBTRegressor. + """ + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.regression tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b627c2b4e930..c151d21fd661 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,9 +31,13 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame -from pyspark.ml.param import Param -from pyspark.ml.pipeline import Transformer, Estimator, Pipeline +from pyspark.sql import DataFrame, SQLContext +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.util import keyword_only +from pyspark.ml import Estimator, Model, Pipeline, Transformer +from pyspark.ml.feature import * +from pyspark.mllib.linalg import DenseVector class MockDataset(DataFrame): @@ -42,44 +46,43 @@ def __init__(self): self.index = 0 -class MockTransformer(Transformer): +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + def getFake(self): + return self.getOrDefault(self.fake) + + +class MockTransformer(Transformer, HasFake): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake", None) self.dataset_index = None - self.fake_param_value = None - def transform(self, dataset, params={}): + def _transform(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] dataset.index += 1 return dataset -class MockEstimator(Estimator): +class MockEstimator(Estimator, HasFake): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake", None) self.dataset_index = None - self.fake_param_value = None - self.model = None - def fit(self, dataset, params={}): + def _fit(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] model = MockModel() - self.model = model + self._copyValues(model) return model -class MockModel(MockTransformer, Transformer): - - def __init__(self): - super(MockModel, self).__init__() +class MockModel(MockTransformer, Model, HasFake): + pass class PipelineTests(PySparkTestCase): @@ -90,19 +93,17 @@ def test_pipeline(self): transformer1 = MockTransformer() estimator2 = MockEstimator() transformer3 = MockTransformer() - pipeline = Pipeline() \ - .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - self.assertEqual(0, estimator0.dataset_index) - self.assertEqual(0, estimator0.fake_param_value) - model0 = estimator0.model + model0, transformer1, model2, transformer3 = pipeline_model.stages self.assertEqual(0, model0.dataset_index) + self.assertEqual(0, model0.getFake()) self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.fake_param_value) - self.assertEqual(2, estimator2.dataset_index) - model2 = estimator2.model - self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " - "not be called during fit.") + self.assertEqual(1, transformer1.getFake()) + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") dataset = pipeline_model.transform(dataset) self.assertEqual(2, model0.dataset_index) self.assertEqual(3, transformer1.dataset_index) @@ -111,5 +112,157 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) +class TestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(OtherTestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0)") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + seed = testParams.seed + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter, seed]) + + self.assertTrue(testParams.hasParam(maxIter)) + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertTrue(testParams.hasParam(inputCol)) + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + # Since the default is normally random, set it to a known number for debug str + testParams._setDefault(seed=41) + testParams.setSeed(43) + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", + "seed: random seed (default: 41, current: 43)"])) + + def test_hasseed(self): + noSeedSpecd = TestParams() + withSeedSpecd = TestParams(seed=42) + other = OtherTestParams() + # Check that we no longer use 42 as the magic number + self.assertNotEqual(noSeedSpecd.getSeed(), 42) + origSeed = noSeedSpecd.getSeed() + # Check that we only compute the seed once + self.assertEqual(noSeedSpecd.getSeed(), origSeed) + # Check that a specified seed is honored + self.assertEqual(withSeedSpecd.getSeed(), 42) + # Check that a different class has a different seed + self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + + +class FeatureTests(PySparkTestCase): + + def test_binarizer(self): + b0 = Binarizer() + self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) + self.assertTrue(all([~b0.isSet(p) for p in b0.params])) + self.assertTrue(b0.hasDefault(b0.threshold)) + self.assertEqual(b0.getThreshold(), 0.0) + b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) + self.assertTrue(all([b0.isSet(p) for p in b0.params])) + self.assertEqual(b0.getThreshold(), 1.0) + self.assertEqual(b0.getInputCol(), "input") + self.assertEqual(b0.getOutputCol(), "output") + + b0c = b0.copy({b0.threshold: 2.0}) + self.assertEqual(b0c.uid, b0.uid) + self.assertListEqual(b0c.params, b0.params) + self.assertEqual(b0c.getThreshold(), 2.0) + + b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") + self.assertNotEqual(b1.uid, b0.uid) + self.assertEqual(b1.getThreshold(), 2.0) + self.assertEqual(b1.getInputCol(), "input") + self.assertEqual(b1.getOutputCol(), "output") + + def test_idf(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (DenseVector([1.0, 2.0]),), + (DenseVector([0.0, 1.0]),), + (DenseVector([3.0, 0.2]),)], ["tf"]) + idf0 = IDF(inputCol="tf") + self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) + idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) + self.assertEqual(idf0m.uid, idf0.uid, + "Model should inherit the UID from its parent estimator.") + output = idf0m.transform(dataset) + self.assertIsNotNone(output.head().idf) + + def test_ngram(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py new file mode 100644 index 000000000000..dcfee6a3170a --- /dev/null +++ b/python/pyspark/ml/tuning.py @@ -0,0 +1,284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import itertools +import numpy as np + +from pyspark.ml.param import Params, Param +from pyspark.ml import Estimator, Model +from pyspark.ml.util import keyword_only +from pyspark.sql.functions import rand + +__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel'] + + +class ParamGridBuilder(object): + r""" + Builder for a param grid used in grid search-based model selection. + + >>> from pyspark.ml.classification import LogisticRegression + >>> lr = LogisticRegression() + >>> output = ParamGridBuilder() \ + ... .baseOn({lr.labelCol: 'l'}) \ + ... .baseOn([lr.predictionCol, 'p']) \ + ... .addGrid(lr.regParam, [1.0, 2.0]) \ + ... .addGrid(lr.maxIter, [1, 5]) \ + ... .build() + >>> expected = [ + ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, + ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, + ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, + ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] + >>> len(output) == len(expected) + True + >>> all([m in expected for m in output]) + True + """ + + def __init__(self): + self._param_grid = {} + + def addGrid(self, param, values): + """ + Sets the given parameters in this grid to fixed values. + """ + self._param_grid[param] = values + + return self + + def baseOn(self, *args): + """ + Sets the given parameters in this grid to fixed values. + Accepts either a parameter dictionary or a list of (parameter, value) pairs. + """ + if isinstance(args[0], dict): + self.baseOn(*args[0].items()) + else: + for (param, value) in args: + self.addGrid(param, [value]) + + return self + + def build(self): + """ + Builds and returns all combinations of parameters specified + by the param grid. + """ + keys = self._param_grid.keys() + grid_values = self._param_grid.values() + return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + + +class CrossValidator(Estimator): + """ + K-fold cross validation. + + >>> from pyspark.ml.classification import LogisticRegression + >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, + ... ["features", "label"]) + >>> lr = LogisticRegression() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + >>> evaluator = BinaryClassificationEvaluator() + >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... + """ + + # a placeholder to make it appear in the generated doc + estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") + + # a placeholder to make it appear in the generated doc + estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") + + # a placeholder to make it appear in the generated doc + evaluator = Param( + Params._dummy(), "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") + + # a placeholder to make it appear in the generated doc + numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") + + @keyword_only + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + """ + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + """ + super(CrossValidator, self).__init__() + #: param for estimator to be cross-validated + self.estimator = Param(self, "estimator", "estimator to be cross-validated") + #: param for estimator param maps + self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") + #: param for the evaluator used to select hyper-parameters that + #: maximize the cross-validated metric + self.evaluator = Param( + self, "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") + #: param for number of folds for cross validation + self.numFolds = Param(self, "numFolds", "number of folds for cross validation") + self._setDefault(numFolds=3) + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + @keyword_only + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + """ + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + Sets params for cross validator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setEstimator(self, value): + """ + Sets the value of :py:attr:`estimator`. + """ + self._paramMap[self.estimator] = value + return self + + def getEstimator(self): + """ + Gets the value of estimator or its default value. + """ + return self.getOrDefault(self.estimator) + + def setEstimatorParamMaps(self, value): + """ + Sets the value of :py:attr:`estimatorParamMaps`. + """ + self._paramMap[self.estimatorParamMaps] = value + return self + + def getEstimatorParamMaps(self): + """ + Gets the value of estimatorParamMaps or its default value. + """ + return self.getOrDefault(self.estimatorParamMaps) + + def setEvaluator(self, value): + """ + Sets the value of :py:attr:`evaluator`. + """ + self._paramMap[self.evaluator] = value + return self + + def getEvaluator(self): + """ + Gets the value of evaluator or its default value. + """ + return self.getOrDefault(self.evaluator) + + def setNumFolds(self, value): + """ + Sets the value of :py:attr:`numFolds`. + """ + self._paramMap[self.numFolds] = value + return self + + def getNumFolds(self): + """ + Gets the value of numFolds or its default value. + """ + return self.getOrDefault(self.numFolds) + + def _fit(self, dataset): + est = self.getOrDefault(self.estimator) + epm = self.getOrDefault(self.estimatorParamMaps) + numModels = len(epm) + eva = self.getOrDefault(self.evaluator) + nFolds = self.getOrDefault(self.numFolds) + h = 1.0 / nFolds + randCol = self.uid + "_rand" + df = dataset.select("*", rand(0).alias(randCol)) + metrics = np.zeros(numModels) + for i in range(nFolds): + validateLB = i * h + validateUB = (i + 1) * h + condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) + validation = df.filter(condition) + train = df.filter(~condition) + for j in range(numModels): + model = est.fit(train, epm[j]) + # TODO: duplicate evaluator to take extra params from input + metric = eva.evaluate(model.transform(validation, epm[j])) + metrics[j] += metric + bestIndex = np.argmax(metrics) + bestModel = est.fit(dataset, epm[bestIndex]) + return CrossValidatorModel(bestModel) + + def copy(self, extra=None): + if extra is None: + extra = dict() + newCV = Params.copy(self, extra) + if self.isSet(self.estimator): + newCV.setEstimator(self.getEstimator().copy(extra)) + # estimatorParamMaps remain the same + if self.isSet(self.evaluator): + newCV.setEvaluator(self.getEvaluator().copy(extra)) + return newCV + + +class CrossValidatorModel(Model): + """ + Model from k-fold cross validation. + """ + + def __init__(self, bestModel): + super(CrossValidatorModel, self).__init__() + #: best model from cross validation + self.bestModel = bestModel + + def _transform(self, dataset): + return self.bestModel.transform(dataset) + + def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies the underlying bestModel, + creates a deep copy of the embedded paramMap, and + copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() + return CrossValidatorModel(self.bestModel.copy(extra)) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.tuning tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b1caa84b6306..cee9d67b0532 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,21 +15,22 @@ # limitations under the License. # +from functools import wraps import uuid -def inherit_doc(cls): - for name, func in vars(cls).items(): - # only inherit docstring for public functions - if name.startswith("_"): - continue - if not func.__doc__: - for parent in cls.__bases__: - parent_func = getattr(parent, name, None) - if parent_func and getattr(parent_func, "__doc__", None): - func.__doc__ = parent_func.__doc__ - break - return cls +def keyword_only(func): + """ + A decorator that forces keyword arguments in the wrapped method + and saves actual input keyword arguments in `_input_kwargs`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise TypeError("Method %s forces keyword arguments." % func.__name__) + wrapper._input_kwargs = kwargs + return func(*args, **kwargs) + return wrapper class Identifiable(object): @@ -38,9 +39,16 @@ class Identifiable(object): """ def __init__(self): - #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + #: A unique id for the object. + self.uid = self._randomUID() def __repr__(self): return self.uid + + @classmethod + def _randomUID(cls): + """ + Generate a unique id for the object. The default implementation + concatenates the class name, "_", and 12 random hex chars. + """ + return cls.__name__ + "_" + uuid.uuid4().hex[12:] diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 9e12ddc3d9b8..253705bde913 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,8 +20,8 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer -from pyspark.ml.util import inherit_doc +from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -45,43 +45,61 @@ class JavaWrapper(Params): __metaclass__ = ABCMeta - #: Fully-qualified class name of the wrapped Java component. - _java_class = None + #: The wrapped Java companion object. Subclasses should initialize + #: it properly. The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + _java_obj = None - def _java_obj(self): + @staticmethod + def _new_java_obj(java_class, *args): """ - Returns or creates a Java object. + Construct a new Java object. """ + sc = SparkContext._active_spark_context java_obj = _jvm() - for name in self._java_class.split("."): + for name in java_class.split("."): java_obj = getattr(java_obj, name) - return java_obj() + java_args = [_py2java(sc, arg) for arg in args] + return java_obj(*java_args) - def _transfer_params_to_java(self, params, java_obj): + def _make_java_param_pair(self, param, value): """ - Transforms the embedded params and additional params to the - input Java object. - :param params: additional params (overwriting embedded values) - :param java_obj: Java object to receive the params + Makes a Java parm pair. + """ + sc = SparkContext._active_spark_context + param = self._resolveParam(param) + java_param = self._java_obj.getParam(param.name) + java_value = _py2java(sc, value) + return java_param.w(java_value) + + def _transfer_params_to_java(self): + """ + Transforms the embedded params to the companion Java object. """ - paramMap = self._merge_params(params) + paramMap = self.extractParamMap() for param in self.params: if param in paramMap: - java_obj.set(param.name, paramMap[param]) + pair = self._make_java_param_pair(param, paramMap[param]) + self._java_obj.set(pair) + + def _transfer_params_from_java(self): + """ + Transforms the embedded params from the companion Java object. + """ + sc = SparkContext._active_spark_context + for param in self.params: + if self._java_obj.hasParam(param.name): + java_param = self._java_obj.getParam(param.name) + value = _java2py(sc, self._java_obj.getOrDefault(java_param)) + self._paramMap[param] = value - def _empty_java_param_map(self): + @staticmethod + def _empty_java_param_map(): """ Returns an empty Java ParamMap reference. """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _create_java_param_map(self, params, java_obj): - paramMap = self._empty_java_param_map() - for param, value in params.items(): - if param.parent is self: - paramMap.put(java_obj.getParam(param.name), value) - return paramMap - @inherit_doc class JavaEstimator(Estimator, JavaWrapper): @@ -96,22 +114,21 @@ def _create_model(self, java_model): """ Creates a model from the input Java model reference. """ - return JavaModel(java_model) + raise NotImplementedError() - def _fit_java(self, dataset, params={}): + def _fit_java(self, dataset): """ Fits a Java model to the input dataset. :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` + :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) :return: fitted Java model """ - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + self._transfer_params_to_java() + return self._java_obj.fit(dataset._jdf) - def fit(self, dataset, params={}): - java_model = self._fit_java(dataset, params) + def _fit(self, dataset): + java_model = self._fit_java(dataset) return self._create_model(java_model) @@ -124,26 +141,49 @@ class JavaTransformer(Transformer, JavaWrapper): __metaclass__ = ABCMeta - def transform(self, dataset, params={}): - java_obj = self._java_obj() - self._transfer_params_to_java({}, java_obj) - java_param_map = self._create_java_param_map(params, java_obj) - return DataFrame(java_obj.transform(dataset._jdf, java_param_map), - dataset.sql_ctx) + def _transform(self, dataset): + self._transfer_params_to_java() + return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) @inherit_doc -class JavaModel(JavaTransformer): +class JavaModel(Model, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala - implementations. + implementations. Subclasses should inherit this class before + param mix-ins, because this sets the UID from the Java model. """ __metaclass__ = ABCMeta def __init__(self, java_model): - super(JavaTransformer, self).__init__() - self._java_model = java_model + """ + Initialize this instance with a Java model object. + Subclasses should call this constructor, initialize params, + and then call _transformer_params_from_java. + """ + super(JavaModel, self).__init__() + self._java_obj = java_model + self.uid = java_model.uid() - def _java_obj(self): - return self._java_model + def copy(self, extra=None): + """ + Creates a copy of this instance with the same uid and some + extra params. This implementation first calls Params.copy and + then make a copy of the companion Java model with extra params. + So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() + that = super(JavaModel, self).copy(extra) + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() + return that + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index c3217620e3c4..acba3a717d21 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -18,18 +18,15 @@ """ Python bindings for MLlib. """ +from __future__ import absolute_import -# MLlib currently needs and NumPy 1.4+, so complain if lower +# MLlib currently needs NumPy 1.4+, so complain if lower import numpy -if numpy.version.version < '1.4': + +ver = [int(x) for x in numpy.version.version.split('.')[:2]] +if ver < [1, 4]: raise Exception("MLlib requires NumPy 1.4+") -__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', +__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -import rand as random -random.__name__ = 'random' -random.RandomRDDs.__module__ = __name__ + '.random' -sys.modules[__name__ + '.random'] = random diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 00e2e76711e8..8f27c446a66e 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,59 +21,91 @@ from numpy import array from pyspark import RDD -from pyspark.mllib.common import callMLlibFunc -from pyspark.mllib.linalg import SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.streaming import DStream +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.regression import ( + LabeledPoint, LinearModel, _regression_train_wrapper, + StreamingLinearAlgorithm) +from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', - 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes', + 'StreamingLogisticRegressionWithSGD'] -class LinearBinaryClassificationModel(LinearModel): +class LinearClassificationModel(LinearModel): """ - Represents a linear binary classification model that predicts to whether an - example is positive (1.0) or negative (0.0). + A private abstract class representing a multiclass classification + model. The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): - super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + super(LinearClassificationModel, self).__init__(weights, intercept) self._threshold = None def setThreshold(self, value): """ .. note:: Experimental - Sets the threshold that separates positive predictions from negative - predictions. An example with prediction score greater than or equal - to this threshold is identified as an positive, and negative otherwise. + Sets the threshold that separates positive predictions from + negative predictions. An example with prediction score greater + than or equal to this threshold is identified as an positive, + and negative otherwise. It is used for binary classification + only. """ self._threshold = value + @property + def threshold(self): + """ + .. note:: Experimental + + Returns the threshold (if any) used for converting raw + prediction scores into 0/1 predictions. It is used for + binary classification only. + """ + return self._threshold + def clearThreshold(self): """ .. note:: Experimental - Clears the threshold so that `predict` will output raw prediction scores. + Clears the threshold so that `predict` will output raw + prediction scores. It is used for binary classification only. """ self._threshold = None def predict(self, test): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ raise NotImplementedError -class LogisticRegressionModel(LinearBinaryClassificationModel): +class LogisticRegressionModel(LinearClassificationModel): - """A linear binary classification model derived from logistic regression. + """ + Classification model trained using Multinomial/Binary Logistic + Regression. + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. (Only used + in Binary Logistic Regression. In Multinomial Logistic + Regression, the intercepts will not be a single value, + so the intercepts will be part of the weights.) + :param numFeatures: the dimension of the features. + :param numClasses: the number of possible outcomes for k classes + classification problem in Multinomial Logistic Regression. + By default, it is binary logistic regression so numClasses + will be set to 2. >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -82,7 +114,7 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): [1, 0] >>> lrm.clearThreshold() >>> lrm.predict([0.0, 1.0]) - 0.123... + 0.279... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), @@ -90,7 +122,7 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> lrm.predict(array([0.0, 1.0])) 1 >>> lrm.predict(array([1.0, 0.0])) @@ -99,50 +131,133 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): 1 >>> lrm.predict(SparseVector(2, {0: 1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LogisticRegressionModel.load(sc, path) + >>> sameModel.predict(array([0.0, 1.0])) + 1 + >>> sameModel.predict(SparseVector(2, {0: 1.0})) + 0 + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except: + ... pass + >>> multi_class_data = [ + ... LabeledPoint(0.0, [0.0, 1.0, 0.0]), + ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), + ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) + ... ] + >>> data = sc.parallelize(multi_class_data) + >>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) + >>> mcm.predict([0.0, 0.5, 0.0]) + 0 + >>> mcm.predict([0.8, 0.0, 0.0]) + 1 + >>> mcm.predict([0.0, 0.0, 0.3]) + 2 """ - def __init__(self, weights, intercept): + def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) + self._numFeatures = int(numFeatures) + self._numClasses = int(numClasses) self._threshold = 0.5 + if self._numClasses == 2: + self._dataWithBiasSize = None + self._weightsMatrix = None + else: + self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1) + self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1, + self._dataWithBiasSize) + + @property + def numFeatures(self): + return self._numFeatures + + @property + def numClasses(self): + return self._numClasses def predict(self, x): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - margin = self.weights.dot(x) + self._intercept - if margin > 0: - prob = 1 / (1 + exp(-margin)) + if self.numClasses == 2: + margin = self.weights.dot(x) + self._intercept + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 else: - exp_margin = exp(margin) - prob = exp_margin / (1 + exp_margin) - if self._threshold is None: - return prob - else: - return 1 if prob > self._threshold else 0 + best_class = 0 + max_margin = 0.0 + if x.size + 1 == self._dataWithBiasSize: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i][0:x.size]) + \ + self._weightsMatrix[i][x.size] + if margin > max_margin: + max_margin = margin + best_class = i + 1 + else: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i]) + if margin > max_margin: + max_margin = margin + best_class = i + 1 + return best_class + + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( + _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + numFeatures = java_model.numFeatures() + numClasses = java_model.numClasses() + threshold = java_model.getThreshold().get() + model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses) + model.setThreshold(threshold) + return model class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.01, regType="l2", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False, + validateData=True): """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.01). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.01). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -151,15 +266,19 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -168,16 +287,19 @@ class LogisticRegressionWithLBFGS(object): @classmethod def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4): + intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.01). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.01). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -186,20 +308,27 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param corrections: The number of corrections used in the LBFGS - update (default: 10). - :param tolerance: The convergence tolerance of iterations for - L-BFGS (default: 1e-4). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param corrections: The number of corrections used in the + LBFGS update (default: 10). + :param tolerance: The convergence tolerance of iterations + for L-BFGS (default: 1e-4). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before + training. (default: True) + :param numClasses: The number of classes (i.e., outcomes) a + label can take in Multinomial Logistic + Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -207,15 +336,27 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, - float(regParam), str(regType), bool(intercept), int(corrections), - float(tolerance)) - + float(regParam), regType, bool(intercept), int(corrections), + float(tolerance), bool(validateData), int(numClasses)) + + if initialWeights is None: + if numClasses == 2: + initialWeights = [0.0] * len(data.first().features) + else: + if intercept: + initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1) + else: + initialWeights = [0.0] * len(data.first().features) * (numClasses - 1) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearBinaryClassificationModel): +class SVMModel(LinearClassificationModel): - """A support vector machine. + """ + Model for Support Vector Machines (SVMs). + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -223,14 +364,14 @@ class SVMModel(LinearBinaryClassificationModel): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(data)) + >>> svm = SVMWithSGD.train(sc.parallelize(data), iterations=10) >>> svm.predict([1.0]) 1 >>> svm.predict(sc.parallelize([[1.0]])).collect() [1] >>> svm.clearThreshold() >>> svm.predict(array([1.0])) - 1.25... + 1.44... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: -1.0})), @@ -238,11 +379,24 @@ class SVMModel(LinearBinaryClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data)) + >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> svm.predict(SparseVector(2, {1: 1.0})) 1 >>> svm.predict(SparseVector(2, {0: -1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> svm.save(sc, path) + >>> sameModel = SVMModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {1: 1.0})) + 1 + >>> sameModel.predict(SparseVector(2, {0: -1.0})) + 0 + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except: + ... pass """ def __init__(self, weights, intercept): super(SVMModel, self).__init__(weights, intercept) @@ -250,8 +404,8 @@ def __init__(self, weights, intercept): def predict(self, x): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -263,25 +417,45 @@ def predict(self, x): else: return 1 if margin > self._threshold else 0 + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + threshold = java_model.getThreshold().get() + model = SVMModel(weights, intercept) + model.setThreshold(threshold) + return model + class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): + miniBatchFraction=1.0, initialWeights=None, regType="l2", + intercept=False, validateData=True): """ Train a support vector machine on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param regParam: The regularizer parameter (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regType: The type of regularizer used for training - our model. + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -290,27 +464,34 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) -class NaiveBayesModel(object): +@inherit_doc +class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. - Contains two parameters: - - pi: vector of logs of class priors (dimension C) - - theta: matrix of logs of class conditional probabilities (CxD) + :param labels: list of labels. + :param pi: log of class priors, whose dimension is C, + number of labels. + :param theta: log of class conditional probabilities, whose + dimension is C-by-D, where D is number of features. >>> data = [ ... LabeledPoint(0.0, [0.0, 0.0]), @@ -334,6 +515,17 @@ class NaiveBayesModel(object): 0.0 >>> model.predict(SparseVector(2, {0: 1.0})) 1.0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = NaiveBayesModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self, labels, pi, theta): @@ -342,36 +534,109 @@ def __init__(self, labels, pi, theta): self.theta = theta def predict(self, x): - """Return the most likely class for a data vector or an RDD of vectors""" + """ + Return the most likely class for a data vector + or an RDD of vectors + """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] + def save(self, sc, path): + java_labels = _py2java(sc, self.labels.tolist()) + java_pi = _py2java(sc, self.pi.tolist()) + java_theta = _py2java(sc, self.theta.tolist()) + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel( + java_labels, java_pi, java_theta) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( + sc._jsc.sc(), path) + # Can not unpickle array.array from Pyrolite in Python3 with "bytes" + py_labels = _java2py(sc, java_model.labels(), "latin1") + py_pi = _java2py(sc, java_model.pi(), "latin1") + py_theta = _java2py(sc, java_model.theta(), "latin1") + return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta)) + class NaiveBayes(object): @classmethod def train(cls, data, lambda_=1.0): """ - Train a Naive Bayes model given an RDD of (label, features) vectors. + Train a Naive Bayes model given an RDD of (label, features) + vectors. - This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can - handle all kinds of discrete data. For example, by converting - documents into TF-IDF vectors, it can be used for document - classification. By making every vector a 0-1 vector, it can also be - used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). + This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which + can handle all kinds of discrete data. For example, by + converting documents into TF-IDF vectors, it can be used for + document classification. By making every vector a 0-1 vector, + it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). + The input feature values must be nonnegative. :param data: RDD of LabeledPoint. - :param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter (default: 1.0). """ first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("`data` should be an RDD of LabeledPoint") - labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) + labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) +@inherit_doc +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LogisticRegression with SGD on a batch of data. + + The weights obtained at the end of training a stream are used as initial + weights for the next batch. + + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Number of iterations run for each batch of data. + :param miniBatchFraction: Fraction of data on which SGD is run for each + iteration. + :param regParam: L2 Regularization parameter. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + self.stepSize = stepSize + self.numIterations = numIterations + self.regParam = regParam + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLogisticRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn. + """ + initialWeights = _convert_to_vector(initialWeights) + + # LogisticRegressionWithSGD does only binary classification. + self._model = LogisticRegressionModel( + initialWeights, 0, initialWeights.size, 2) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LogisticRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LogisticRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index f6b97abb1723..900ade248c38 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,28 +15,51 @@ # limitations under the License. # -from numpy import array +import sys +import array as pyarray + +if sys.version > '3': + xrange = range + basestring = str + +from math import exp, log + +from numpy import array, random, tile + +from collections import namedtuple -from pyspark import RDD from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc -from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector +from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector +from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable +from pyspark.streaming import DStream -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', + 'PowerIterationClusteringModel', 'PowerIterationClustering', + 'StreamingKMeans', 'StreamingKMeansModel', + 'LDA', 'LDAModel'] -class KMeansModel(object): +@inherit_doc +class KMeansModel(Saveable, Loader): """A clustering model derived from the k-means method. >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) >>> model = KMeans.train( - ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") + ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random", + ... seed=50, initializationSteps=5, epsilon=1e-4) >>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0])) True >>> model.predict(array([8.0, 9.0])) == model.predict(array([9.0, 8.0])) True + >>> model.k + 2 + >>> model.computeCost(sc.parallelize(data)) + 2.0000000000000004 >>> model = KMeans.train(sc.parallelize(data), 2) >>> sparse_data = [ ... SparseVector(3, {1: 1.0}), @@ -44,7 +67,8 @@ class KMeansModel(object): ... SparseVector(3, {2: 1.0}), ... SparseVector(3, {2: 1.1}) ... ] - >>> model = KMeans.train(sc.parallelize(sparse_data), 2, initializationMode="k-means||") + >>> model = KMeans.train(sc.parallelize(sparse_data), 2, initializationMode="k-means||", + ... seed=50, initializationSteps=5, epsilon=1e-4) >>> model.predict(array([0., 1., 0.])) == model.predict(array([0, 1.1, 0.])) True >>> model.predict(array([0., 0., 1.])) == model.predict(array([0, 0, 1.1])) @@ -53,8 +77,19 @@ class KMeansModel(object): True >>> model.predict(sparse_data[2]) == model.predict(sparse_data[3]) True - >>> type(model.clusterCenters) - + >>> isinstance(model.clusterCenters, list) + True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = KMeansModel.load(sc, path) + >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self, centers): @@ -65,10 +100,18 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return self.centers + @property + def k(self): + """Total number of clusters.""" + return len(self.centers) + def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 best_distance = float("inf") + if isinstance(x, RDD): + return x.map(self.predict) + x = _convert_to_vector(x) for i in xrange(len(self.centers)): distance = x.squared_distance(self.centers[i]) @@ -77,21 +120,50 @@ def predict(self, x): best_distance = distance return best + def computeCost(self, rdd): + """ + Return the K-means cost (sum of squared distances of points to + their nearest center) for this model on the given data. + """ + cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), + [_convert_to_vector(c) for c in self.centers]) + return cost + + def save(self, sc, path): + java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers]) + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path) + return KMeansModel(_java2py(sc, java_model.clusterCenters())) + class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", + seed=None, initializationSteps=5, epsilon=1e-4): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode, seed) + runs, initializationMode, seed, initializationSteps, epsilon) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) -class GaussianMixtureModel(object): +@inherit_doc +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental - """A clustering model derived from the Gaussian Mixture Model method. + A clustering model derived from the Gaussian Mixture Model method. + + >>> from pyspark.mllib.linalg import Vectors, DenseMatrix + >>> from numpy.testing import assert_equal + >>> from shutil import rmtree + >>> import os, tempfile >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, @@ -105,13 +177,33 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True - >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220, - ... -5.2211, -5.0602, 4.7118, - ... 6.8989, 3.4592, 4.6322, - ... 5.7048, 4.6567, 5.5026, - ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) + + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = GaussianMixtureModel.load(sc, path) + >>> assert_equal(model.weights, sameModel.weights) + >>> mus, sigmas = list( + ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) + >>> sameMus, sameSigmas = list( + ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) + >>> mus == sameMus + True + >>> sigmas == sameSigmas + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + + >>> data = array([-5.1971, -2.5359, -3.8220, + ... -5.2211, -5.0602, 4.7118, + ... 6.8989, 3.4592, 4.6322, + ... 5.7048, 4.6567, 5.5026, + ... 4.5605, 5.2043, 6.2734]) + >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, - ... maxIterations=150, seed=10) + ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1]==labels[2] True @@ -119,10 +211,28 @@ class GaussianMixtureModel(object): True """ - def __init__(self, weights, gaussians): - self.weights = weights - self.gaussians = gaussians - self.k = len(self.weights) + @property + def weights(self): + """ + Weights for each Gaussian distribution in the mixture, where weights[i] is + the weight for Gaussian i, and weights.sum == 1. + """ + return array(self.call("weights")) + + @property + def gaussians(self): + """ + Array of MultivariateGaussian where gaussians[i] represents + the Multivariate Gaussian (Normal) Distribution for Gaussian i. + """ + return [ + MultivariateGaussian(gaussian[0], gaussian[1]) + for gaussian in zip(*self.call("gaussians"))] + + @property + def k(self): + """Number of gaussians in mixture.""" + return len(self.weights) def predict(self, x): """ @@ -135,6 +245,9 @@ def predict(self, x): if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) def predictSoft(self, x): """ @@ -146,33 +259,450 @@ def predictSoft(self, x): if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - self.weights, means, sigmas) - return membership_matrix + _convert_to_vector(self.weights), means, sigmas) + return membership_matrix.map(lambda x: pyarray.array('d', x)) + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) + + @classmethod + def load(cls, sc, path): + """Load the GaussianMixtureModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.GaussianMixtureModelWrapper(model) + return cls(wrapper) class GaussianMixture(object): """ - Estimate model parameters with the expectation-maximization algorithm. + .. note:: Experimental + + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. :param data: RDD of data points :param k: Number of components :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 :param maxIterations: Number of iterations. Default to 100 :param seed: Random Seed + :param initialModel: GaussianMixtureModel for initializing learning """ @classmethod - def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None): + def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): """Train a Gaussian Mixture clustering model.""" - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", - rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed) - mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] - return GaussianMixtureModel(weight, mvg_obj) + initialModelWeights = None + initialModelMu = None + initialModelSigma = None + if initialModel is not None: + if initialModel.k != k: + raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" + % (initialModel.k, k)) + initialModelWeights = initialModel.weights + initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] + initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] + java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) + return GaussianMixtureModel(java_model) + + +class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental + + Model produced by [[PowerIterationClustering]]. + + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0), + ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), + ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0), + ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)] + >>> rdd = sc.parallelize(data, 2) + >>> model = PowerIterationClustering.train(rdd, 2, 100) + >>> model.k + 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = PowerIterationClusteringModel.load(sc, path) + >>> sameModel.k + 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + """ + + @property + def k(self): + """ + Returns the number of clusters. + """ + return self.call("k") + + def assignments(self): + """ + Returns the cluster assignments of this model. + """ + return self.call("getAssignments").map( + lambda x: (PowerIterationClustering.Assignment(*x))) + + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) + return PowerIterationClusteringModel(wrapper) + + +class PowerIterationClustering(object): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm + developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. + From the abstract: PIC finds a very low-dimensional embedding of a + dataset using truncated power iteration on a normalized pair-wise + similarity matrix of the data. + """ + + @classmethod + def train(cls, rdd, k, maxIterations=100, initMode="random"): + """ + :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. + The similarity s,,ij,, must be nonnegative. + This is a symmetric matrix and hence s,,ij,, = s,,ji,,. + For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. + Tuples with i = j are ignored, because we assume + s,,ij,, = 0.0. + :param k: Number of clusters. + :param maxIterations: Maximum number of iterations of the + PIC algorithm. + :param initMode: Initialization mode. + """ + model = callMLlibFunc("trainPowerIterationClusteringModel", + rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) + return PowerIterationClusteringModel(model) + + class Assignment(namedtuple("Assignment", ["id", "cluster"])): + """ + Represents an (id, cluster) tuple. + """ + + +class StreamingKMeansModel(KMeansModel): + """ + .. note:: Experimental + + Clustering model which can perform an online update of the centroids. + + The update formula for each centroid is given by + + * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) + * n_t+1 = n_t * a + m_t + + where + + * c_t: Centroid at the n_th iteration. + * n_t: Number of samples (or) weights associated with the centroid + at the n_th iteration. + * x_t: Centroid of the new data closest to c_t. + * m_t: Number of samples (or) weights of the new data closest to c_t + * c_t+1: New centroid. + * n_t+1: New number of weights. + * a: Decay Factor, which gives the forgetfulness. + + Note that if a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. + + :param clusterCenters: Initial cluster centers. + :param clusterWeights: List of weights assigned to each cluster. + + >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] + >>> initWeights = [1.0, 1.0] + >>> stkm = StreamingKMeansModel(initCenters, initWeights) + >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], + ... [0.9, 0.9], [1.1, 1.1]]) + >>> stkm = stkm.update(data, 1.0, u"batches") + >>> stkm.centers + array([[ 0., 0.], + [ 1., 1.]]) + >>> stkm.predict([-0.1, -0.1]) + 0 + >>> stkm.predict([0.9, 0.9]) + 1 + >>> stkm.clusterWeights + [3.0, 3.0] + >>> decayFactor = 0.0 + >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) + >>> stkm = stkm.update(data, 0.0, u"batches") + >>> stkm.centers + array([[ 0.2, 0.2], + [ 1.5, 1.5]]) + >>> stkm.clusterWeights + [1.0, 1.0] + >>> stkm.predict([0.2, 0.2]) + 0 + >>> stkm.predict([1.5, 1.5]) + 1 + """ + def __init__(self, clusterCenters, clusterWeights): + super(StreamingKMeansModel, self).__init__(centers=clusterCenters) + self._clusterWeights = list(clusterWeights) + + @property + def clusterWeights(self): + """Return the cluster weights.""" + return self._clusterWeights + + @ignore_unicode_prefix + def update(self, data, decayFactor, timeUnit): + """Update the centroids, according to data + + :param data: Should be a RDD that represents the new data. + :param decayFactor: forgetfulness of the previous centroids. + :param timeUnit: Can be "batches" or "points". If points, then the + decay factor is raised to the power of number of new + points and if batches, it is used as it is. + """ + if not isinstance(data, RDD): + raise TypeError("Data should be of an RDD, got %s." % type(data)) + data = data.map(_convert_to_vector) + decayFactor = float(decayFactor) + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + vectorCenters = [_convert_to_vector(center) for center in self.centers] + updatedModel = callMLlibFunc( + "updateStreamingKMeansModel", vectorCenters, self._clusterWeights, + data, decayFactor, timeUnit) + self.centers = array(updatedModel[0]) + self._clusterWeights = list(updatedModel[1]) + return self + + +class StreamingKMeans(object): + """ + .. note:: Experimental + + Provides methods to set k, decayFactor, timeUnit to configure the + KMeans algorithm for fitting and predicting on incoming dstreams. + More details on how the centroids are updated are provided under the + docs of StreamingKMeansModel. + + :param k: int, number of clusters + :param decayFactor: float, forgetfulness of the previous centroids. + :param timeUnit: can be "batches" or "points". If points, then the + decayfactor is raised to the power of no. of new points. + """ + def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): + self._k = k + self._decayFactor = decayFactor + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + self._timeUnit = timeUnit + self._model = None + + def latestModel(self): + """Return the latest model""" + return self._model + + def _validate(self, dstream): + if self._model is None: + raise ValueError( + "Initial centers should be set either by setInitialCenters " + "or setRandomCenters.") + if not isinstance(dstream, DStream): + raise TypeError( + "Expected dstream to be of type DStream, " + "got type %s" % type(dstream)) + + def setK(self, k): + """Set number of clusters.""" + self._k = k + return self + + def setDecayFactor(self, decayFactor): + """Set decay factor.""" + self._decayFactor = decayFactor + return self + + def setHalfLife(self, halfLife, timeUnit): + """ + Set number of batches after which the centroids of that + particular batch has half the weightage. + """ + self._timeUnit = timeUnit + self._decayFactor = exp(log(0.5) / halfLife) + return self + + def setInitialCenters(self, centers, weights): + """ + Set initial centers. Should be set before calling trainOn. + """ + self._model = StreamingKMeansModel(centers, weights) + return self + + def setRandomCenters(self, dim, weight, seed): + """ + Set the initial centres to be random samples from + a gaussian population with constant weights. + """ + rng = random.RandomState(seed) + clusterCenters = rng.randn(self._k, dim) + clusterWeights = tile(weight, self._k) + self._model = StreamingKMeansModel(clusterCenters, clusterWeights) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + self._model.update(rdd, self._decayFactor, self._timeUnit) + + dstream.foreachRDD(update) + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + Returns a transformed dstream object + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + Returns a transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +class LDAModel(JavaModelWrapper): + + """ A clustering model derived from the LDA method. + + Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + Terminology + - "word" = "term": an element of the vocabulary + - "token": instance of a term appearing in a document + - "topic": multinomial distribution over words representing some concept + References: + - Original LDA paper (journal version): + Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + + >>> from pyspark.mllib.linalg import Vectors + >>> from numpy.testing import assert_almost_equal, assert_equal + >>> data = [ + ... [1, Vectors.dense([0.0, 1.0])], + ... [2, SparseVector(2, {0: 1.0})], + ... ] + >>> rdd = sc.parallelize(data) + >>> model = LDA.train(rdd, k=2) + >>> model.vocabSize() + 2 + >>> topics = model.topicsMatrix() + >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) + >>> assert_almost_equal(topics, topics_expect, 1) + + >>> import os, tempfile + >>> from shutil import rmtree + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = LDAModel.load(sc, path) + >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix()) + >>> sameModel.vocabSize() == model.vocabSize() + True + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + """ + + def topicsMatrix(self): + """Inferred topics, where each topic is represented by a distribution over terms.""" + return self.call("topicsMatrix").toArray() + + def vocabSize(self): + """Vocabulary size (number of terms or terms in the vocabulary)""" + return self.call("vocabSize") + + def save(self, sc, path): + """Save the LDAModel on to disk. + + :param sc: SparkContext + :param path: str, path to where the model needs to be stored. + """ + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + """Load the LDAModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( + sc._jsc.sc(), path) + return cls(java_model) + + +class LDA(object): + + @classmethod + def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, + topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): + """Train a LDA model. + + :param rdd: RDD of data points + :param k: Number of clusters you want + :param maxIterations: Number of iterations. Default to 20 + :param docConcentration: Concentration parameter (commonly named "alpha") + for the prior placed on documents' distributions over topics ("theta"). + :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") + for the prior placed on topics' distributions over terms. + :param seed: Random Seed + :param checkpointInterval: Period (in iterations) between checkpoints. + :param optimizer: LDAOptimizer used to perform the actual calculation. + Currently "em", "online" are supported. Default to "em". + """ + model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, + docConcentration, topicConcentration, seed, + checkpointInterval, optimizer) + return LDAModel(model) def _test(): import doctest - globs = globals().copy() + import pyspark.mllib.clustering + globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 3c5ee66cd8b6..a439a488de5c 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -15,6 +15,11 @@ # limitations under the License. # +import sys +if sys.version >= '3': + long = int + unicode = str + import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject @@ -22,7 +27,7 @@ from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - +from pyspark.sql import DataFrame, SQLContext # Hack for support float('inf') in Py4j _old_smart_decode = py4j.protocol.smart_decode @@ -36,7 +41,7 @@ def _new_smart_decode(obj): if isinstance(obj, float): - s = unicode(obj) + s = str(obj) return _float_str_mapping.get(s, s) return _old_smart_decode(obj) @@ -68,21 +73,23 @@ def _py2java(sc, obj): """ Convert Python object into Java """ if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) + elif isinstance(obj, DataFrame): + obj = obj._jdf elif isinstance(obj, SparkContext): obj = obj._jsc - elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)): - obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, list): + obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass - elif isinstance(obj, (int, long, float, bool, basestring)): + elif isinstance(obj, (int, long, float, bool, bytes, unicode)): pass else: - bytes = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.SerDe.loads(bytes) + data = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(data) return obj -def _java2py(sc, r): +def _java2py(sc, r, encoding="bytes"): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -94,6 +101,9 @@ def _java2py(sc, r): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) + if clsName == 'DataFrame': + return DataFrame(r, SQLContext(sc)) + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): @@ -102,8 +112,8 @@ def _java2py(sc, r): except Py4JJavaError: pass # not pickable - if isinstance(r, bytearray): - r = PickleSerializer().loads(str(r)) + if isinstance(r, (bytearray, bytes)): + r = PickleSerializer().loads(bytes(r), encoding=encoding) return r @@ -134,3 +144,20 @@ def __del__(self): def call(self, name, *a): """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) + + +def inherit_doc(cls): + """ + A decorator that makes a class inherit documentation from its parents. + """ + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py new file mode 100644 index 000000000000..4398ca86f2ec --- /dev/null +++ b/python/pyspark/mllib/evaluation.py @@ -0,0 +1,489 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.sql import SQLContext +from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType + +__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics', + 'MulticlassMetrics', 'RankingMetrics'] + + +class BinaryClassificationMetrics(JavaModelWrapper): + """ + Evaluator for binary classification. + + :param scoreAndLabels: an RDD of (score, label) pairs + + >>> scoreAndLabels = sc.parallelize([ + ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) + >>> metrics = BinaryClassificationMetrics(scoreAndLabels) + >>> metrics.areaUnderROC + 0.70... + >>> metrics.areaUnderPR + 0.83... + >>> metrics.unpersist() + """ + + def __init__(self, scoreAndLabels): + sc = scoreAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ + StructField("score", DoubleType(), nullable=False), + StructField("label", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics + java_model = java_class(df._jdf) + super(BinaryClassificationMetrics, self).__init__(java_model) + + @property + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + """ + return self.call("areaUnderROC") + + @property + def areaUnderPR(self): + """ + Computes the area under the precision-recall curve. + """ + return self.call("areaUnderPR") + + def unpersist(self): + """ + Unpersists intermediate RDDs used in the computation. + """ + self.call("unpersist") + + +class RegressionMetrics(JavaModelWrapper): + """ + Evaluator for regression. + + :param predictionAndObservations: an RDD of (prediction, + observation) pairs. + + >>> predictionAndObservations = sc.parallelize([ + ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) + >>> metrics = RegressionMetrics(predictionAndObservations) + >>> metrics.explainedVariance + 8.859... + >>> metrics.meanAbsoluteError + 0.5... + >>> metrics.meanSquaredError + 0.37... + >>> metrics.rootMeanSquaredError + 0.61... + >>> metrics.r2 + 0.94... + """ + + def __init__(self, predictionAndObservations): + sc = predictionAndObservations.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ + StructField("prediction", DoubleType(), nullable=False), + StructField("observation", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics + java_model = java_class(df._jdf) + super(RegressionMetrics, self).__init__(java_model) + + @property + def explainedVariance(self): + """ + Returns the explained variance regression score. + explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + """ + return self.call("explainedVariance") + + @property + def meanAbsoluteError(self): + """ + Returns the mean absolute error, which is a risk function corresponding to the + expected value of the absolute error loss or l1-norm loss. + """ + return self.call("meanAbsoluteError") + + @property + def meanSquaredError(self): + """ + Returns the mean squared error, which is a risk function corresponding to the + expected value of the squared error loss or quadratic loss. + """ + return self.call("meanSquaredError") + + @property + def rootMeanSquaredError(self): + """ + Returns the root mean squared error, which is defined as the square root of + the mean squared error. + """ + return self.call("rootMeanSquaredError") + + @property + def r2(self): + """ + Returns R^2^, the coefficient of determination. + """ + return self.call("r2") + + +class MulticlassMetrics(JavaModelWrapper): + """ + Evaluator for multiclass classification. + + :param predictionAndLabels an RDD of (prediction, label) pairs. + + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) + >>> metrics = MulticlassMetrics(predictionAndLabels) + >>> metrics.confusionMatrix().toArray() + array([[ 2., 1., 1.], + [ 1., 3., 0.], + [ 0., 0., 1.]]) + >>> metrics.falsePositiveRate(0.0) + 0.2... + >>> metrics.precision(1.0) + 0.75... + >>> metrics.recall(2.0) + 1.0... + >>> metrics.fMeasure(0.0, 2.0) + 0.52... + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.66... + >>> metrics.weightedFalsePositiveRate + 0.19... + >>> metrics.weightedPrecision + 0.68... + >>> metrics.weightedRecall + 0.66... + >>> metrics.weightedFMeasure() + 0.66... + >>> metrics.weightedFMeasure(2.0) + 0.65... + """ + + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ + StructField("prediction", DoubleType(), nullable=False), + StructField("label", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics + java_model = java_class(df._jdf) + super(MulticlassMetrics, self).__init__(java_model) + + def confusionMatrix(self): + """ + Returns confusion matrix: predicted classes are in columns, + they are ordered by class label ascending, as in "labels". + """ + return self.call("confusionMatrix") + + def truePositiveRate(self, label): + """ + Returns true positive rate for a given label (category). + """ + return self.call("truePositiveRate", label) + + def falsePositiveRate(self, label): + """ + Returns false positive rate for a given label (category). + """ + return self.call("falsePositiveRate", label) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def fMeasure(self, label=None, beta=None): + """ + Returns f-measure or f-measure for a given label (category) if specified. + """ + if beta is None: + if label is None: + return self.call("fMeasure") + else: + return self.call("fMeasure", label) + else: + if label is None: + raise Exception("If the beta parameter is specified, label can not be none") + else: + return self.call("fMeasure", label, beta) + + @property + def weightedTruePositiveRate(self): + """ + Returns weighted true positive rate. + (equals to precision, recall and f-measure) + """ + return self.call("weightedTruePositiveRate") + + @property + def weightedFalsePositiveRate(self): + """ + Returns weighted false positive rate. + """ + return self.call("weightedFalsePositiveRate") + + @property + def weightedRecall(self): + """ + Returns weighted averaged recall. + (equals to precision, recall and f-measure) + """ + return self.call("weightedRecall") + + @property + def weightedPrecision(self): + """ + Returns weighted averaged precision. + """ + return self.call("weightedPrecision") + + def weightedFMeasure(self, beta=None): + """ + Returns weighted averaged f-measure. + """ + if beta is None: + return self.call("weightedFMeasure") + else: + return self.call("weightedFMeasure", beta) + + +class RankingMetrics(JavaModelWrapper): + """ + Evaluator for ranking algorithms. + + :param predictionAndLabels: an RDD of (predicted ranking, + ground truth set) pairs. + + >>> predictionAndLabels = sc.parallelize([ + ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), + ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), + ... ([1, 2, 3, 4, 5], [])]) + >>> metrics = RankingMetrics(predictionAndLabels) + >>> metrics.precisionAt(1) + 0.33... + >>> metrics.precisionAt(5) + 0.26... + >>> metrics.precisionAt(15) + 0.17... + >>> metrics.meanAveragePrecision + 0.35... + >>> metrics.ndcgAt(3) + 0.33... + >>> metrics.ndcgAt(10) + 0.48... + + """ + + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, + schema=sql_ctx._inferSchema(predictionAndLabels)) + java_model = callMLlibFunc("newRankingMetrics", df._jdf) + super(RankingMetrics, self).__init__(java_model) + + def precisionAt(self, k): + """ + Compute the average precision of all the queries, truncated at ranking position k. + + If for a query, the ranking algorithm returns n (n < k) results, the precision value + will be computed as #(relevant items retrieved) / k. This formula also applies when + the size of the ground truth set is less than k. + + If a query has an empty ground truth set, zero will be used as precision together + with a log warning. + """ + return self.call("precisionAt", int(k)) + + @property + def meanAveragePrecision(self): + """ + Returns the mean average precision (MAP) of all the queries. + If a query has an empty ground truth set, the average precision will be zero and + a log warining is generated. + """ + return self.call("meanAveragePrecision") + + def ndcgAt(self, k): + """ + Compute the average NDCG value of all the queries, truncated at ranking position k. + The discounted cumulative gain at position k is computed as: + sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + and the NDCG is obtained by dividing the DCG value on the ground truth set. + In the current implementation, the relevance value is binary. + If a query has an empty ground truth set, zero will be used as NDCG together with + a log warning. + """ + return self.call("ndcgAt", int(k)) + + +class MultilabelMetrics(JavaModelWrapper): + """ + Evaluator for multilabel classification. + + :param predictionAndLabels: an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with + unique elements. + + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), + ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), + ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) + >>> metrics = MultilabelMetrics(predictionAndLabels) + >>> metrics.precision(0.0) + 1.0 + >>> metrics.recall(1.0) + 0.66... + >>> metrics.f1Measure(2.0) + 0.5 + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.64... + >>> metrics.f1Measure() + 0.63... + >>> metrics.microPrecision + 0.72... + >>> metrics.microRecall + 0.66... + >>> metrics.microF1Measure + 0.69... + >>> metrics.hammingLoss + 0.33... + >>> metrics.subsetAccuracy + 0.28... + >>> metrics.accuracy + 0.54... + """ + + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, + schema=sql_ctx._inferSchema(predictionAndLabels)) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics + java_model = java_class(df._jdf) + super(MultilabelMetrics, self).__init__(java_model) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def f1Measure(self, label=None): + """ + Returns f1Measure or f1Measure for a given label (category) if specified. + """ + if label is None: + return self.call("f1Measure") + else: + return self.call("f1Measure", float(label)) + + @property + def microPrecision(self): + """ + Returns micro-averaged label-based precision. + (equals to micro-averaged document-based precision) + """ + return self.call("microPrecision") + + @property + def microRecall(self): + """ + Returns micro-averaged label-based recall. + (equals to micro-averaged document-based recall) + """ + return self.call("microRecall") + + @property + def microF1Measure(self): + """ + Returns micro-averaged label-based f1-measure. + (equals to micro-averaged document-based f1-measure) + """ + return self.call("microF1Measure") + + @property + def hammingLoss(self): + """ + Returns Hamming-loss. + """ + return self.call("hammingLoss") + + @property + def subsetAccuracy(self): + """ + Returns subset accuracy. + (for equal sets of labels) + """ + return self.call("subsetAccuracy") + + @property + def accuracy(self): + """ + Returns accuracy. + """ + return self.call("accuracy") + + +def _test(): + import doctest + from pyspark import SparkContext + import pyspark.mllib.evaluation + globs = pyspark.mllib.evaluation.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 10df6288065b..f921e3ad1a31 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -23,15 +23,24 @@ import sys import warnings import random +import binascii +if sys.version >= '3': + basestring = str + unicode = str from py4j.protocol import Py4JJavaError -from pyspark import RDD, SparkContext +from pyspark import SparkContext +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector +from pyspark.mllib.linalg import ( + Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', - 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] + 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', + 'ChiSqSelector', 'ChiSqSelectorModel', 'ElementwiseProduct'] class VectorTransformer(object): @@ -58,7 +67,10 @@ class Normalizer(VectorTransformer): For any 1 <= `p` < float('inf'), normalizes samples using sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm. - For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + For `p` = float('inf'), max(abs(vector)) will be used as norm for + normalization. + + :param p: Normalization in L^p^ space, p = 2 by default. >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) @@ -74,9 +86,6 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ def __init__(self, p=2.0): - """ - :param p: Normalization in L^p^ space, p = 2 by default. - """ assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) @@ -86,7 +95,7 @@ def transform(self, vector): :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" @@ -103,6 +112,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ def transform(self, vector): + """ + Applies transformation on a vector or an RDD[Vector]. + + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param vector: Vector or RDD of Vector to be transformed. + """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -120,12 +138,33 @@ def transform(self, vector): """ Applies standardization transformation on a vector. + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + :param vector: Vector or RDD of Vector to be standardized. - :return: Standardized vector. If the variance of a column is zero, - it will return default `0.0` for the column with zero variance. + :return: Standardized vector. If the variance of a column is + zero, it will return default `0.0` for the column with + zero variance. """ return JavaVectorTransformer.transform(self, vector) + def setWithMean(self, withMean): + """ + Setter of the boolean which decides + whether it uses mean or not + """ + self.call("setWithMean", withMean) + return self + + def setWithStd(self, withStd): + """ + Setter of the boolean which decides + whether it uses std or not + """ + self.call("setWithStd", withStd) + return self + class StandardScaler(object): """ @@ -135,6 +174,13 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -145,13 +191,6 @@ class StandardScaler(object): DenseVector([0.7071, -0.7071, 0.7071]) """ def __init__(self, withMean=False, withStd=True): - """ - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so this - does not work on sparse input and will raise an exception. - :param withStd: True by default. Scales the data to unit standard - deviation. - """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean @@ -159,10 +198,11 @@ def __init__(self, withMean=False, withStd=True): def fit(self, dataset): """ - Computes the mean and variance and stores as a model to be used for later scaling. + Computes the mean and variance and stores as a model to be used + for later scaling. - :param data: The data used to compute the mean and variance to build - the transformation model. + :param dataset: The data used to compute the mean and variance + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -170,23 +210,110 @@ def fit(self, dataset): return StandardScalerModel(jmodel) +class ChiSqSelectorModel(JavaVectorTransformer): + """ + .. note:: Experimental + + Represents a Chi Squared selector model. + """ + def transform(self, vector): + """ + Applies transformation on a vector. + + :param vector: Vector or RDD of Vector to be transformed. + :return: transformed vector. + """ + return JavaVectorTransformer.transform(self, vector) + + +class ChiSqSelector(object): + """ + .. note:: Experimental + + Creates a ChiSquared feature selector. + + :param numTopFeatures: number of features that selector will select. + + >>> data = [ + ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), + ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), + ... LabeledPoint(1.0, [0.0, 9.0, 8.0]), + ... LabeledPoint(2.0, [8.0, 9.0, 5.0]) + ... ] + >>> model = ChiSqSelector(1).fit(sc.parallelize(data)) + >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) + SparseVector(1, {0: 6.0}) + >>> model.transform(DenseVector([8.0, 9.0, 5.0])) + DenseVector([5.0]) + """ + def __init__(self, numTopFeatures): + self.numTopFeatures = int(numTopFeatures) + + def fit(self, data): + """ + Returns a ChiSquared feature selector. + + :param data: an `RDD[LabeledPoint]` containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. + """ + jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) + return ChiSqSelectorModel(jmodel) + + +class PCAModel(JavaVectorTransformer): + """ + Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + """ + + +class PCA(object): + """ + A feature transformer that projects vectors to a low-dimensional space using PCA. + + >>> data = [Vectors.sparse(5, [(1, 1.0), (3, 7.0)]), + ... Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]), + ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] + >>> model = PCA(2).fit(sc.parallelize(data)) + >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray() + >>> pcArray[0] + 1.648... + >>> pcArray[1] + -4.013... + """ + def __init__(self, k): + """ + :param k: number of principal components. + """ + self.k = int(k) + + def fit(self, data): + """ + Computes a [[PCAModel]] that contains the principal components of the input vectors. + :param data: source vectors + """ + jmodel = callMLlibFunc("fitPCA", self.k, data) + return PCAModel(jmodel) + + class HashingTF(object): """ .. note:: Experimental - Maps a sequence of terms to their term frequencies using the hashing trick. + Maps a sequence of terms to their term frequencies using the hashing + trick. Note: the terms must be hashable (can not be dict/set/list...). + :param numFeatures: number of features (default: 2^20) + >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) - SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0}) + SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): - """ - :param numFeatures: number of features (default: 2^20) - """ self.numFeatures = numFeatures def indexOf(self, term): @@ -195,8 +322,9 @@ def indexOf(self, term): def transform(self, document): """ - Transforms the input document (list of terms) to term frequency vectors, - or transform the RDD of document to RDD of term frequency vectors. + Transforms the input document (list of terms) to term frequency + vectors, or transform the RDD of document to RDD of term + frequency vectors. """ if isinstance(document, RDD): return document.map(self.transform) @@ -220,15 +348,22 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - :param x: an RDD of term frequency vectors or a term frequency vector + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param x: an RDD of term frequency vectors or a term frequency + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if isinstance(x, RDD): - return JavaVectorTransformer.transform(self, x) - - x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) + def idf(self): + """ + Returns the current IDF vector. + """ + return self.call('idf') + class IDF(object): """ @@ -241,9 +376,12 @@ class IDF(object): of documents that contain term `t`. This implementation supports filtering out terms which do not appear - in a minimum number of documents (controlled by the variable `minDocFreq`). - For terms that are not in at least `minDocFreq` documents, the IDF is - found as 0, resulting in TF-IDFs of 0. + in a minimum number of documents (controlled by the variable + `minDocFreq`). For terms that are not in at least `minDocFreq` + documents, the IDF is found as 0, resulting in TF-IDFs of 0. + + :param minDocFreq: minimum of documents in which a term + should appear for filtering >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), @@ -265,10 +403,6 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): - """ - :param minDocFreq: minimum of documents in which a term - should appear for filtering - """ self.minDocFreq = minDocFreq def fit(self, dataset): @@ -283,7 +417,7 @@ def fit(self, dataset): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model """ @@ -316,7 +450,20 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + + @classmethod + def load(cls, sc, path): + jmodel = sc._jvm.org.apache.spark.mllib.feature \ + .Word2VecModel.load(sc._jsc.sc(), path) + return Word2VecModel(jmodel) + +@ignore_unicode_prefix class Word2Vec(object): """ Word2Vec creates vector representation of words in a text corpus. @@ -325,20 +472,21 @@ class Word2Vec(object): The vector representation can be used as features in natural language processing and machine learning algorithms. - We used skip-gram model in our implementation and hierarchical softmax - method to train the model. The variable names in the implementation - matches the original C implementation. + We used skip-gram model in our implementation and hierarchical + softmax method to train the model. The variable names in the + implementation matches the original C implementation. - For original C implementation, see https://code.google.com/p/word2vec/ + For original C implementation, + see https://code.google.com/p/word2vec/ For research papers, see Efficient Estimation of Word Representations in Vector Space - and - Distributed Representations of Words and Phrases and their Compositionality. + and Distributed Representations of Words and Phrases and their + Compositionality. >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) - >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] @@ -347,6 +495,18 @@ class Word2Vec(object): >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] [u'b', u'c'] + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = Word2VecModel.load(sc, path) + >>> model.transform("a") == sameModel.transform("a") + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self): """ @@ -356,7 +516,8 @@ def __init__(self): self.learningRate = 0.025 self.numPartitions = 1 self.numIterations = 1 - self.seed = random.randint(0, sys.maxint) + self.seed = random.randint(0, sys.maxsize) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -374,15 +535,16 @@ def setLearningRate(self, learningRate): def setNumPartitions(self, numPartitions): """ - Sets number of partitions (default: 1). Use a small number for accuracy. + Sets number of partitions (default: 1). Use a small number for + accuracy. """ self.numPartitions = numPartitions return self def setNumIterations(self, numIterations): """ - Sets number of iterations (default: 1), which should be smaller than or equal to number of - partitions. + Sets number of iterations (default: 1), which should be smaller + than or equal to number of partitions. """ self.numIterations = numIterations return self @@ -394,6 +556,14 @@ def setSeed(self, seed): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -403,12 +573,45 @@ def fit(self, data): """ if not isinstance(data, RDD): raise TypeError("data should be an RDD of list of string") - jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), int(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) +class ElementwiseProduct(VectorTransformer): + """ + .. note:: Experimental + + Scales each column of the vector, with the supplied weight vector. + i.e the elementwise product. + + >>> weight = Vectors.dense([1.0, 2.0, 3.0]) + >>> eprod = ElementwiseProduct(weight) + >>> a = Vectors.dense([2.0, 1.0, 3.0]) + >>> eprod.transform(a) + DenseVector([2.0, 2.0, 9.0]) + >>> b = Vectors.dense([9.0, 3.0, 4.0]) + >>> rdd = sc.parallelize([a, b]) + >>> eprod.transform(rdd).collect() + [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + """ + def __init__(self, scalingVector): + self.scalingVector = _convert_to_vector(scalingVector) + + def transform(self, vector): + """ + Computes the Hadamard product of the vector. + """ + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + + else: + vector = _convert_to_vector(vector) + return callMLlibFunc("elementwiseProductVector", self.scalingVector, vector) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 000000000000..bdc4a132b1b1 --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy +from numpy import array +from collections import namedtuple + +from pyspark import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth', 'FPGrowthModel'] + + +@inherit_doc +@ignore_unicode_prefix +class FPGrowthModel(JavaModelWrapper): + + """ + .. note:: Experimental + + A FP-Growth model for mining frequent itemsets + using the Parallel FP-Growth algorithm. + + >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + >>> rdd = sc.parallelize(data, 2) + >>> model = FPGrowth.train(rdd, 0.6, 2) + >>> sorted(model.freqItemsets().collect()) + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + """ + + def freqItemsets(self): + """ + Returns the frequent itemsets of this model. + """ + return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) + + +class FPGrowth(object): + """ + .. note:: Experimental + + A Parallel FP-growth algorithm to mine frequent itemsets. + """ + + @classmethod + def train(cls, data, minSupport=0.3, numPartitions=-1): + """ + Computes an FP-Growth model that contains frequent itemsets. + + :param data: The input data set, each element contains a + transaction. + :param minSupport: The minimal support level (default: `0.3`). + :param numPartitions: The number of partitions used by + parallel FP-growth (default: same as input data). + """ + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) + return FPGrowthModel(model) + + class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): + """ + Represents an (items, freq) tuple. + """ + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py deleted file mode 100644 index 7f21190ed8c2..000000000000 --- a/python/pyspark/mllib/linalg.py +++ /dev/null @@ -1,657 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -MLlib utilities for linear algebra. For dense vectors, MLlib -uses the NumPy C{array} type, so you can simply pass NumPy arrays -around. For sparse vectors, users can construct a L{SparseVector} -object from MLlib or pass SciPy C{scipy.sparse} column vectors if -SciPy is available in their environment. -""" - -import sys -import array -import copy_reg - -import numpy as np - -from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ - IntegerType, ByteType - - -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] - - -if sys.version_info[:2] == (2, 7): - # speed up pickling array in Python 2.7 - def fast_pickle_array(ar): - return array.array, (ar.typecode, ar.tostring()) - copy_reg.pickle(array.array, fast_pickle_array) - - -# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, -# such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. - -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy in environment, but that's okay - _have_scipy = False - - -def _convert_to_vector(l): - if isinstance(l, Vector): - return l - elif type(l) in (array.array, np.array, np.ndarray, list, tuple): - return DenseVector(l) - elif _have_scipy and scipy.sparse.issparse(l): - assert l.shape[1] == 1, "Expected column vector" - csc = l.tocsc() - return SparseVector(l.shape[0], csc.indices, csc.data) - else: - raise TypeError("Cannot convert type %s into Vector" % type(l)) - - -def _vector_size(v): - """ - Returns the size of the vector. - - >>> _vector_size([1., 2., 3.]) - 3 - >>> _vector_size((1., 2., 3.)) - 3 - >>> _vector_size(array.array('d', [1., 2., 3.])) - 3 - >>> _vector_size(np.zeros(3)) - 3 - >>> _vector_size(np.zeros((3, 1))) - 3 - >>> _vector_size(np.zeros((1, 3))) - Traceback (most recent call last): - ... - ValueError: Cannot treat an ndarray of shape (1, 3) as a vector - """ - if isinstance(v, Vector): - return len(v) - elif type(v) in (array.array, list, tuple): - return len(v) - elif type(v) == np.ndarray: - if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): - return len(v) - else: - raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape)) - elif _have_scipy and scipy.sparse.issparse(v): - assert v.shape[1] == 1, "Expected column vector" - return v.shape[0] - else: - raise TypeError("Cannot treat type %s as a vector" % type(v)) - - -def _format_float(f, digits=4): - s = str(round(f, digits)) - if '.' in s: - s = s[:s.index('.') + 1 + digits] - return s - - -class VectorUDT(UserDefinedType): - """ - SQL user-defined type (UDT) for Vector. - """ - - @classmethod - def sqlType(cls): - return StructType([ - StructField("type", ByteType(), False), - StructField("size", IntegerType(), True), - StructField("indices", ArrayType(IntegerType(), False), True), - StructField("values", ArrayType(DoubleType(), False), True)]) - - @classmethod - def module(cls): - return "pyspark.mllib.linalg" - - @classmethod - def scalaUDT(cls): - return "org.apache.spark.mllib.linalg.VectorUDT" - - def serialize(self, obj): - if isinstance(obj, SparseVector): - indices = [int(i) for i in obj.indices] - values = [float(v) for v in obj.values] - return (0, obj.size, indices, values) - elif isinstance(obj, DenseVector): - values = [float(v) for v in obj] - return (1, None, None, values) - else: - raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) - - def deserialize(self, datum): - assert len(datum) == 4, \ - "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) - tpe = datum[0] - if tpe == 0: - return SparseVector(datum[1], datum[2], datum[3]) - elif tpe == 1: - return DenseVector(datum[3]) - else: - raise ValueError("do not recognize type %r" % tpe) - - -class Vector(object): - - __UDT__ = VectorUDT() - - """ - Abstract class for DenseVector and SparseVector - """ - def toArray(self): - """ - Convert the vector into an numpy.ndarray - :return: numpy.ndarray - """ - raise NotImplementedError - - -class DenseVector(Vector): - """ - A dense vector represented by a value array. - """ - def __init__(self, ar): - if isinstance(ar, basestring): - ar = np.frombuffer(ar, dtype=np.float64) - elif not isinstance(ar, np.ndarray): - ar = np.array(ar, dtype=np.float64) - if ar.dtype != np.float64: - ar = ar.astype(np.float64) - self.array = ar - - def __reduce__(self): - return DenseVector, (self.array.tostring(),) - - def dot(self, other): - """ - Compute the dot product of two Vectors. We support - (Numpy array, list, SparseVector, or SciPy sparse) - and a target NumPy array that is either 1- or 2-dimensional. - Equivalent to calling numpy.dot of the two vectors. - - >>> dense = DenseVector(array.array('d', [1., 2.])) - >>> dense.dot(dense) - 5.0 - >>> dense.dot(SparseVector(2, [0, 1], [2., 1.])) - 4.0 - >>> dense.dot(range(1, 3)) - 5.0 - >>> dense.dot(np.array(range(1, 3))) - 5.0 - >>> dense.dot([1.,]) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - >>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F')) - array([ 5., 11.]) - >>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F')) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - """ - if type(other) == np.ndarray: - if other.ndim > 1: - assert len(self) == other.shape[0], "dimension mismatch" - return np.dot(self.array, other) - elif _have_scipy and scipy.sparse.issparse(other): - assert len(self) == other.shape[0], "dimension mismatch" - return other.transpose().dot(self.toArray()) - else: - assert len(self) == _vector_size(other), "dimension mismatch" - if isinstance(other, SparseVector): - return other.dot(self) - elif isinstance(other, Vector): - return np.dot(self.toArray(), other.toArray()) - else: - return np.dot(self.toArray(), other) - - def squared_distance(self, other): - """ - Squared distance of two Vectors. - - >>> dense1 = DenseVector(array.array('d', [1., 2.])) - >>> dense1.squared_distance(dense1) - 0.0 - >>> dense2 = np.array([2., 1.]) - >>> dense1.squared_distance(dense2) - 2.0 - >>> dense3 = [2., 1.] - >>> dense1.squared_distance(dense3) - 2.0 - >>> sparse1 = SparseVector(2, [0, 1], [2., 1.]) - >>> dense1.squared_distance(sparse1) - 2.0 - >>> dense1.squared_distance([1.,]) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - >>> dense1.squared_distance(SparseVector(1, [0,], [1.,])) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - """ - assert len(self) == _vector_size(other), "dimension mismatch" - if isinstance(other, SparseVector): - return other.squared_distance(self) - elif _have_scipy and scipy.sparse.issparse(other): - return _convert_to_vector(other).squared_distance(self) - - if isinstance(other, Vector): - other = other.toArray() - elif not isinstance(other, np.ndarray): - other = np.array(other) - diff = self.toArray() - other - return np.dot(diff, diff) - - def toArray(self): - return self.array - - def __getitem__(self, item): - return self.array[item] - - def __len__(self): - return len(self.array) - - def __str__(self): - return "[" + ",".join([str(v) for v in self.array]) + "]" - - def __repr__(self): - return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) - - def __eq__(self, other): - return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) - - def __ne__(self, other): - return not self == other - - def __getattr__(self, item): - return getattr(self.array, item) - - -class SparseVector(Vector): - """ - A simple sparse vector class for passing data to MLlib. Users may - alternatively pass SciPy's {scipy.sparse} data types. - """ - def __init__(self, size, *args): - """ - Create a sparse vector, using either a dictionary, a list of - (index, value) pairs, or two separate arrays of indices and - values (sorted by index). - - :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, - or two sorted lists containing indices and values. - - >>> print SparseVector(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) - """ - self.size = int(size) - assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" - if len(args) == 1: - pairs = args[0] - if type(pairs) == dict: - pairs = pairs.items() - pairs = sorted(pairs) - self.indices = np.array([p[0] for p in pairs], dtype=np.int32) - self.values = np.array([p[1] for p in pairs], dtype=np.float64) - else: - if isinstance(args[0], basestring): - assert isinstance(args[1], str), "values should be string too" - if args[0]: - self.indices = np.frombuffer(args[0], np.int32) - self.values = np.frombuffer(args[1], np.float64) - else: - # np.frombuffer() doesn't work well with empty string in older version - self.indices = np.array([], dtype=np.int32) - self.values = np.array([], dtype=np.float64) - else: - self.indices = np.array(args[0], dtype=np.int32) - self.values = np.array(args[1], dtype=np.float64) - assert len(self.indices) == len(self.values), "index and value arrays not same length" - for i in xrange(len(self.indices) - 1): - if self.indices[i] >= self.indices[i + 1]: - raise TypeError("indices array must be sorted") - - def __reduce__(self): - return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring())) - - def dot(self, other): - """ - Dot product with a SparseVector or 1- or 2-dimensional Numpy array. - - >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) - >>> a.dot(a) - 25.0 - >>> a.dot(array.array('d', [1., 2., 3., 4.])) - 22.0 - >>> b = SparseVector(4, [2, 4], [1.0, 2.0]) - >>> a.dot(b) - 0.0 - >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) - array([ 22., 22.]) - >>> a.dot([1., 2., 3.]) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - >>> a.dot(np.array([1., 2.])) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - >>> a.dot(DenseVector([1., 2.])) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - >>> a.dot(np.zeros((3, 2))) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - """ - if type(other) == np.ndarray: - if other.ndim == 2: - results = [self.dot(other[:, i]) for i in xrange(other.shape[1])] - return np.array(results) - elif other.ndim > 2: - raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim) - - assert len(self) == _vector_size(other), "dimension mismatch" - - if type(other) in (np.ndarray, array.array, DenseVector): - result = 0.0 - for i in xrange(len(self.indices)): - result += self.values[i] * other[self.indices[i]] - return result - - elif type(other) is SparseVector: - result = 0.0 - i, j = 0, 0 - while i < len(self.indices) and j < len(other.indices): - if self.indices[i] == other.indices[j]: - result += self.values[i] * other.values[j] - i += 1 - j += 1 - elif self.indices[i] < other.indices[j]: - i += 1 - else: - j += 1 - return result - - else: - return self.dot(_convert_to_vector(other)) - - def squared_distance(self, other): - """ - Squared distance from a SparseVector or 1-dimensional NumPy array. - - >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) - >>> a.squared_distance(a) - 0.0 - >>> a.squared_distance(array.array('d', [1., 2., 3., 4.])) - 11.0 - >>> a.squared_distance(np.array([1., 2., 3., 4.])) - 11.0 - >>> b = SparseVector(4, [2, 4], [1.0, 2.0]) - >>> a.squared_distance(b) - 30.0 - >>> b.squared_distance(a) - 30.0 - >>> b.squared_distance([1., 2.]) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - >>> b.squared_distance(SparseVector(3, [1,], [1.0,])) - Traceback (most recent call last): - ... - AssertionError: dimension mismatch - """ - assert len(self) == _vector_size(other), "dimension mismatch" - if type(other) in (list, array.array, DenseVector, np.array, np.ndarray): - if type(other) is np.array and other.ndim != 1: - raise Exception("Cannot call squared_distance with %d-dimensional array" % - other.ndim) - result = 0.0 - j = 0 # index into our own array - for i in xrange(len(other)): - if j < len(self.indices) and self.indices[j] == i: - diff = self.values[j] - other[i] - result += diff * diff - j += 1 - else: - result += other[i] * other[i] - return result - - elif type(other) is SparseVector: - result = 0.0 - i, j = 0, 0 - while i < len(self.indices) and j < len(other.indices): - if self.indices[i] == other.indices[j]: - diff = self.values[i] - other.values[j] - result += diff * diff - i += 1 - j += 1 - elif self.indices[i] < other.indices[j]: - result += self.values[i] * self.values[i] - i += 1 - else: - result += other.values[j] * other.values[j] - j += 1 - while i < len(self.indices): - result += self.values[i] * self.values[i] - i += 1 - while j < len(other.indices): - result += other.values[j] * other.values[j] - j += 1 - return result - else: - return self.squared_distance(_convert_to_vector(other)) - - def toArray(self): - """ - Returns a copy of this SparseVector as a 1-dimensional NumPy array. - """ - arr = np.zeros((self.size,), dtype=np.float64) - arr[self.indices] = self.values - return arr - - def __len__(self): - return self.size - - def __str__(self): - inds = "[" + ",".join([str(i) for i in self.indices]) + "]" - vals = "[" + ",".join([str(v) for v in self.values]) + "]" - return "(" + ",".join((str(self.size), inds, vals)) + ")" - - def __repr__(self): - inds = self.indices - vals = self.values - entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) - for i in xrange(len(inds))]) - return "SparseVector({0}, {{{1}}})".format(self.size, entries) - - def __eq__(self, other): - """ - Test SparseVectors for equality. - - >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v1 == v2 - True - >>> v1 != v2 - False - """ - return (isinstance(other, self.__class__) - and other.size == self.size - and np.array_equal(other.indices, self.indices) - and np.array_equal(other.values, self.values)) - - def __getitem__(self, index): - inds = self.indices - vals = self.values - if not isinstance(index, int): - raise ValueError( - "Indices must be of type integer, got type %s" % type(index)) - if index < 0: - index += self.size - if index >= self.size or index < 0: - raise ValueError("Index %d out of bounds." % index) - - insert_index = np.searchsorted(inds, index) - row_ind = inds[insert_index] - if row_ind == index: - return vals[insert_index] - return 0. - - def __ne__(self, other): - return not self.__eq__(other) - - -class Vectors(object): - - """ - Factory methods for working with vectors. Note that dense vectors - are simply represented as NumPy array objects, so there is no need - to covert them for use in MLlib. For sparse vectors, the factory - methods in this class create an MLlib-compatible type, or users - can pass in SciPy's C{scipy.sparse} column vectors. - """ - - @staticmethod - def sparse(size, *args): - """ - Create a sparse vector, using either a dictionary, a list of - (index, value) pairs, or two separate arrays of indices and - values (sorted by index). - - :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, - or two sorted lists containing indices and values. - - >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) - """ - return SparseVector(size, *args) - - @staticmethod - def dense(elements): - """ - Create a dense vector of 64-bit floats from a Python list. Always - returns a NumPy array. - - >>> Vectors.dense([1, 2, 3]) - DenseVector([1.0, 2.0, 3.0]) - """ - return DenseVector(elements) - - @staticmethod - def stringify(vector): - """ - Converts a vector into a string, which can be recognized by - Vectors.parse(). - - >>> Vectors.stringify(Vectors.sparse(2, [1], [1.0])) - '(2,[1],[1.0])' - >>> Vectors.stringify(Vectors.dense([0.0, 1.0])) - '[0.0,1.0]' - """ - return str(vector) - - -class Matrix(object): - """ - Represents a local matrix. - """ - - def __init__(self, numRows, numCols): - self.numRows = numRows - self.numCols = numCols - - def toArray(self): - """ - Returns its elements in a NumPy ndarray. - """ - raise NotImplementedError - - -class DenseMatrix(Matrix): - """ - Column-major dense matrix. - """ - def __init__(self, numRows, numCols, values): - Matrix.__init__(self, numRows, numCols) - if isinstance(values, basestring): - values = np.frombuffer(values, dtype=np.float64) - elif not isinstance(values, np.ndarray): - values = np.array(values, dtype=np.float64) - assert len(values) == numRows * numCols - if values.dtype != np.float64: - values.astype(np.float64) - self.values = values - - def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) - - def toArray(self): - """ - Return an numpy.ndarray - - >>> m = DenseMatrix(2, 2, range(4)) - >>> m.toArray() - array([[ 0., 2.], - [ 1., 3.]]) - """ - return self.values.reshape((self.numRows, self.numCols), order='F') - - def __eq__(self, other): - return (isinstance(other, DenseMatrix) and - self.numRows == other.numRows and - self.numCols == other.numCols and - all(self.values == other.values)) - - -class Matrices(object): - @staticmethod - def dense(numRows, numCols, values): - """ - Create a DenseMatrix - """ - return DenseMatrix(numRows, numCols, values) - - -def _test(): - import doctest - (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) - if failure_count: - exit(-1) - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py new file mode 100644 index 000000000000..334dc8e38bb8 --- /dev/null +++ b/python/pyspark/mllib/linalg/__init__.py @@ -0,0 +1,1162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +MLlib utilities for linear algebra. For dense vectors, MLlib +uses the NumPy C{array} type, so you can simply pass NumPy arrays +around. For sparse vectors, users can construct a L{SparseVector} +object from MLlib or pass SciPy C{scipy.sparse} column vectors if +SciPy is available in their environment. +""" + +import sys +import array + +if sys.version >= '3': + basestring = str + xrange = range + import copyreg as copy_reg + long = int +else: + from itertools import izip as zip + import copy_reg + +import numpy as np + +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, BooleanType + + +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', + 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] + + +if sys.version_info[:2] == (2, 7): + # speed up pickling array in Python 2.7 + def fast_pickle_array(ar): + return array.array, (ar.typecode, ar.tostring()) + copy_reg.pickle(array.array, fast_pickle_array) + + +# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, +# such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. + +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy in environment, but that's okay + _have_scipy = False + + +def _convert_to_vector(l): + if isinstance(l, Vector): + return l + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): + return DenseVector(l) + elif _have_scipy and scipy.sparse.issparse(l): + assert l.shape[1] == 1, "Expected column vector" + csc = l.tocsc() + return SparseVector(l.shape[0], csc.indices, csc.data) + else: + raise TypeError("Cannot convert type %s into Vector" % type(l)) + + +def _vector_size(v): + """ + Returns the size of the vector. + + >>> _vector_size([1., 2., 3.]) + 3 + >>> _vector_size((1., 2., 3.)) + 3 + >>> _vector_size(array.array('d', [1., 2., 3.])) + 3 + >>> _vector_size(np.zeros(3)) + 3 + >>> _vector_size(np.zeros((3, 1))) + 3 + >>> _vector_size(np.zeros((1, 3))) + Traceback (most recent call last): + ... + ValueError: Cannot treat an ndarray of shape (1, 3) as a vector + """ + if isinstance(v, Vector): + return len(v) + elif type(v) in (array.array, list, tuple, xrange): + return len(v) + elif type(v) == np.ndarray: + if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): + return len(v) + else: + raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape)) + elif _have_scipy and scipy.sparse.issparse(v): + assert v.shape[1] == 1, "Expected column vector" + return v.shape[0] + else: + raise TypeError("Cannot treat type %s as a vector" % type(v)) + + +def _format_float(f, digits=4): + s = str(round(f, digits)) + if '.' in s: + s = s[:s.index('.') + 1 + digits] + return s + + +def _format_float_list(l): + return [_format_float(x) for x in l] + + +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "vector" + + +class MatrixUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Matrix. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("numRows", IntegerType(), False), + StructField("numCols", IntegerType(), False), + StructField("colPtrs", ArrayType(IntegerType(), False), True), + StructField("rowIndices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True), + StructField("isTransposed", BooleanType(), False)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.MatrixUDT" + + def serialize(self, obj): + if isinstance(obj, SparseMatrix): + colPtrs = [int(i) for i in obj.colPtrs] + rowIndices = [int(i) for i in obj.rowIndices] + values = [float(v) for v in obj.values] + return (0, obj.numRows, obj.numCols, colPtrs, + rowIndices, values, bool(obj.isTransposed)) + elif isinstance(obj, DenseMatrix): + values = [float(v) for v in obj.values] + return (1, obj.numRows, obj.numCols, None, None, values, + bool(obj.isTransposed)) + else: + raise TypeError("cannot serialize type %r" % (type(obj))) + + def deserialize(self, datum): + assert len(datum) == 7, \ + "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseMatrix(*datum[1:]) + elif tpe == 1: + return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "matrix" + + +class Vector(object): + + __UDT__ = VectorUDT() + + """ + Abstract class for DenseVector and SparseVector + """ + def toArray(self): + """ + Convert the vector into an numpy.ndarray + :return: numpy.ndarray + """ + raise NotImplementedError + + +class DenseVector(Vector): + """ + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) + """ + def __init__(self, ar): + if isinstance(ar, bytes): + ar = np.frombuffer(ar, dtype=np.float64) + elif not isinstance(ar, np.ndarray): + ar = np.array(ar, dtype=np.float64) + if ar.dtype != np.float64: + ar = ar.astype(np.float64) + self.array = ar + + @staticmethod + def parse(s): + """ + Parse string representation back into the DenseVector. + + >>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]') + DenseVector([0.0, 1.0, 2.0, 3.0]) + """ + start = s.find('[') + if start == -1: + raise ValueError("Array should start with '['.") + end = s.find(']') + if end == -1: + raise ValueError("Array should end with ']'.") + s = s[start + 1: end] + + try: + values = [float(val) for val in s.split(',')] + except ValueError: + raise ValueError("Unable to parse values from %s" % s) + return DenseVector(values) + + def __reduce__(self): + return DenseVector, (self.array.tostring(),) + + def numNonzeros(self): + return np.count_nonzero(self.array) + + def norm(self, p): + """ + Calculte the norm of a DenseVector. + + >>> a = DenseVector([0, -1, 2, -3]) + >>> a.norm(2) + 3.7... + >>> a.norm(1) + 6.0 + """ + return np.linalg.norm(self.array, p) + + def dot(self, other): + """ + Compute the dot product of two Vectors. We support + (Numpy array, list, SparseVector, or SciPy sparse) + and a target NumPy array that is either 1- or 2-dimensional. + Equivalent to calling numpy.dot of the two vectors. + + >>> dense = DenseVector(array.array('d', [1., 2.])) + >>> dense.dot(dense) + 5.0 + >>> dense.dot(SparseVector(2, [0, 1], [2., 1.])) + 4.0 + >>> dense.dot(range(1, 3)) + 5.0 + >>> dense.dot(np.array(range(1, 3))) + 5.0 + >>> dense.dot([1.,]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F')) + array([ 5., 11.]) + >>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F')) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + if type(other) == np.ndarray: + if other.ndim > 1: + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.array, other) + elif _have_scipy and scipy.sparse.issparse(other): + assert len(self) == other.shape[0], "dimension mismatch" + return other.transpose().dot(self.toArray()) + else: + assert len(self) == _vector_size(other), "dimension mismatch" + if isinstance(other, SparseVector): + return other.dot(self) + elif isinstance(other, Vector): + return np.dot(self.toArray(), other.toArray()) + else: + return np.dot(self.toArray(), other) + + def squared_distance(self, other): + """ + Squared distance of two Vectors. + + >>> dense1 = DenseVector(array.array('d', [1., 2.])) + >>> dense1.squared_distance(dense1) + 0.0 + >>> dense2 = np.array([2., 1.]) + >>> dense1.squared_distance(dense2) + 2.0 + >>> dense3 = [2., 1.] + >>> dense1.squared_distance(dense3) + 2.0 + >>> sparse1 = SparseVector(2, [0, 1], [2., 1.]) + >>> dense1.squared_distance(sparse1) + 2.0 + >>> dense1.squared_distance([1.,]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> dense1.squared_distance(SparseVector(1, [0,], [1.,])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + assert len(self) == _vector_size(other), "dimension mismatch" + if isinstance(other, SparseVector): + return other.squared_distance(self) + elif _have_scipy and scipy.sparse.issparse(other): + return _convert_to_vector(other).squared_distance(self) + + if isinstance(other, Vector): + other = other.toArray() + elif not isinstance(other, np.ndarray): + other = np.array(other) + diff = self.toArray() - other + return np.dot(diff, diff) + + def toArray(self): + return self.array + + def __getitem__(self, item): + return self.array[item] + + def __len__(self): + return len(self.array) + + def __str__(self): + return "[" + ",".join([str(v) for v in self.array]) + "]" + + def __repr__(self): + return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) + + def __eq__(self, other): + return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) + + def __ne__(self, other): + return not self == other + + def __getattr__(self, item): + return getattr(self.array, item) + + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __truediv__ = _delegate("__truediv__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rtruediv__ = _delegate("__rtruediv__") + __rmod__ = _delegate("__rmod__") + + +class SparseVector(Vector): + """ + A simple sparse vector class for passing data to MLlib. Users may + alternatively pass SciPy's {scipy.sparse} data types. + """ + def __init__(self, size, *args): + """ + Create a sparse vector, using either a dictionary, a list of + (index, value) pairs, or two separate arrays of indices and + values (sorted by index). + + :param size: Size of the vector. + :param args: Active entries, as a dictionary {index: value, ...}, + a list of tuples [(index, value), ...], or a list of strictly i + ncreasing indices and a list of corresponding values [index, ...], + [value, ...]. Inactive entries are treated as zeros. + + >>> SparseVector(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) + """ + self.size = int(size) + """ Size of the vector. """ + assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" + if len(args) == 1: + pairs = args[0] + if type(pairs) == dict: + pairs = pairs.items() + pairs = sorted(pairs) + self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + """ A list of indices corresponding to active entries. """ + self.values = np.array([p[1] for p in pairs], dtype=np.float64) + """ A list of values corresponding to active entries. """ + else: + if isinstance(args[0], bytes): + assert isinstance(args[1], bytes), "values should be string too" + if args[0]: + self.indices = np.frombuffer(args[0], np.int32) + self.values = np.frombuffer(args[1], np.float64) + else: + # np.frombuffer() doesn't work well with empty string in older version + self.indices = np.array([], dtype=np.int32) + self.values = np.array([], dtype=np.float64) + else: + self.indices = np.array(args[0], dtype=np.int32) + self.values = np.array(args[1], dtype=np.float64) + assert len(self.indices) == len(self.values), "index and value arrays not same length" + for i in xrange(len(self.indices) - 1): + if self.indices[i] >= self.indices[i + 1]: + raise TypeError("indices array must be sorted") + + def numNonzeros(self): + return np.count_nonzero(self.values) + + def norm(self, p): + """ + Calculte the norm of a SparseVector. + + >>> a = SparseVector(4, [0, 1], [3., -4.]) + >>> a.norm(1) + 7.0 + >>> a.norm(2) + 5.0 + """ + return np.linalg.norm(self.values, p) + + def __reduce__(self): + return ( + SparseVector, + (self.size, self.indices.tostring(), self.values.tostring())) + + @staticmethod + def parse(s): + """ + Parse string representation back into the DenseVector. + + >>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )') + SparseVector(4, {0: 4.0, 1: 5.0}) + """ + start = s.find('(') + if start == -1: + raise ValueError("Tuple should start with '('") + end = s.find(')') + if start == -1: + raise ValueError("Tuple should end with ')'") + s = s[start + 1: end].strip() + + size = s[: s.find(',')] + try: + size = int(size) + except ValueError: + raise ValueError("Cannot parse size %s." % size) + + ind_start = s.find('[') + if ind_start == -1: + raise ValueError("Indices array should start with '['.") + ind_end = s.find(']') + if ind_end == -1: + raise ValueError("Indices array should end with ']'") + new_s = s[ind_start + 1: ind_end] + ind_list = new_s.split(',') + try: + indices = [int(ind) for ind in ind_list] + except ValueError: + raise ValueError("Unable to parse indices from %s." % new_s) + s = s[ind_end + 1:].strip() + + val_start = s.find('[') + if val_start == -1: + raise ValueError("Values array should start with '['.") + val_end = s.find(']') + if val_end == -1: + raise ValueError("Values array should end with ']'.") + val_list = s[val_start + 1: val_end].split(',') + try: + values = [float(val) for val in val_list] + except ValueError: + raise ValueError("Unable to parse values from %s." % s) + return SparseVector(size, indices, values) + + def dot(self, other): + """ + Dot product with a SparseVector or 1- or 2-dimensional Numpy array. + + >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) + >>> a.dot(a) + 25.0 + >>> a.dot(array.array('d', [1., 2., 3., 4.])) + 22.0 + >>> b = SparseVector(4, [2], [1.0]) + >>> a.dot(b) + 0.0 + >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) + array([ 22., 22.]) + >>> a.dot([1., 2., 3.]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(np.array([1., 2.])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(DenseVector([1., 2.])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(np.zeros((3, 2))) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + + if isinstance(other, np.ndarray): + if other.ndim not in [2, 1]: + raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim) + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.values, other[self.indices]) + + assert len(self) == _vector_size(other), "dimension mismatch" + + if isinstance(other, DenseVector): + return np.dot(other.array[self.indices], self.values) + + elif isinstance(other, SparseVector): + # Find out common indices. + self_cmind = np.in1d(self.indices, other.indices, assume_unique=True) + self_values = self.values[self_cmind] + if self_values.size == 0: + return 0.0 + else: + other_cmind = np.in1d(other.indices, self.indices, assume_unique=True) + return np.dot(self_values, other.values[other_cmind]) + + else: + return self.dot(_convert_to_vector(other)) + + def squared_distance(self, other): + """ + Squared distance from a SparseVector or 1-dimensional NumPy array. + + >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) + >>> a.squared_distance(a) + 0.0 + >>> a.squared_distance(array.array('d', [1., 2., 3., 4.])) + 11.0 + >>> a.squared_distance(np.array([1., 2., 3., 4.])) + 11.0 + >>> b = SparseVector(4, [2], [1.0]) + >>> a.squared_distance(b) + 26.0 + >>> b.squared_distance(a) + 26.0 + >>> b.squared_distance([1., 2.]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> b.squared_distance(SparseVector(3, [1,], [1.0,])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + assert len(self) == _vector_size(other), "dimension mismatch" + + if isinstance(other, np.ndarray) or isinstance(other, DenseVector): + if isinstance(other, np.ndarray) and other.ndim != 1: + raise Exception("Cannot call squared_distance with %d-dimensional array" % + other.ndim) + if isinstance(other, DenseVector): + other = other.array + sparse_ind = np.zeros(other.size, dtype=bool) + sparse_ind[self.indices] = True + dist = other[sparse_ind] - self.values + result = np.dot(dist, dist) + + other_ind = other[~sparse_ind] + result += np.dot(other_ind, other_ind) + return result + + elif isinstance(other, SparseVector): + result = 0.0 + i, j = 0, 0 + while i < len(self.indices) and j < len(other.indices): + if self.indices[i] == other.indices[j]: + diff = self.values[i] - other.values[j] + result += diff * diff + i += 1 + j += 1 + elif self.indices[i] < other.indices[j]: + result += self.values[i] * self.values[i] + i += 1 + else: + result += other.values[j] * other.values[j] + j += 1 + while i < len(self.indices): + result += self.values[i] * self.values[i] + i += 1 + while j < len(other.indices): + result += other.values[j] * other.values[j] + j += 1 + return result + else: + return self.squared_distance(_convert_to_vector(other)) + + def toArray(self): + """ + Returns a copy of this SparseVector as a 1-dimensional NumPy array. + """ + arr = np.zeros((self.size,), dtype=np.float64) + arr[self.indices] = self.values + return arr + + def __len__(self): + return self.size + + def __str__(self): + inds = "[" + ",".join([str(i) for i in self.indices]) + "]" + vals = "[" + ",".join([str(v) for v in self.values]) + "]" + return "(" + ",".join((str(self.size), inds, vals)) + ")" + + def __repr__(self): + inds = self.indices + vals = self.values + entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) + for i in xrange(len(inds))]) + return "SparseVector({0}, {{{1}}})".format(self.size, entries) + + def __eq__(self, other): + """ + Test SparseVectors for equality. + + >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + >>> v1 == v2 + True + >>> v1 != v2 + False + """ + return (isinstance(other, self.__class__) + and other.size == self.size + and np.array_equal(other.indices, self.indices) + and np.array_equal(other.values, self.values)) + + def __getitem__(self, index): + inds = self.indices + vals = self.values + if not isinstance(index, int): + raise TypeError( + "Indices must be of type integer, got type %s" % type(index)) + if index < 0: + index += self.size + if index >= self.size or index < 0: + raise ValueError("Index %d out of bounds." % index) + + insert_index = np.searchsorted(inds, index) + row_ind = inds[insert_index] + if row_ind == index: + return vals[insert_index] + return 0. + + def __ne__(self, other): + return not self.__eq__(other) + + +class Vectors(object): + + """ + Factory methods for working with vectors. Note that dense vectors + are simply represented as NumPy array objects, so there is no need + to covert them for use in MLlib. For sparse vectors, the factory + methods in this class create an MLlib-compatible type, or users + can pass in SciPy's C{scipy.sparse} column vectors. + """ + + @staticmethod + def sparse(size, *args): + """ + Create a sparse vector, using either a dictionary, a list of + (index, value) pairs, or two separate arrays of indices and + values (sorted by index). + + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, + or two sorted lists containing indices and values. + + >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) + """ + return SparseVector(size, *args) + + @staticmethod + def dense(*elements): + """ + Create a dense vector of 64-bit floats from a Python list or numbers. + + >>> Vectors.dense([1, 2, 3]) + DenseVector([1.0, 2.0, 3.0]) + >>> Vectors.dense(1.0, 2.0) + DenseVector([1.0, 2.0]) + """ + if len(elements) == 1 and not isinstance(elements[0], (float, int, long)): + # it's list, numpy.array or other iterable object. + elements = elements[0] + return DenseVector(elements) + + @staticmethod + def stringify(vector): + """ + Converts a vector into a string, which can be recognized by + Vectors.parse(). + + >>> Vectors.stringify(Vectors.sparse(2, [1], [1.0])) + '(2,[1],[1.0])' + >>> Vectors.stringify(Vectors.dense([0.0, 1.0])) + '[0.0,1.0]' + """ + return str(vector) + + @staticmethod + def squared_distance(v1, v2): + """ + Squared distance between two vectors. + a and b can be of type SparseVector, DenseVector, np.ndarray + or array.array. + + >>> a = Vectors.sparse(4, [(0, 1), (3, 4)]) + >>> b = Vectors.dense([2, 5, 4, 1]) + >>> a.squared_distance(b) + 51.0 + """ + v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2) + return v1.squared_distance(v2) + + @staticmethod + def norm(vector, p): + """ + Find norm of the given vector. + """ + return _convert_to_vector(vector).norm(p) + + @staticmethod + def parse(s): + """Parse a string representation back into the Vector. + + >>> Vectors.parse('[2,1,2 ]') + DenseVector([2.0, 1.0, 2.0]) + >>> Vectors.parse(' ( 100, [0], [2])') + SparseVector(100, {0: 2.0}) + """ + if s.find('(') == -1 and s.find('[') != -1: + return DenseVector.parse(s) + elif s.find('(') != -1: + return SparseVector.parse(s) + else: + raise ValueError( + "Cannot find tokens '[' or '(' from the input string.") + + @staticmethod + def zeros(size): + return DenseVector(np.zeros(size)) + + +class Matrix(object): + + __UDT__ = MatrixUDT() + + """ + Represents a local matrix. + """ + def __init__(self, numRows, numCols, isTransposed=False): + self.numRows = numRows + self.numCols = numCols + self.isTransposed = isTransposed + + def toArray(self): + """ + Returns its elements in a NumPy ndarray. + """ + raise NotImplementedError + + @staticmethod + def _convert_to_array(array_like, dtype): + """ + Convert Matrix attributes which are array-like or buffer to array. + """ + if isinstance(array_like, bytes): + return np.frombuffer(array_like, dtype=dtype) + return np.asarray(array_like, dtype=dtype) + + +class DenseMatrix(Matrix): + """ + Column-major dense matrix. + """ + def __init__(self, numRows, numCols, values, isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) + values = self._convert_to_array(values, np.float64) + assert len(values) == numRows * numCols + self.values = values + + def __reduce__(self): + return DenseMatrix, ( + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) + + def __str__(self): + """ + Pretty printing of a DenseMatrix + + >>> dm = DenseMatrix(2, 2, range(4)) + >>> print(dm) + DenseMatrix([[ 0., 2.], + [ 1., 3.]]) + >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True) + >>> print(dm) + DenseMatrix([[ 0., 1.], + [ 2., 3.]]) + """ + # Inspired by __repr__ in scipy matrices. + array_lines = repr(self.toArray()).splitlines() + + # We need to adjust six spaces which is the difference in number + # of letters between "DenseMatrix" and "array" + x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]]) + return array_lines[0].replace("array", "DenseMatrix") + "\n" + x + + def __repr__(self): + """ + Representation of a DenseMatrix + + >>> dm = DenseMatrix(2, 2, range(4)) + >>> dm + DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False) + """ + # If the number of values are less than seventeen then return as it is. + # Else return first eight values and last eight values. + if len(self.values) < 17: + entries = _format_float_list(self.values) + else: + entries = ( + _format_float_list(self.values[:8]) + + ["..."] + + _format_float_list(self.values[-8:]) + ) + + entries = ", ".join(entries) + return "DenseMatrix({0}, {1}, [{2}], {3})".format( + self.numRows, self.numCols, entries, self.isTransposed) + + def toArray(self): + """ + Return an numpy.ndarray + + >>> m = DenseMatrix(2, 2, range(4)) + >>> m.toArray() + array([[ 0., 2.], + [ 1., 3.]]) + """ + if self.isTransposed: + return np.asfortranarray( + self.values.reshape((self.numRows, self.numCols))) + else: + return self.values.reshape((self.numRows, self.numCols), order='F') + + def toSparse(self): + """Convert to SparseMatrix""" + if self.isTransposed: + values = np.ravel(self.toArray(), order='F') + else: + values = self.values + indices = np.nonzero(values)[0] + colCounts = np.bincount(indices // self.numRows) + colPtrs = np.cumsum(np.hstack( + (0, colCounts, np.zeros(self.numCols - colCounts.size)))) + values = values[indices] + rowIndices = indices % self.numRows + + return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + if self.isTransposed: + return self.values[i * self.numCols + j] + else: + return self.values[i + j * self.numRows] + + def __eq__(self, other): + if (not isinstance(other, DenseMatrix) or + self.numRows != other.numRows or + self.numCols != other.numCols): + return False + + self_values = np.ravel(self.toArray(), order='F') + other_values = np.ravel(other.toArray(), order='F') + return all(self_values == other_values) + + +class SparseMatrix(Matrix): + """Sparse Matrix stored in CSC format.""" + def __init__(self, numRows, numCols, colPtrs, rowIndices, values, + isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) + self.colPtrs = self._convert_to_array(colPtrs, np.int32) + self.rowIndices = self._convert_to_array(rowIndices, np.int32) + self.values = self._convert_to_array(values, np.float64) + + if self.isTransposed: + if self.colPtrs.size != numRows + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numRows + 1, self.colPtrs.size)) + else: + if self.colPtrs.size != numCols + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numCols + 1, self.colPtrs.size)) + if self.rowIndices.size != self.values.size: + raise ValueError("Expected rowIndices of length %d, got %d." + % (self.rowIndices.size, self.values.size)) + + def __str__(self): + """ + Pretty printing of a SparseMatrix + + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + >>> print(sm1) + 2 X 2 CSCMatrix + (0,0) 2.0 + (1,0) 3.0 + (1,1) 4.0 + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + >>> print(sm1) + 2 X 2 CSRMatrix + (0,0) 2.0 + (0,1) 3.0 + (1,1) 4.0 + """ + spstr = "{0} X {1} ".format(self.numRows, self.numCols) + if self.isTransposed: + spstr += "CSRMatrix\n" + else: + spstr += "CSCMatrix\n" + + cur_col = 0 + smlist = [] + + # Display first 16 values. + if len(self.values) <= 16: + zipindval = zip(self.rowIndices, self.values) + else: + zipindval = zip(self.rowIndices[:16], self.values[:16]) + for i, (rowInd, value) in enumerate(zipindval): + if self.colPtrs[cur_col + 1] <= i: + cur_col += 1 + if self.isTransposed: + smlist.append('({0},{1}) {2}'.format( + cur_col, rowInd, _format_float(value))) + else: + smlist.append('({0},{1}) {2}'.format( + rowInd, cur_col, _format_float(value))) + spstr += "\n".join(smlist) + + if len(self.values) > 16: + spstr += "\n.." * 2 + return spstr + + def __repr__(self): + """ + Representation of a SparseMatrix + + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + >>> sm1 + SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False) + """ + rowIndices = list(self.rowIndices) + colPtrs = list(self.colPtrs) + + if len(self.values) <= 16: + values = _format_float_list(self.values) + + else: + values = ( + _format_float_list(self.values[:8]) + + ["..."] + + _format_float_list(self.values[-8:]) + ) + rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:] + + if len(self.colPtrs) > 16: + colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:] + + values = ", ".join(values) + rowIndices = ", ".join([str(ind) for ind in rowIndices]) + colPtrs = ", ".join([str(ptr) for ptr in colPtrs]) + return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format( + self.numRows, self.numCols, colPtrs, rowIndices, + values, self.isTransposed) + + def __reduce__(self): + return SparseMatrix, ( + self.numRows, self.numCols, self.colPtrs.tostring(), + self.rowIndices.tostring(), self.values.tostring(), + int(self.isTransposed)) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j < 0 or j >= self.numCols: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + # If a CSR matrix is given, then the row index should be searched + # for in ColPtrs, and the column index should be searched for in the + # corresponding slice obtained from rowIndices. + if self.isTransposed: + j, i = i, j + + colStart = self.colPtrs[j] + colEnd = self.colPtrs[j + 1] + nz = self.rowIndices[colStart: colEnd] + ind = np.searchsorted(nz, i) + colStart + if ind < colEnd and self.rowIndices[ind] == i: + return self.values[ind] + else: + return 0.0 + + def toArray(self): + """ + Return an numpy.ndarray + """ + A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') + for k in xrange(self.colPtrs.size - 1): + startptr = self.colPtrs[k] + endptr = self.colPtrs[k + 1] + if self.isTransposed: + A[k, self.rowIndices[startptr:endptr]] = self.values[startptr:endptr] + else: + A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] + return A + + def toDense(self): + densevals = np.ravel(self.toArray(), order='F') + return DenseMatrix(self.numRows, self.numCols, densevals) + + # TODO: More efficient implementation: + def __eq__(self, other): + return np.all(self.toArray() == other.toArray()) + + +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + @staticmethod + def sparse(numRows, numCols, colPtrs, rowIndices, values): + """ + Create a SparseMatrix + """ + return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + + +def _test(): + import doctest + (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py new file mode 100644 index 000000000000..aec407de90aa --- /dev/null +++ b/python/pyspark/mllib/linalg/distributed.py @@ -0,0 +1,853 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Package for distributed linear algebra. +""" + +import sys + +if sys.version >= '3': + long = int + +from py4j.java_gateway import JavaObject + +from pyspark import RDD +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector, Matrix + + +__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', + 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', + 'BlockMatrix'] + + +class DistributedMatrix(object): + """ + .. note:: Experimental + + Represents a distributively stored matrix backed by one or + more RDDs. + + """ + def numRows(self): + """Get or compute the number of rows.""" + raise NotImplementedError + + def numCols(self): + """Get or compute the number of cols.""" + raise NotImplementedError + + +class RowMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a row-oriented distributed Matrix with no meaningful + row indices. + + :param rows: An RDD of vectors. + :param numRows: Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the number of + records in the `rows` RDD. + :param numCols: Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the size of + the first row. + """ + def __init__(self, rows, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java RowMatrix. + + Publicly, we require that `rows` be an RDD. However, for + internal usage, `rows` can also be a Java RowMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) + >>> mat = RowMatrix(rows) + + >>> mat_diff = RowMatrix(rows) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = RowMatrix(mat._java_matrix_wrapper._java_model) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(rows, RDD): + rows = rows.map(_convert_to_vector) + java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols)) + elif (isinstance(rows, JavaObject) + and rows.getClass().getSimpleName() == "RowMatrix"): + java_matrix = rows + else: + raise TypeError("rows should be an RDD of vectors, got %s" % type(rows)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def rows(self): + """ + Rows of the RowMatrix stored as an RDD of vectors. + + >>> mat = RowMatrix(sc.parallelize([[1, 2, 3], [4, 5, 6]])) + >>> rows = mat.rows + >>> rows.first() + DenseVector([1.0, 2.0, 3.0]) + """ + return self._java_matrix_wrapper.call("rows") + + def numRows(self): + """ + Get or compute the number of rows. + + >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6], + ... [7, 8, 9], [10, 11, 12]]) + + >>> mat = RowMatrix(rows) + >>> print(mat.numRows()) + 4 + + >>> mat = RowMatrix(rows, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6], + ... [7, 8, 9], [10, 11, 12]]) + + >>> mat = RowMatrix(rows) + >>> print(mat.numCols()) + 3 + + >>> mat = RowMatrix(rows, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + +class IndexedRow(object): + """ + .. note:: Experimental + + Represents a row of an IndexedRowMatrix. + + Just a wrapper over a (long, vector) tuple. + + :param index: The index for the given row. + :param vector: The row in the matrix at the given index. + """ + def __init__(self, index, vector): + self.index = long(index) + self.vector = _convert_to_vector(vector) + + def __repr__(self): + return "IndexedRow(%s, %s)" % (self.index, self.vector) + + +def _convert_to_indexed_row(row): + if isinstance(row, IndexedRow): + return row + elif isinstance(row, tuple) and len(row) == 2: + return IndexedRow(*row) + else: + raise TypeError("Cannot convert type %s into IndexedRow" % type(row)) + + +class IndexedRowMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a row-oriented distributed Matrix with indexed rows. + + :param rows: An RDD of IndexedRows or (long, vector) tuples. + :param numRows: Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the max row + index plus one. + :param numCols: Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the size of + the first row. + """ + def __init__(self, rows, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java IndexedRowMatrix. + + Publicly, we require that `rows` be an RDD. However, for + internal usage, `rows` can also be a Java IndexedRowMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows) + + >>> mat_diff = IndexedRowMatrix(rows) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = IndexedRowMatrix(mat._java_matrix_wrapper._java_model) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(rows, RDD): + rows = rows.map(_convert_to_indexed_row) + # We use DataFrames for serialization of IndexedRows from + # Python, so first convert the RDD to a DataFrame on this + # side. This will convert each IndexedRow to a Row + # containing the 'index' and 'vector' values, which can + # both be easily serialized. We will convert back to + # IndexedRows on the Scala side. + java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(), + long(numRows), int(numCols)) + elif (isinstance(rows, JavaObject) + and rows.getClass().getSimpleName() == "IndexedRowMatrix"): + java_matrix = rows + else: + raise TypeError("rows should be an RDD of IndexedRows or (long, vector) tuples, " + "got %s" % type(rows)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def rows(self): + """ + Rows of the IndexedRowMatrix stored as an RDD of IndexedRows. + + >>> mat = IndexedRowMatrix(sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6])])) + >>> rows = mat.rows + >>> rows.first() + IndexedRow(0, [1.0,2.0,3.0]) + """ + # We use DataFrames for serialization of IndexedRows from + # Java, so we first convert the RDD of rows to a DataFrame + # on the Scala/Java side. Then we map each Row in the + # DataFrame back to an IndexedRow on this side. + rows_df = callMLlibFunc("getIndexedRows", self._java_matrix_wrapper._java_model) + rows = rows_df.map(lambda row: IndexedRow(row[0], row[1])) + return rows + + def numRows(self): + """ + Get or compute the number of rows. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6]), + ... IndexedRow(2, [7, 8, 9]), + ... IndexedRow(3, [10, 11, 12])]) + + >>> mat = IndexedRowMatrix(rows) + >>> print(mat.numRows()) + 4 + + >>> mat = IndexedRowMatrix(rows, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(1, [4, 5, 6]), + ... IndexedRow(2, [7, 8, 9]), + ... IndexedRow(3, [10, 11, 12])]) + + >>> mat = IndexedRowMatrix(rows) + >>> print(mat.numCols()) + 3 + + >>> mat = IndexedRowMatrix(rows, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toRowMatrix(self): + """ + Convert this matrix to a RowMatrix. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows).toRowMatrix() + >>> mat.rows.collect() + [DenseVector([1.0, 2.0, 3.0]), DenseVector([4.0, 5.0, 6.0])] + """ + java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix") + return RowMatrix(java_row_matrix) + + def toCoordinateMatrix(self): + """ + Convert this matrix to a CoordinateMatrix. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 0]), + ... IndexedRow(6, [0, 5])]) + >>> mat = IndexedRowMatrix(rows).toCoordinateMatrix() + >>> mat.entries.take(3) + [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 0.0), MatrixEntry(6, 0, 0.0)] + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") + return CoordinateMatrix(java_coordinate_matrix) + + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows).toBlockMatrix() + + >>> # This IndexedRowMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> print(mat.numCols()) + 3 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + + +class MatrixEntry(object): + """ + .. note:: Experimental + + Represents an entry of a CoordinateMatrix. + + Just a wrapper over a (long, long, float) tuple. + + :param i: The row index of the matrix. + :param j: The column index of the matrix. + :param value: The (i, j)th entry of the matrix, as a float. + """ + def __init__(self, i, j, value): + self.i = long(i) + self.j = long(j) + self.value = float(value) + + def __repr__(self): + return "MatrixEntry(%s, %s, %s)" % (self.i, self.j, self.value) + + +def _convert_to_matrix_entry(entry): + if isinstance(entry, MatrixEntry): + return entry + elif isinstance(entry, tuple) and len(entry) == 3: + return MatrixEntry(*entry) + else: + raise TypeError("Cannot convert type %s into MatrixEntry" % type(entry)) + + +class CoordinateMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a matrix in coordinate format. + + :param entries: An RDD of MatrixEntry inputs or + (long, long, float) tuples. + :param numRows: Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the max row + index plus one. + :param numCols: Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the max row + index plus one. + """ + def __init__(self, entries, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java CoordinateMatrix. + + Publicly, we require that `rows` be an RDD. However, for + internal usage, `rows` can also be a Java CoordinateMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries) + + >>> mat_diff = CoordinateMatrix(entries) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = CoordinateMatrix(mat._java_matrix_wrapper._java_model) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(entries, RDD): + entries = entries.map(_convert_to_matrix_entry) + # We use DataFrames for serialization of MatrixEntry entries + # from Python, so first convert the RDD to a DataFrame on + # this side. This will convert each MatrixEntry to a Row + # containing the 'i', 'j', and 'value' values, which can + # each be easily serialized. We will convert back to + # MatrixEntry inputs on the Scala side. + java_matrix = callMLlibFunc("createCoordinateMatrix", entries.toDF(), + long(numRows), long(numCols)) + elif (isinstance(entries, JavaObject) + and entries.getClass().getSimpleName() == "CoordinateMatrix"): + java_matrix = entries + else: + raise TypeError("entries should be an RDD of MatrixEntry entries or " + "(long, long, float) tuples, got %s" % type(entries)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def entries(self): + """ + Entries of the CoordinateMatrix stored as an RDD of + MatrixEntries. + + >>> mat = CoordinateMatrix(sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)])) + >>> entries = mat.entries + >>> entries.first() + MatrixEntry(0, 0, 1.2) + """ + # We use DataFrames for serialization of MatrixEntry entries + # from Java, so we first convert the RDD of entries to a + # DataFrame on the Scala/Java side. Then we map each Row in + # the DataFrame back to a MatrixEntry on this side. + entries_df = callMLlibFunc("getMatrixEntries", self._java_matrix_wrapper._java_model) + entries = entries_df.map(lambda row: MatrixEntry(row[0], row[1], row[2])) + return entries + + def numRows(self): + """ + Get or compute the number of rows. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(1, 0, 2), + ... MatrixEntry(2, 1, 3.7)]) + + >>> mat = CoordinateMatrix(entries) + >>> print(mat.numRows()) + 3 + + >>> mat = CoordinateMatrix(entries, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(1, 0, 2), + ... MatrixEntry(2, 1, 3.7)]) + + >>> mat = CoordinateMatrix(entries) + >>> print(mat.numCols()) + 2 + + >>> mat = CoordinateMatrix(entries, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toRowMatrix(self): + """ + Convert this matrix to a RowMatrix. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toRowMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, but the ensuing RowMatrix + >>> # will only have 2 rows since there are only entries on 2 + >>> # unique rows. + >>> print(mat.numRows()) + 2 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing RowMatrix + >>> # will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix") + return RowMatrix(java_row_matrix) + + def toIndexedRowMatrix(self): + """ + Convert this matrix to an IndexedRowMatrix. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # IndexedRowMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing + >>> # IndexedRowMatrix will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") + return IndexedRowMatrix(java_indexed_row_matrix) + + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toBlockMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing + >>> # BlockMatrix will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + + +def _convert_to_matrix_block_tuple(block): + if (isinstance(block, tuple) and len(block) == 2 + and isinstance(block[0], tuple) and len(block[0]) == 2 + and isinstance(block[1], Matrix)): + blockRowIndex = int(block[0][0]) + blockColIndex = int(block[0][1]) + subMatrix = block[1] + return ((blockRowIndex, blockColIndex), subMatrix) + else: + raise TypeError("Cannot convert type %s into a sub-matrix block tuple" % type(block)) + + +class BlockMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a distributed matrix in blocks of local matrices. + + :param blocks: An RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that + form this distributed matrix. If multiple blocks + with the same index exist, the results for + operations like add and multiply will be + unpredictable. + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + :param numRows: Number of rows of this matrix. If the supplied + value is less than or equal to zero, the number + of rows will be calculated when `numRows` is + invoked. + :param numCols: Number of columns of this matrix. If the supplied + value is less than or equal to zero, the number + of columns will be calculated when `numCols` is + invoked. + """ + def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java BlockMatrix. + + Publicly, we require that `blocks` be an RDD. However, for + internal usage, `blocks` can also be a Java BlockMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + + >>> mat_diff = BlockMatrix(blocks, 3, 2) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = BlockMatrix(mat._java_matrix_wrapper._java_model, 3, 2) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(blocks, RDD): + blocks = blocks.map(_convert_to_matrix_block_tuple) + # We use DataFrames for serialization of sub-matrix blocks + # from Python, so first convert the RDD to a DataFrame on + # this side. This will convert each sub-matrix block + # tuple to a Row containing the 'blockRowIndex', + # 'blockColIndex', and 'subMatrix' values, which can + # each be easily serialized. We will convert back to + # ((blockRowIndex, blockColIndex), sub-matrix) tuples on + # the Scala side. + java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(), + int(rowsPerBlock), int(colsPerBlock), + long(numRows), long(numCols)) + elif (isinstance(blocks, JavaObject) + and blocks.getClass().getSimpleName() == "BlockMatrix"): + java_matrix = blocks + else: + raise TypeError("blocks should be an RDD of sub-matrix blocks as " + "((int, int), matrix) tuples, got %s" % type(blocks)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def blocks(self): + """ + The RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that form this + distributed matrix. + + >>> mat = BlockMatrix( + ... sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2) + >>> blocks = mat.blocks + >>> blocks.first() + ((0, 0), DenseMatrix(3, 2, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 0)) + + """ + # We use DataFrames for serialization of sub-matrix blocks + # from Java, so we first convert the RDD of blocks to a + # DataFrame on the Scala/Java side. Then we map each Row in + # the DataFrame back to a sub-matrix block on this side. + blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model) + blocks = blocks_df.map(lambda row: ((row[0][0], row[0][1]), row[1])) + return blocks + + @property + def rowsPerBlock(self): + """ + Number of rows that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.rowsPerBlock + 3 + """ + return self._java_matrix_wrapper.call("rowsPerBlock") + + @property + def colsPerBlock(self): + """ + Number of columns that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.colsPerBlock + 2 + """ + return self._java_matrix_wrapper.call("colsPerBlock") + + @property + def numRowBlocks(self): + """ + Number of rows of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numRowBlocks + 2 + """ + return self._java_matrix_wrapper.call("numRowBlocks") + + @property + def numColBlocks(self): + """ + Number of columns of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numColBlocks + 1 + """ + return self._java_matrix_wrapper.call("numColBlocks") + + def numRows(self): + """ + Get or compute the number of rows. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numRows()) + 6 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numCols()) + 2 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toLocalMatrix(self): + """ + Collect the distributed matrix on the driver as a DenseMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing DenseMatrix will also have 6 rows. + >>> print(mat.numRows) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 + >>> # columns. The ensuing DenseMatrix will also have 2 columns. + >>> print(mat.numCols) + 2 + """ + return self._java_matrix_wrapper.call("toLocalMatrix") + + def toIndexedRowMatrix(self): + """ + Convert this matrix to an IndexedRowMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing IndexedRowMatrix will also have 6 rows. + >>> print(mat.numRows()) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 columns. + >>> # The ensuing IndexedRowMatrix will also have 2 columns. + >>> print(mat.numCols()) + 2 + """ + java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") + return IndexedRowMatrix(java_indexed_row_matrix) + + def toCoordinateMatrix(self): + """ + Convert this matrix to a CoordinateMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])), + ... ((1, 0), Matrices.dense(1, 2, [7, 8]))]) + >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix() + >>> mat.entries.take(3) + [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 2.0), MatrixEntry(1, 0, 7.0)] + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") + return CoordinateMatrix(java_coordinate_matrix) + + +def _test(): + import doctest + from pyspark import SparkContext + from pyspark.sql import SQLContext + from pyspark.mllib.linalg import Matrices + import pyspark.mllib.linalg.distributed + globs = pyspark.mllib.linalg.distributed.__dict__.copy() + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + globs['sqlContext'] = SQLContext(globs['sc']) + globs['Matrices'] = Matrices + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py deleted file mode 100644 index 20ee9d78bf5b..000000000000 --- a/python/pyspark/mllib/rand.py +++ /dev/null @@ -1,410 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Python package for random data generation. -""" - -from functools import wraps - -from pyspark.mllib.common import callMLlibFunc - - -__all__ = ['RandomRDDs', ] - - -def toArray(f): - @wraps(f) - def func(sc, *a, **kw): - rdd = f(sc, *a, **kw) - return rdd.map(lambda vec: vec.toArray()) - return func - - -class RandomRDDs(object): - """ - Generator methods for creating RDDs comprised of i.i.d samples from - some distribution. - """ - - @staticmethod - def uniformRDD(sc, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the - uniform distribution U(0.0, 1.0). - - To transform the distribution in the generated RDD from U(0.0, 1.0) - to U(a, b), use - C{RandomRDDs.uniformRDD(sc, n, p, seed)\ - .map(lambda v: a + (b - a) * v)} - - :param sc: SparkContext used to create the RDD. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. - - >>> x = RandomRDDs.uniformRDD(sc, 100).collect() - >>> len(x) - 100 - >>> max(x) <= 1.0 and min(x) >= 0.0 - True - >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() - 4 - >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() - >>> parts == sc.defaultParallelism - True - """ - return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) - - @staticmethod - def normalRDD(sc, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the standard normal - distribution. - - To transform the distribution in the generated RDD from standard normal - to some other normal N(mean, sigma^2), use - C{RandomRDDs.normal(sc, n, p, seed)\ - .map(lambda v: mean + sigma * v)} - - :param sc: SparkContext used to create the RDD. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). - - >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - 0.0) < 0.1 - True - >>> abs(stats.stdev() - 1.0) < 0.1 - True - """ - return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) - - @staticmethod - def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the log normal - distribution with the input mean and standard distribution. - - :param sc: SparkContext used to create the RDD. - :param mean: mean for the log Normal distribution - :param std: std for the log Normal distribution - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std). - - >>> from math import sqrt, exp - >>> mean = 0.0 - >>> std = 1.0 - >>> expMean = exp(mean + 0.5 * std * std) - >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - expMean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - expStd) < 0.5 - True - """ - return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std), - size, numPartitions, seed) - - @staticmethod - def poissonRDD(sc, mean, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Poisson - distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or lambda, for the Poisson distribution. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). - - >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) - - @staticmethod - def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Exponential - distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or 1 / lambda, for the Exponential distribution. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). - - >>> mean = 2.0 - >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) - - @staticmethod - def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Gamma - distribution with the input shape and scale. - - :param sc: SparkContext used to create the RDD. - :param shape: shape (> 0) parameter for the Gamma distribution - :param scale: scale (> 0) parameter for the Gamma distribution - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale). - - >>> from math import sqrt - >>> shape = 1.0 - >>> scale = 2.0 - >>> expMean = shape * scale - >>> expStd = sqrt(shape * scale * scale) - >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - expMean) < 0.5 - True - >>> abs(stats.stdev() - expStd) < 0.5 - True - """ - return callMLlibFunc("gammaRDD", sc._jsc, float(shape), - float(scale), size, numPartitions, seed) - - @staticmethod - @toArray - def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the uniform distribution U(0.0, 1.0). - - :param sc: SparkContext used to create the RDD. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD. - :param seed: Seed for the RNG that generates the seed for the generator in each partition. - :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. - - >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) - >>> mat.shape - (10, 10) - >>> mat.max() <= 1.0 and mat.min() >= 0.0 - True - >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() - 4 - """ - return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the standard normal distribution. - - :param sc: SparkContext used to create the RDD. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. - - >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - 0.0) < 0.1 - True - >>> abs(mat.std() - 1.0) < 0.1 - True - """ - return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the log normal distribution. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean of the log normal distribution - :param std: Standard Deviation of the log normal distribution - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`. - - >>> import numpy as np - >>> from math import sqrt, exp - >>> mean = 0.0 - >>> std = 1.0 - >>> expMean = exp(mean + 0.5 * std * std) - >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \ - 100, 100, seed=1L).collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 - True - >>> abs(mat.std() - expStd) < 0.1 - True - """ - return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std), - numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Poisson distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or lambda, for the Poisson distribution. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). - - >>> import numpy as np - >>> mean = 100.0 - >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) - >>> mat = np.mat(rdd.collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, - numPartitions, seed) - - @staticmethod - @toArray - def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Exponential distribution with the input mean. - - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or 1 / lambda, for the Exponential distribution. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean). - - >>> import numpy as np - >>> mean = 0.5 - >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L) - >>> mat = np.mat(rdd.collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 - True - """ - return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols, - numPartitions, seed) - - @staticmethod - @toArray - def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Gamma distribution. - - :param sc: SparkContext used to create the RDD. - :param shape: Shape (> 0) of the Gamma distribution - :param scale: Scale (> 0) of the Gamma distribution - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale). - - >>> import numpy as np - >>> from math import sqrt - >>> shape = 1.0 - >>> scale = 2.0 - >>> expMean = shape * scale - >>> expStd = sqrt(shape * scale * scale) - >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \ - 100, 100, seed=1L).collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 - True - >>> abs(mat.std() - expStd) < 0.1 - True - """ - return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale), - numRows, numCols, numPartitions, seed) - - -def _test(): - import doctest - from pyspark.context import SparkContext - globs = globals().copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py new file mode 100644 index 000000000000..06fbc0eb6aef --- /dev/null +++ b/python/pyspark/mllib/random.py @@ -0,0 +1,409 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + +from functools import wraps + +from pyspark.mllib.common import callMLlibFunc + + +__all__ = ['RandomRDDs', ] + + +def toArray(f): + @wraps(f) + def func(sc, *a, **kw): + rdd = f(sc, *a, **kw) + return rdd.map(lambda vec: vec.toArray()) + return func + + +class RandomRDDs(object): + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution U(0.0, 1.0). + + To transform the distribution in the generated RDD from U(0.0, 1.0) + to U(a, b), use + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma^2), use + C{RandomRDDs.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) + + @staticmethod + def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the log normal + distribution with the input mean and standard distribution. + + :param sc: SparkContext used to create the RDD. + :param mean: mean for the log Normal distribution + :param std: std for the log Normal distribution + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std). + + >>> from math import sqrt, exp + >>> mean = 0.0 + >>> std = 1.0 + >>> expMean = exp(mean + 0.5 * std * std) + >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) + >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - expMean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - expStd) < 0.5 + True + """ + return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std), + size, numPartitions, seed) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Poisson + distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). + + >>> mean = 100.0 + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) + + @staticmethod + def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Exponential + distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or 1 / lambda, for the Exponential distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). + + >>> mean = 2.0 + >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) + + @staticmethod + def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Gamma + distribution with the input shape and scale. + + :param sc: SparkContext used to create the RDD. + :param shape: shape (> 0) parameter for the Gamma distribution + :param scale: scale (> 0) parameter for the Gamma distribution + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale). + + >>> from math import sqrt + >>> shape = 1.0 + >>> scale = 2.0 + >>> expMean = shape * scale + >>> expStd = sqrt(shape * scale * scale) + >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2) + >>> stats = x.stats() + >>> stats.count() + 1000 + >>> abs(stats.mean() - expMean) < 0.5 + True + >>> abs(stats.stdev() - expStd) < 0.5 + True + """ + return callMLlibFunc("gammaRDD", sc._jsc, float(shape), + float(scale), size, numPartitions, seed) + + @staticmethod + @toArray + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the uniform distribution U(0.0, 1.0). + + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD. + :param seed: Seed for the RNG that generates the seed for the generator in each partition. + :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the standard normal distribution. + + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the log normal distribution. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean of the log normal distribution + :param std: Standard Deviation of the log normal distribution + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`. + + >>> import numpy as np + >>> from math import sqrt, exp + >>> mean = 0.0 + >>> std = 1.0 + >>> expMean = exp(mean + 0.5 * std * std) + >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) + >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect() + >>> mat = np.matrix(m) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - expMean) < 0.1 + True + >>> abs(mat.std() - expStd) < 0.1 + True + """ + return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std), + numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Poisson distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, + numPartitions, seed) + + @staticmethod + @toArray + def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Exponential distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or 1 / lambda, for the Exponential distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean). + + >>> import numpy as np + >>> mean = 0.5 + >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols, + numPartitions, seed) + + @staticmethod + @toArray + def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Gamma distribution. + + :param sc: SparkContext used to create the RDD. + :param shape: Shape (> 0) of the Gamma distribution + :param scale: Scale (> 0) of the Gamma distribution + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale). + + >>> import numpy as np + >>> from math import sqrt + >>> shape = 1.0 + >>> scale = 2.0 + >>> expMean = shape * scale + >>> expStd = sqrt(shape * scale * scale) + >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - expMean) < 0.1 + True + >>> abs(mat.std() - expStd) < 0.1 + True + """ + return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale), + numRows, numCols, numPartitions, seed) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 0d99e6dedfad..506ca2151cce 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,11 +15,14 @@ # limitations under the License. # +import array from collections import namedtuple from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.util import JavaLoader, JavaSaveable +from pyspark.sql import DataFrame __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -39,7 +42,8 @@ def __reduce__(self): return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(JavaModelWrapper): +@inherit_doc +class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. @@ -50,7 +54,7 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model.predict(2, 2) - 0.43... + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 2, seed=0) @@ -61,6 +65,13 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model.userFeatures().collect() [(1, array('d', [...])), (2, array('d', [...]))] + >>> model.recommendUsers(1, 2) + [Rating(user=2, product=1, rating=1.9...), Rating(user=1, product=1, rating=1.0...)] + >>> model.recommendProducts(1, 2) + [Rating(user=1, product=2, rating=1.9...), Rating(user=1, product=1, rating=1.0...)] + >>> model.rank + 4 + >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] >>> len(latents) == 4 @@ -75,41 +86,105 @@ class MatrixFactorizationModel(JavaModelWrapper): True >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) + >>> model.predict(2, 2) + 3.8... + + >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) + >>> model = ALS.train(df, 1, nonnegative=True, seed=10) + >>> model.predict(2, 2) 3.8... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) - 0.43... + >>> model.predict(2, 2) + 0.4... + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = MatrixFactorizationModel.load(sc, path) + >>> sameModel.predict(2, 2) + 0.4... + >>> sameModel.predictAll(testset).collect() + [Rating(... + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def predict(self, user, product): + """ + Predicts rating for the given user and product. + """ return self._java_model.predict(int(user), int(product)) def predictAll(self, user_product): + """ + Returns a list of predicted ratings for input user and product pairs. + """ assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() assert len(first) == 2, "user_product should be RDD of (user, product)" - user_product = user_product.map(lambda (u, p): (int(u), int(p))) + user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) def userFeatures(self): - return self.call("getUserFeatures") + """ + Returns a paired RDD, where the first element is the user and the + second is an array of features corresponding to that user. + """ + return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) def productFeatures(self): - return self.call("getProductFeatures") + """ + Returns a paired RDD, where the first element is the product and the + second is an array of features corresponding to that product. + """ + return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + + def recommendUsers(self, product, num): + """ + Recommends the top "num" number of users for a given product and returns a list + of Rating objects sorted by the predicted rating in descending order. + """ + return list(self.call("recommendUsers", product, num)) + + def recommendProducts(self, user, num): + """ + Recommends the top "num" number of products for a given user and returns a list + of Rating objects sorted by the predicted rating in descending order. + """ + return list(self.call("recommendProducts", user, num)) + + @property + def rank(self): + return self.call("rank") + + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + return MatrixFactorizationModel(wrapper) class ALS(object): @classmethod def _prepare(cls, ratings): - assert isinstance(ratings, RDD), "ratings should be RDD" + if isinstance(ratings, RDD): + pass + elif isinstance(ratings, DataFrame): + ratings = ratings.rdd + else: + raise TypeError("Ratings should be represented by either an RDD or a DataFrame, " + "but got %s." % type(ratings)) first = ratings.first() - if not isinstance(first, Rating): - if isinstance(first, (tuple, list)): - ratings = ratings.map(lambda x: Rating(*x)) - else: - raise ValueError("rating should be RDD of Rating or tuple/list") + if isinstance(first, Rating): + pass + elif isinstance(first, (tuple, list)): + ratings = ratings.map(lambda x: Rating(*x)) + else: + raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first)) return ratings @classmethod @@ -130,8 +205,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp def _test(): import doctest import pyspark.mllib.recommendation + from pyspark.sql import SQLContext globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 210060140fd9..41946e3674fb 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,21 +18,30 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark import RDD +from pyspark.streaming.dstream import DStream +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc +from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector +from pyspark.mllib.util import Saveable, Loader -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', - 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] +__all__ = ['LabeledPoint', 'LinearModel', + 'LinearRegressionModel', 'LinearRegressionWithSGD', + 'RidgeRegressionModel', 'RidgeRegressionWithSGD', + 'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel', + 'IsotonicRegression'] class LabeledPoint(object): """ - The features and labels of a data point. + Class that represents the features and labels of a data point. :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, list, - pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) + :param features: Vector of features for this point (NumPy array, + list, pyspark.mllib.linalg.SparseVector, or scipy.sparse + column matrix) + + Note: 'label' and 'features' are accessible as class attributes. """ def __init__(self, label, features): @@ -51,7 +60,12 @@ def __repr__(self): class LinearModel(object): - """A linear model that has a vector of coefficients and an intercept.""" + """ + A linear model that has a vector of coefficients and an intercept. + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. + """ def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) @@ -69,6 +83,7 @@ def __repr__(self): return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) +@inherit_doc class LinearRegressionModelBase(LinearModel): """A linear regression model. @@ -82,13 +97,16 @@ class LinearRegressionModelBase(LinearModel): def predict(self, x): """ - Predict the value of the dependent variable given a vector x - containing values for the independent variables. + Predict the value of the dependent variable given a vector or + an RDD of vectors containing values for the independent variables. """ + if isinstance(x, RDD): + return x.map(self.predict) x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept +@inherit_doc class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. @@ -100,57 +118,116 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=np.array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=np.array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LinearRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) + >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, + ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", + ... intercept=True, validateData=True) >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LinearRegressionModel(weights, intercept) + return model # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): - raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) - initial_weights = initial_weights or [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + raise TypeError("data should be an RDD of LabeledPoint, but got %s" % type(first)) + if initial_weights is None: + initial_weights = [0.0] * len(data.first().features) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.0, regType=None, intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False, + validateData=True): """ - Train a linear regression model on the given data. + Train a linear regression model using Stochastic Gradient + Descent (SGD). + This solves the least squares regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2, - :param data: The training data. - :param iterations: The number of iterations (default: 100). + which is the mean squared error. + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.0). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.0). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization (lasso), @@ -159,23 +236,28 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - @param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept)) + regType, bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) +@inherit_doc class LassoModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_1 penalty term. + """A linear regression model derived from a least-squares fit with + an l_1 penalty term. >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -184,44 +266,116 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LassoModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LassoModel(weights, intercept) + return model class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): - """Train a Lasso regression model on the given data.""" + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): + """ + Train a regression model with L1-regularization using + Stochastic Gradient Descent. + This solves the l1-regularized least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. + + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). + :param step: The step parameter used in SGD + (default: 1.0). + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). + :param initialWeights: The initial weights (default: None). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) + """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) +@inherit_doc class RidgeRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_2 penalty term. + """A linear regression model derived from a least-squares fit with + an l_2 penalty term. >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -230,46 +384,302 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = RidgeRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = RidgeRegressionModel(weights, intercept) + return model class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): - """Train a ridge regression model on the given data.""" + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): + """ + Train a regression model with L2-regularization using + Stochastic Gradient Descent. + This solves the l2-regularized least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. + + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). + :param step: The step parameter used in SGD + (default: 1.0). + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). + :param initialWeights: The initial weights (default: None). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) + """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) +class IsotonicRegressionModel(Saveable, Loader): + + """ + Regression model for isotonic regression. + + :param boundaries: Array of boundaries for which predictions are + known. Boundaries must be sorted in increasing order. + :param predictions: Array of predictions associated to the + boundaries at the same index. Results of isotonic + regression and therefore monotone. + :param isotonic: indicates whether this is isotonic or antitonic. + + >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] + >>> irm = IsotonicRegression.train(sc.parallelize(data)) + >>> irm.predict(3) + 2.0 + >>> irm.predict(5) + 16.5 + >>> irm.predict(sc.parallelize([3, 5])).collect() + [2.0, 16.5] + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> irm.save(sc, path) + >>> sameModel = IsotonicRegressionModel.load(sc, path) + >>> sameModel.predict(3) + 2.0 + >>> sameModel.predict(5) + 16.5 + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + """ + + def __init__(self, boundaries, predictions, isotonic): + self.boundaries = boundaries + self.predictions = predictions + self.isotonic = isotonic + + def predict(self, x): + """ + Predict labels for provided features. + Using a piecewise linear function. + 1) If x exactly matches a boundary then associated prediction + is returned. In case there are multiple predictions with the + same boundary then one of them is returned. Which one is + undefined (same as java.util.Arrays.binarySearch). + 2) If x is lower or higher than all boundaries then first or + last prediction is returned respectively. In case there are + multiple predictions with the same boundary then the lowest + or highest is returned respectively. + 3) If x falls between two values in boundary array then + prediction is treated as piecewise linear function and + interpolated value is returned. In case there are multiple + values with the same boundary then the same rules as in 2) + are used. + + :param x: Feature or RDD of Features to be labeled. + """ + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) + return np.interp(x, self.boundaries, self.predictions) + + def save(self, sc, path): + java_boundaries = _py2java(sc, self.boundaries.tolist()) + java_predictions = _py2java(sc, self.predictions.tolist()) + java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel( + java_boundaries, java_predictions, self.isotonic) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load( + sc._jsc.sc(), path) + py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray() + py_predictions = _java2py(sc, java_model.predictionVector()).toArray() + return IsotonicRegressionModel(py_boundaries, py_predictions, java_model.isotonic) + + +class IsotonicRegression(object): + + @classmethod + def train(cls, data, isotonic=True): + """ + Train a isotonic regression model on the given data. + + :param data: RDD of (label, feature, weight) tuples. + :param isotonic: Whether this is isotonic or antitonic. + """ + boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", + data.map(_convert_to_vector), bool(isotonic)) + return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) + + +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LinearRegression with SGD on a batch of data. + + The problem minimized is (1 / n_samples) * (y - weights'X)**2. + After training on a batch of data, the weights obtained at the end of + training are used as initial weights for the next batch. + + :param: stepSize Step size for each iteration of gradient descent. + :param: numIterations Total number of iterations run. + :param: miniBatchFraction Fraction of data on which SGD is run for each + iteration. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + self.stepSize = stepSize + self.numIterations = numIterations + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLinearRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn + """ + initialWeights = _convert_to_vector(initialWeights) + self._model = LinearRegressionModel(initialWeights, 0) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LinearRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LinearRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights, + self._model.intercept) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext import pyspark.mllib.regression globs = pyspark.mllib.regression.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py new file mode 100644 index 000000000000..7da921976d4d --- /dev/null +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version > '3': + xrange = range + +import numpy as np + +from pyspark.mllib.common import callMLlibFunc +from pyspark.rdd import RDD + + +class KernelDensity(object): + """ + .. note:: Experimental + + Estimate probability density at required points given a RDD of samples + from the population. + + >>> kd = KernelDensity() + >>> sample = sc.parallelize([0.0, 1.0]) + >>> kd.setSample(sample) + >>> kd.estimate([0.0, 1.0]) + array([ 0.12938758, 0.12938758]) + """ + def __init__(self): + self._bandwidth = 1.0 + self._sample = None + + def setBandwidth(self, bandwidth): + """Set bandwidth of each sample. Defaults to 1.0""" + self._bandwidth = bandwidth + + def setSample(self, sample): + """Set sample points from the population. Should be a RDD""" + if not isinstance(sample, RDD): + raise TypeError("samples should be a RDD, received %s" % type(sample)) + self._sample = sample + + def estimate(self, points): + """Estimate the probability density at points""" + points = list(points) + densities = callMLlibFunc( + "estimateKernelDensity", self._sample, self._bandwidth, points) + return np.asarray(densities) diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index b686d955a008..c8a721d3fe41 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -21,5 +21,8 @@ from pyspark.mllib.stat._statistics import * from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.stat.test import ChiSqTestResult +from pyspark.mllib.stat.KernelDensity import KernelDensity -__all__ = ["Statistics", "MultivariateStatisticalSummary", "MultivariateGaussian"] +__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", + "MultivariateGaussian", "KernelDensity"] diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 218ac148ca99..36c8f48a4a88 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,11 +15,15 @@ # limitations under the License. # -from pyspark import RDD +import sys +if sys.version >= '3': + basestring = str + +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.stat.test import ChiSqTestResult +from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult __all__ = ['MultivariateStatisticalSummary', 'Statistics'] @@ -38,7 +42,7 @@ def variance(self): return self.call("variance").toArray() def count(self): - return self.call("count") + return int(self.call("count")) def numNonzeros(self): return self.call("numNonzeros").toArray() @@ -49,6 +53,12 @@ def max(self): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() + + def normL2(self): + return self.call("normL2").toArray() + class Statistics(object): @@ -72,7 +82,7 @@ def colStats(rdd): >>> cStats.variance() array([ 4., 13., 0., 25.]) >>> cStats.count() - 3L + 3 >>> cStats.numNonzeros() array([ 3., 2., 0., 3.]) >>> cStats.max() @@ -118,20 +128,20 @@ def corr(x, y=None, method=None): >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) >>> pearsonCorr = Statistics.corr(rdd) - >>> print str(pearsonCorr).replace('nan', 'NaN') + >>> print(str(pearsonCorr).replace('nan', 'NaN')) [[ 1. 0.05564149 NaN 0.40047142] [ 0.05564149 1. NaN 0.91359586] [ NaN NaN 1. NaN] [ 0.40047142 0.91359586 NaN 1. ]] >>> spearmanCorr = Statistics.corr(rdd, method="spearman") - >>> print str(spearmanCorr).replace('nan', 'NaN') + >>> print(str(spearmanCorr).replace('nan', 'NaN')) [[ 1. 0.10540926 NaN 0.4 ] [ 0.10540926 1. NaN 0.9486833 ] [ NaN NaN 1. NaN] [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") - ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... print("Method name as second argument without 'method=' shouldn't be allowed.") ... except TypeError: ... pass """ @@ -147,6 +157,7 @@ def corr(x, y=None, method=None): return callMLlibFunc("corr", x.map(float), y.map(float), method) @staticmethod + @ignore_unicode_prefix def chiSqTest(observed, expected=None): """ .. note:: Experimental @@ -182,11 +193,11 @@ def chiSqTest(observed, expected=None): >>> from pyspark.mllib.linalg import Vectors, Matrices >>> observed = Vectors.dense([4, 6, 5]) >>> pearson = Statistics.chiSqTest(observed) - >>> print pearson.statistic + >>> print(pearson.statistic) 0.4 >>> pearson.degreesOfFreedom 2 - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.8187 >>> pearson.method u'pearson' @@ -196,12 +207,12 @@ def chiSqTest(observed, expected=None): >>> observed = Vectors.dense([21, 38, 43, 80]) >>> expected = Vectors.dense([3, 5, 7, 20]) >>> pearson = Statistics.chiSqTest(observed, expected) - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.0027 >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - >>> print round(chi.statistic, 4) + >>> print(round(chi.statistic, 4)) 21.9958 >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), @@ -212,9 +223,9 @@ def chiSqTest(observed, expected=None): ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] >>> rdd = sc.parallelize(data, 4) >>> chi = Statistics.chiSqTest(rdd) - >>> print chi[0].statistic + >>> print(chi[0].statistic) 0.75 - >>> print chi[1].statistic + >>> print(chi[1].statistic) 1.5 """ if isinstance(observed, RDD): @@ -231,6 +242,67 @@ def chiSqTest(observed, expected=None): jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) return ChiSqTestResult(jmodel) + @staticmethod + @ignore_unicode_prefix + def kolmogorovSmirnovTest(data, distName="norm", *params): + """ + .. note:: Experimental + + Performs the Kolmogorov-Smirnov (KS) test for data sampled from + a continuous distribution. It tests the null hypothesis that + the data is generated from a particular distribution. + + The given data is sorted and the Empirical Cumulative + Distribution Function (ECDF) is calculated + which for a given point is the number of points having a CDF + value lesser than it divided by the total number of points. + + Since the data is sorted, this is a step function + that rises by (1 / length of data) for every ordered point. + + The KS statistic gives us the maximum distance between the + ECDF and the CDF. Intuitively if this statistic is large, the + probabilty that the null hypothesis is true becomes small. + For specific details of the implementation, please have a look + at the Scala documentation. + + :param data: RDD, samples from the data + :param distName: string, currently only "norm" is supported. + (Normal distribution) to calculate the + theoretical distribution of the data. + :param params: additional values which need to be provided for + a certain distribution. + If not provided, the default values are used. + :return: KolmogorovSmirnovTestResult object containing the test + statistic, degrees of freedom, p-value, + the method used, and the null hypothesis. + + >>> kstest = Statistics.kolmogorovSmirnovTest + >>> data = sc.parallelize([-1.0, 0.0, 1.0]) + >>> ksmodel = kstest(data, "norm") + >>> print(round(ksmodel.pValue, 3)) + 1.0 + >>> print(round(ksmodel.statistic, 3)) + 0.175 + >>> ksmodel.nullHypothesis + u'Sample follows theoretical distribution' + + >>> data = sc.parallelize([2.0, 3.0, 4.0]) + >>> ksmodel = kstest(data, "norm", 3.0, 1.0) + >>> print(round(ksmodel.pValue, 3)) + 1.0 + >>> print(round(ksmodel.statistic, 3)) + 0.175 + """ + if not isinstance(data, RDD): + raise TypeError("data should be an RDD, got %s." % type(data)) + if not isinstance(distName, basestring): + raise TypeError("distName should be a string, got %s." % type(distName)) + + params = [float(param) for param in params] + return KolmogorovSmirnovTestResult( + callMLlibFunc("kolmogorovSmirnovTest", data, distName, params)) + def _test(): import doctest diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py index 07792e153204..46f7a1d2f277 100644 --- a/python/pyspark/mllib/stat/distribution.py +++ b/python/pyspark/mllib/stat/distribution.py @@ -22,7 +22,8 @@ class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): - """ Represents a (mu, sigma) tuple + """Represents a (mu, sigma) tuple + >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0))) >>> (m.mu, m.sigma.toArray()) (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py index 762506e952b4..0abe104049ff 100644 --- a/python/pyspark/mllib/stat/test.py +++ b/python/pyspark/mllib/stat/test.py @@ -15,24 +15,16 @@ # limitations under the License. # -from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.common import inherit_doc, JavaModelWrapper -__all__ = ["ChiSqTestResult"] +__all__ = ["ChiSqTestResult", "KolmogorovSmirnovTestResult"] -class ChiSqTestResult(JavaModelWrapper): +class TestResult(JavaModelWrapper): """ - .. note:: Experimental - - Object containing the test results for the chi-squared hypothesis test. + Base class for all test results. """ - @property - def method(self): - """ - Name of the test method - """ - return self._java_model.method() @property def pValue(self): @@ -67,3 +59,24 @@ def nullHypothesis(self): def __str__(self): return self._java_model.toString() + + +@inherit_doc +class ChiSqTestResult(TestResult): + """ + Contains test results for the chi-squared hypothesis test. + """ + + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + +@inherit_doc +class KolmogorovSmirnovTestResult(TestResult): + """ + Contains test results for the Kolmogorov-Smirnov test. + """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 42aa22873772..5097c5e8ba4c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,12 +19,22 @@ Fuller unit tests for Python MLlib. """ +import os import sys +import tempfile import array as pyarray +from time import time, sleep +from shutil import rmtree + +from numpy import ( + array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) +from numpy import sum as array_sum -from numpy import array, array_equal from py4j.protocol import Py4JJavaError +if sys.version > '3': + basestring = str + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -34,14 +44,24 @@ else: import unittest +from pyspark import SparkContext +from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, Vectors, Matrices -from pyspark.mllib.regression import LabeledPoint + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import Word2Vec +from pyspark.mllib.feature import IDF +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct +from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer +from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.streaming import StreamingContext _have_scipy = False try: @@ -52,6 +72,59 @@ pass ser = PickleSerializer() +sc = SparkContext('local[4]', "MLlib tests") + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc + + +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + + @staticmethod + def _eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for streaming ML tests. + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return + sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) def _squared_distance(a, b): @@ -61,16 +134,16 @@ def _squared_distance(a, b): return b.squared_distance(a) -class VectorTests(PySparkTestCase): +class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) + nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs))) + nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): @@ -80,6 +153,9 @@ def test_serialize(self): self._test_serialize(SparseVector(4, {1: 1, 3: 2})) self._test_serialize(SparseVector(3, {})) self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -89,17 +165,22 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) self.assertEquals(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) self.assertEquals(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) self.assertEquals(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEquals(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) dv = DenseVector(array([1., 2., 3., 4.])) lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) self.assertEquals(15.0, _squared_distance(sv, dv)) self.assertEquals(25.0, _squared_distance(sv, lst)) self.assertEquals(20.0, _squared_distance(dv, lst)) @@ -109,6 +190,9 @@ def test_squared_distance(self): self.assertEquals(0.0, _squared_distance(sv, sv)) self.assertEquals(0.0, _squared_distance(dv, dv)) self.assertEquals(0.0, _squared_distance(lst, lst)) + self.assertEquals(25.0, _squared_distance(sv, lst1)) + self.assertEquals(3.0, _squared_distance(sv, arr)) + self.assertEquals(3.0, _squared_distance(sv, narr)) def test_conversion(self): # numpy arrays should be automatically upcast to float64 @@ -129,11 +213,157 @@ def test_sparse_vector_indexing(self): self.assertEquals(sv[-1], 2) self.assertEquals(sv[-2], 0) self.assertEquals(sv[-4], 0) - for ind in [4, -5, 7.8]: + for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) - - -class ListTests(PySparkTestCase): + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEquals(sm1.numRows, 3) + self.assertEquals(sm1.numCols, 4) + self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEquals(sm1.numRows, smnew.numRows) + self.assertEquals(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEquals(sm1t.numRows, 3) + self.assertEquals(sm1t.numCols, 4) + self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEquals(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEquals(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_parse_vector(self): + a = DenseVector([3, 4, 6, 7]) + self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]') + self.assertTrue(Vectors.parse(str(a)), a) + a = SparseVector(4, [0, 2], [3, 4]) + self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])') + self.assertTrue(Vectors.parse(str(a)), a) + a = SparseVector(10, [0, 1], [4, 5]) + self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + +class ListTests(MLlibTestCase): """ Test MLlib algorithms on plain lists, to make sure they're passed through @@ -148,7 +378,8 @@ def test_kmeans(self): [1.1, 0], [1.2, 0], ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", + initializationSteps=7, epsilon=1e-4) self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) @@ -158,9 +389,11 @@ def test_kmeans_deterministic(self): Y = range(0, 100, 10) data = [[x, y] for x, y in zip(X, Y)] clusters1 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", seed=42) + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) clusters2 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", seed=42) + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) centers1 = clusters1.centers centers2 = clusters2.centers for c1, c2 in zip(centers1, centers2): @@ -176,7 +409,7 @@ def test_gmm(self): [-6, -7], ]) clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=100, seed=56) + maxIterations=10, seed=56) labels = clusters.predict(data).collect() self.assertEquals(labels[0], labels[1]) self.assertEquals(labels[2], labels[3]) @@ -187,15 +420,16 @@ def test_gmm_deterministic(self): y = range(0, 100, 10) data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=100, seed=63) + maxIterations=10, seed=63) clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=100, seed=63) + maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEquals(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees + from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ + RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -205,13 +439,15 @@ def test_classification(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] - lr_model = LogisticRegressionWithSGD.train(rdd) + temp_dir = tempfile.mkdtemp() + + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - svm_model = SVMWithSGD.train(rdd) + svm_model = SVMWithSGD.train(rdd, iterations=10) self.assertTrue(svm_model.predict(features[0]) <= 0) self.assertTrue(svm_model.predict(features[1]) > 0) self.assertTrue(svm_model.predict(features[2]) <= 0) @@ -225,26 +461,47 @@ def test_classification(self): categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + dt_model_dir = os.path.join(temp_dir, "dt") + dt_model.save(self.sc, dt_model_dir) + same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) + self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) + rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) self.assertTrue(rf_model.predict(features[3]) > 0) + rf_model_dir = os.path.join(temp_dir, "rf") + rf_model.save(self.sc, rf_model_dir) + same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) + self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) + gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) self.assertTrue(gbt_model.predict(features[0]) <= 0) self.assertTrue(gbt_model.predict(features[1]) > 0) self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) + gbt_model_dir = os.path.join(temp_dir, "gbt") + gbt_model.save(self.sc, gbt_model_dir) + same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) + self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) + + try: + rmtree(temp_dir) + except OSError: + pass + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD @@ -258,19 +515,19 @@ def test_regression(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] - lr_model = LinearRegressionWithSGD.train(rdd) + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - lasso_model = LassoWithSGD.train(rdd) + lasso_model = LassoWithSGD.train(rdd, iterations=10) self.assertTrue(lasso_model.predict(features[0]) <= 0) self.assertTrue(lasso_model.predict(features[1]) > 0) self.assertTrue(lasso_model.predict(features[2]) <= 0) self.assertTrue(lasso_model.predict(features[3]) > 0) - rr_model = RidgeRegressionWithSGD.train(rdd) + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(rr_model.predict(features[0]) <= 0) self.assertTrue(rr_model.predict(features[1]) > 0) self.assertTrue(rr_model.predict(features[2]) <= 0) @@ -278,28 +535,42 @@ def test_regression(self): categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) self.assertTrue(rf_model.predict(features[3]) > 0) gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) self.assertTrue(gbt_model.predict(features[0]) <= 0) self.assertTrue(gbt_model.predict(features[1]) > 0) self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + except ValueError: + self.fail() + + # Verify that maxBins is being passed through + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) + with self.assertRaises(Exception) as cm: + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) + -class StatTests(PySparkTestCase): +class StatTests(MLlibTestCase): # SPARK-4023 def test_col_with_different_rdds(self): # numpy @@ -315,8 +586,21 @@ def test_col_with_different_rdds(self): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) -class VectorUDTTests(PySparkTestCase): + +class VectorUDTTests(MLlibTestCase): dv0 = DenseVector([]) dv1 = DenseVector([1.0, 2.0]) @@ -334,11 +618,11 @@ def test_serialization(self): def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = rdd.toDF() + schema = df.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) - vectors = srdd.map(lambda p: p.features).collect() + vectors = df.map(lambda p: p.features).collect() self.assertEqual(len(vectors), 2) for v in vectors: if isinstance(v, SparseVector): @@ -346,11 +630,43 @@ def test_infer_schema(self): elif isinstance(v, DenseVector): self.assertEqual(v, self.dv1) else: - raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) @unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): +class SciPyTests(MLlibTestCase): """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, @@ -491,7 +807,7 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) -class ChiSqTestTests(PySparkTestCase): +class ChiSqTestTests(MLlibTestCase): def test_goodness_of_fit(self): from numpy import inf @@ -588,9 +904,597 @@ def test_right_number_of_results(self): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class KolmogorovSmirnovTest(MLlibTestCase): + + def test_R_implementation_equivalence(self): + data = self.sc.parallelize([ + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ]) + model = Statistics.kolmogorovSmirnovTest(data, "norm") + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + +class SerDeTest(MLlibTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +class FeatureTest(MLlibTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(MLlibTestCase): + def test_word2vec_setters(self): + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) + self.assertEquals(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEquals(model.numPartitions, 2) + self.assertEquals(model.numIterations, 10) + self.assertEquals(model.seed, 1024) + self.assertEquals(model.minCount, 3) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEquals(len(model.getVectors()), 3) + + +class StandardScalerTests(MLlibTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEquals(stkm._k, 5) + self.assertEquals(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEquals( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + self.ssc.start() + + def condition(): + self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + return True + self._eventually(condition, catch_assertions=True) + + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offset for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + self.ssc.start() + + # Give enough time to train the model. + def condition(): + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + return True + self._eventually(condition, catch_assertions=True) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + self.ssc.start() + + def condition(): + self.assertEquals(result, [[0], [1], [2], [3]]) + return True + + self._eventually(condition, catch_assertions=True) + + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + self.ssc.start() + + def condition(): + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + return True + + self._eventually(condition, catch_assertions=True) + + +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + self.ssc.start() + + def condition(): + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + self.ssc.start() + + def condition(): + self.assertEquals(len(models), len(input_batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, 60.0, catch_assertions=True) + + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + # Test that weights improve with a small tolerance + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + self.ssc.start() + + def condition(): + self.assertEquals(len(true_predicted), len(input_batches)) + return True + + self._eventually(condition, catch_assertions=True) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + self.ssc.start() + + def condition(): + # Test that the improvement in error is > 0.3 + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 0.3) + if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertEquals(len(model_weights), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + w = array(model_weights) + diff = w[1:] - w[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + + def condition(): + self.assertEquals(len(samples), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(mean(abs(true) - abs(predicted))) + + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + + def condition(): + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 2) + if len(errors) >= 3 and errors[1] - errors[-1] > 2: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") + sc.stop() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index aae48f213246..372b86a7c95d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -20,19 +20,24 @@ import random from pyspark import SparkContext, RDD -from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', - 'RandomForest', 'GradientBoostedTrees'] + 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees'] -class TreeEnsembleModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): def predict(self, x): """ Predict values for a single data point or an RDD of points using the model trained. + + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -48,7 +53,8 @@ def numTrees(self): def totalNumNodes(self): """ - Get total number of nodes, summed over all trees in the ensemble. + Get total number of nodes, summed over all trees in the + ensemble. """ return self.call("totalNumNodes") @@ -61,7 +67,7 @@ def toDebugString(self): return self._java_model.toDebugString() -class DecisionTreeModel(JavaModelWrapper): +class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -71,6 +77,10 @@ def predict(self, x): """ Predict the label of one or more examples. + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ @@ -94,12 +104,17 @@ def toDebugString(self): """ full model. """ return self._java_model.toDebugString() + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.DecisionTreeModel" + class DecisionTree(object): """ .. note:: Experimental - Learning algorithm for a decision tree model for classification or regression. + Learning algorithm for a decision tree model for classification or + regression. """ @classmethod @@ -148,14 +163,16 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) - >>> print model, # it already has newline + >>> print(model) DecisionTreeModel classifier of depth 1 with 3 nodes - >>> print model.toDebugString(), # it already has newline + + >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.0) Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 + >>> model.predict(array([1.0])) 1.0 >>> model.predict(array([0.0])) @@ -176,17 +193,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, :param data: Training data: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. + :param categoricalFeaturesInfo: Map from categorical feature + index to number of categories. + Any feature not in this map is treated as continuous. :param impurity: Supported values: "variance" :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each + node. + :param minInstancesPerNode: Min number of instances required at + child nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel @@ -216,19 +233,25 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -class RandomForestModel(TreeEnsembleModel): +@inherit_doc +class RandomForestModel(TreeEnsembleModel, JavaLoader): """ .. note:: Experimental Represents a random forest model. """ + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.RandomForestModel" + class RandomForest(object): """ .. note:: Experimental - Learning algorithm for a random forest model for classification or regression. + Learning algorithm for a random forest model for classification or + regression. """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -255,26 +278,30 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Method to train a decision tree model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels should take - values {0, 1, ..., numClasses-1}. + :param data: Training dataset: RDD of LabeledPoint. Labels + should take values {0, 1, ..., numClasses-1}. :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical features. - E.g., an entry (n -> k) indicates that feature n is categorical - with k categories indexed from 0: {0, 1, ..., k-1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that + feature n is categorical with k categories indexed + from 0: {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for splits at - each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". + :param featureSubsetStrategy: Number of features to consider for + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; - depth 1 means 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + :param maxDepth: Maximum depth of the tree. + E.g., depth 0 means 1 leaf node; depth 1 means + 1 internal node + 2 leaf nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features + (default: 32) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -293,9 +320,10 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, 3 >>> model.totalNumNodes() 7 - >>> print model, + >>> print(model) TreeEnsembleModel classifier with 3 trees - >>> print model.toDebugString(), + + >>> print(model.toDebugString()) TreeEnsembleModel classifier with 3 trees Tree 0: @@ -310,6 +338,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Predict: 0.0 Else (feature 0 > 1.0) Predict: 1.0 + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) @@ -336,19 +365,21 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + :param impurity: Criterion used for information gain + calculation. + Supported values: "variance". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features (default: 32) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -381,52 +412,66 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt featureSubsetStrategy, impurity, maxDepth, maxBins, seed) -class GradientBoostedTreesModel(TreeEnsembleModel): +@inherit_doc +class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): """ .. note:: Experimental Represents a gradient-boosted tree model. """ + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + class GradientBoostedTrees(object): """ .. note:: Experimental - Learning algorithm for a gradient boosted trees model for classification or regression. + Learning algorithm for a gradient boosted trees model for + classification or regression. """ @classmethod def _train(cls, data, algo, categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth): + loss, numIterations, learningRate, maxDepth, maxBins): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) return GradientBoostedTreesModel(model) @classmethod def trainClassifier(cls, data, categoricalFeaturesInfo, - loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, + maxBins=32): """ - Method to train a gradient-boosted trees model for classification. + Method to train a gradient-boosted trees model for + classification. - :param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}. + :param data: Training dataset: RDD of LabeledPoint. + Labels should take values {0, 1}. :param categoricalFeaturesInfo: Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient boosting. - Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. :param numIterations: Number of iterations of boosting. (default: 100) - :param learningRate: Learning rate for shrinking the contribution of each estimator. - The learning rate should be between in the interval (0, 1] - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 3) - :return: GradientBoostedTreesModel that can be used for prediction + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :param maxBins: maximum number of bins used for splitting + features (default: 32) DecisionTree requires maxBins >= max categories + :return: GradientBoostedTreesModel that can be used for + prediction Example usage: @@ -440,13 +485,14 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> - >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}) + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10) >>> model.numTrees() - 100 + 10 >>> model.totalNumNodes() - 300 - >>> print model, # it already has newline - TreeEnsembleModel classifier with 100 trees + 30 + >>> print(model) # it already has newline + TreeEnsembleModel classifier with 10 trees + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) @@ -456,11 +502,12 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, [1.0, 0.0] """ return cls._train(data, "classification", categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) @classmethod def trainRegressor(cls, data, categoricalFeaturesInfo, - loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, + maxBins=32): """ Method to train a gradient-boosted trees model for regression. @@ -470,17 +517,22 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient boosting. - Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. :param numIterations: Number of iterations of boosting. (default: 100) - :param learningRate: Learning rate for shrinking the contribution of each estimator. - The learning rate should be between in the interval (0, 1] - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 3) - :return: GradientBoostedTreesModel that can be used for prediction + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxBins: maximum number of bins used for splitting + features (default: 32) DecisionTree requires maxBins >= max categories + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction Example usage: @@ -495,11 +547,12 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {}) + >>> data = sc.parallelize(sparse_data) + >>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10) >>> model.numTrees() - 100 + 10 >>> model.totalNumNodes() - 102 + 12 >>> model.predict(SparseVector(2, {1: 1.0})) 1.0 >>> model.predict(SparseVector(2, {0: 1.0})) @@ -509,7 +562,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, [1.0, 0.0] """ return cls._train(data, "regression", categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) def _test(): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 4ed978b45409..10a1e4b3eb0f 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -15,12 +15,17 @@ # limitations under the License. # +import sys import numpy as np import warnings -from pyspark.mllib.common import callMLlibFunc +if sys.version > '3': + xrange = range + basestring = str + +from pyspark import SparkContext +from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint class MLUtils(object): @@ -50,6 +55,7 @@ def _parse_libsvm_line(line, multiclass=None): @staticmethod def _convert_labeled_point_to_libsvm(p): """Converts a LabeledPoint to a string in LIBSVM format.""" + from pyspark.mllib.regression import LabeledPoint assert isinstance(p, LabeledPoint) items = [str(p.label)] v = _convert_to_vector(p.features) @@ -92,24 +98,20 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> tempFile = NamedTemporaryFile(delete=True) - >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") + >>> _ = tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() - >>> type(examples[0]) == LabeledPoint - True - >>> print examples[0] - (1.0,(6,[0,2,4],[1.0,2.0,3.0])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (-1.0,(6,[],[])) - >>> type(examples[2]) == LabeledPoint - True - >>> print examples[2] - (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) + >>> examples[0] + LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0])) + >>> examples[1] + LabeledPoint(-1.0, (6,[],[])) + >>> examples[2] + LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0])) """ + from pyspark.mllib.regression import LabeledPoint if multiclass is not None: warnings.warn("deprecated", DeprecationWarning) @@ -130,6 +132,7 @@ def saveAsLibSVMFile(data, dir): >>> from tempfile import NamedTemporaryFile >>> from fileinput import input + >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ @@ -156,6 +159,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) @@ -167,6 +171,155 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + """ + Returns a new vector with `1.0` (bias) appended to + the end of the input vector. + """ + vec = _convert_to_vector(data) + if isinstance(vec, SparseVector): + newIndices = np.append(vec.indices, len(vec)) + newValues = np.append(vec.values, 1.0) + return SparseVector(len(vec) + 1, newIndices, newValues) + else: + return _convert_to_vector(np.append(vec.toArray(), 1.0)) + + @staticmethod + def loadVectors(sc, path): + """ + Loads vectors saved using `RDD[Vector].saveAsTextFile` + with the default number of partitions. + """ + return callMLlibFunc("loadVectors", sc, path) + + +class Saveable(object): + """ + Mixin for models and transformers which may be saved as files. + """ + + def save(self, sc, path): + """ + Save this model to the given path. + + This saves: + * human-readable (JSON) model metadata to path/metadata/ + * Parquet formatted data to path/data/ + + The model may be loaded using py:meth:`Loader.load`. + + :param sc: Spark context used to save model data. + :param path: Path specifying the directory in which to save + this model. If the directory already exists, + this method throws an exception. + """ + raise NotImplementedError + + +@inherit_doc +class JavaSaveable(Saveable): + """ + Mixin for models that provide save() through their Scala + implementation. + """ + + def save(self, sc, path): + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._java_model.save(sc._jsc.sc(), path) + + +class Loader(object): + """ + Mixin for classes which can load saved models from files. + """ + + @classmethod + def load(cls, sc, path): + """ + Load a model from the given path. The model should have been + saved using py:meth:`Saveable.save`. + + :param sc: Spark context used for loading model files. + :param path: Path specifying the directory to which the model + was saved. + :return: model instance + """ + raise NotImplemented + + +@inherit_doc +class JavaLoader(Loader): + """ + Mixin for classes which can load saved models using its Scala + implementation. + """ + + @classmethod + def _java_loader_class(cls): + """ + Returns the full class name of the Java loader. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = cls.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, cls.__name__]) + + @classmethod + def _load_java(cls, sc, path): + """ + Load a Java model from the given path. + """ + java_class = cls._java_loader_class() + java_obj = sc._jvm + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj.load(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = cls._load_java(sc, path) + return cls(java_model) + + +class LinearDataGenerator(object): + """Utils for generating linear data""" + + @staticmethod + def generateLinearInput(intercept, weights, xMean, xVariance, + nPoints, seed, eps): + """ + :param: intercept bias factor, the term c in X'w + c + :param: weights feature vector, the term w in X'w + c + :param: xMean Point around which the data X is centered. + :param: xVariance Variance of the given data + :param: nPoints Number of points to be generated + :param: seed Random Seed + :param: eps Used to scale the noise. If eps is set high, + the amount of gaussian noise added is more. + + Returns a list of LabeledPoints of length nPoints + """ + weights = [float(weight) for weight in weights] + xMean = [float(mean) for mean in xMean] + xVariance = [float(var) for var in xVariance] + return list(callMLlibFunc( + "generateLinearInputWrapper", float(intercept), weights, xMean, + xVariance, int(nPoints), int(seed), float(eps))) + + @staticmethod + def generateLinearRDD(sc, nexamples, nfeatures, eps, + nParts=2, intercept=0.0): + """ + Generate a RDD of LabeledPoints. + """ + return callMLlibFunc( + "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), + float(eps), int(nParts), float(intercept)) + def _test(): import doctest diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 4408996db079..44d17bd62947 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -84,15 +84,17 @@ class Profiler(object): >>> from pyspark import BasicProfiler >>> class MyCustomProfiler(BasicProfiler): ... def show(self, id): - ... print "My custom profiles for RDD:%s" % id + ... print("My custom profiles for RDD:%s" % id) ... >>> conf = SparkConf().set("spark.python.profile", "true") >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) - >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.parallelize(range(1000)).count() + 1000 >>> sc.show_profiles() My custom profiles for RDD:1 - My custom profiles for RDD:2 + My custom profiles for RDD:3 >>> sc.stop() """ @@ -111,9 +113,9 @@ def show(self, id): """ Print the profile stats to stdout, id is the RDD id """ stats = self.stats() if stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 + print("=" * 60) + print("Profile of RDD" % id) + print("=" * 60) stats.sort_stats("time", "cumulative").print_stats() def dump(self, id, path): @@ -169,4 +171,6 @@ def stats(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2f8a0edfe964..9ef60a7e2c84 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -16,21 +16,29 @@ # import copy -from collections import defaultdict -from itertools import chain, ifilter, imap -import operator -import os import sys +import os +import re +import operator import shlex -from subprocess import Popen, PIPE -from tempfile import NamedTemporaryFile -from threading import Thread import warnings import heapq import bisect import random +import socket +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread +from collections import defaultdict +from itertools import chain +from functools import reduce from math import sqrt, log, isinf, isnan, pow, ceil +if sys.version > '3': + basestring = unicode = str +else: + from itertools import imap as map, ifilter as filter + from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -41,7 +49,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -50,20 +58,21 @@ __all__ = ["RDD"] -# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized -# hash for string def portable_hash(x): """ - This function returns consistant hash code for builtin types, especially + This function returns consistent hash code for builtin types, especially for None and tuple with None. - The algrithm is similar to that one used by CPython 2.7 + The algorithm is similar to that one used by CPython 2.7 >>> portable_hash(None) 0 >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") + if x is None: return 0 if isinstance(x, tuple): @@ -71,7 +80,7 @@ def portable_hash(x): for i in x: h ^= portable_hash(i) h *= 1000003 - h &= sys.maxint + h &= sys.maxsize h ^= len(x) if h == -1: h = -2 @@ -111,6 +120,57 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +def _load_from_socket(port, serializer): + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + try: + sock.settimeout(3) + sock.connect(sa) + except socket.error: + sock.close() + sock = None + continue + break + if not sock: + raise Exception("could not open socket") + try: + rf = sock.makefile("rb", 65536) + for item in serializer.load_stream(rf): + yield item + finally: + sock.close() + + +def ignore_unicode_prefix(f): + """ + Ignore the 'u' prefix of string in doc tests, to make it works + in both python 2 and 3 + """ + if sys.version >= '3': + # the representation of unicode string in Python 3 does not have prefix 'u', + # so remove the prefix 'u' for doc tests + literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE) + f.__doc__ = literal_re.sub(r'\1\2', f.__doc__) + return f + + +class Partitioner(object): + def __init__(self, numPartitions, partitionFunc): + self.numPartitions = numPartitions + self.partitionFunc = partitionFunc + + def __eq__(self, other): + return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions + and self.partitionFunc == other.partitionFunc) + + def __call__(self, k): + return self.partitionFunc(k) % self.numPartitions + + class RDD(object): """ @@ -126,7 +186,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() - self._partitionFunc = None + self.partitioner = None def _pickled(self): return self._reserialize(AutoBatchedSerializer(PickleSerializer())) @@ -226,7 +286,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -241,7 +301,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -304,7 +364,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -316,15 +376,21 @@ def distinct(self, numPartitions=None): """ return self.map(lambda x: (x, None)) \ .reduceByKey(lambda x, _: x, numPartitions) \ - .map(lambda (x, _): x) + .map(lambda x: x[0]) def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD. + :param withReplacement: can elements be sampled multiple times (replaced when sampled out) + :param fraction: expected size of the sample as a fraction of this RDD's size + without replacement: probability that each element is chosen; fraction must be [0, 1] + with replacement: expected number of times each element is chosen; fraction must be >= 0 + :param seed: seed for the random number generator + >>> rdd = sc.parallelize(range(100), 4) - >>> rdd.sample(False, 0.1, 81).count() - 10 + >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 + True """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -337,12 +403,14 @@ def randomSplit(self, weights, seed=None): :param seed: random seed :return: split RDDs in a list - >>> rdd = sc.parallelize(range(5), 1) + >>> rdd = sc.parallelize(range(500), 1) >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) - >>> rdd1.collect() - [1, 3] - >>> rdd2.collect() - [0, 2, 4] + >>> len(rdd1.collect() + rdd2.collect()) + 500 + >>> 150 < rdd1.count() < 250 + True + >>> 250 < rdd2.count() < 350 + True """ s = float(sum(weights)) cweights = [0.0] @@ -385,7 +453,7 @@ def takeSample(self, withReplacement, num, seed=None): rand.shuffle(samples) return samples - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) if num > maxSampleSize: raise ValueError( "Sample size cannot be greater than %d." % maxSampleSize) @@ -399,7 +467,7 @@ def takeSample(self, withReplacement, num, seed=None): # See: scala/spark/RDD.scala while len(samples) < num: # TODO: add log warning for when more than one iteration was run - seed = rand.randint(0, sys.maxint) + seed = rand.randint(0, sys.maxsize) samples = self.sample(withReplacement, fraction, seed).collect() rand.shuffle(samples) @@ -450,14 +518,17 @@ def union(self, other): if self._jrdd_deserializer == other._jrdd_deserializer: rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, self._jrdd_deserializer) - return rdd else: # These RDDs contain data in different serialized formats, so we # must normalize them to the default serializer. self_copy = self._reserialize() other_copy = other._reserialize() - return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, - self.ctx.serializer) + rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + if (self.partitioner == other.partitioner and + self.getNumPartitions() == rdd.getNumPartitions()): + rdd.partitioner = self.partitioner + return rdd def intersection(self, other): """ @@ -473,7 +544,7 @@ def intersection(self, other): """ return self.map(lambda v: (v, None)) \ .cogroup(other.map(lambda v: (v, None))) \ - .filter(lambda (k, vs): all(vs)) \ + .filter(lambda k_vs: all(k_vs[1])) \ .keys() def _reserialize(self, serializer=None): @@ -515,7 +586,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -539,13 +610,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: if self.getNumPartitions() > 1: @@ -560,12 +631,12 @@ def sortPartition(iterator): return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) + samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect() + samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary - bounds = [samples[len(samples) * (i + 1) / numPartitions] + bounds = [samples[int(len(samples) * (i + 1) / numPartitions)] for i in range(0, numPartitions - 1)] def rangePartitioner(k): @@ -628,30 +699,47 @@ def groupBy(self, f, numPartitions=None): """ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) - def pipe(self, command, env={}): + @ignore_unicode_prefix + def pipe(self, command, env=None, checkCode=False): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() - ['1', '2', '', '3'] + [u'1', u'2', u'', u'3'] + + :param checkCode: whether or not to check the return value of the shell command. """ + if env is None: + env = dict() + def func(iterator): pipe = Popen( shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) def pipe_objs(out): for obj in iterator: - out.write(str(obj).rstrip('\n') + '\n') + s = str(obj).rstrip('\n') + '\n' + out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip('\n') for x in iter(pipe.stdout.readline, '')) + + def check_return_code(): + pipe.wait() + if checkCode and pipe.returncode: + raise Exception("Pipe function `%s' exited " + "with error code %d" % (command, pipe.returncode)) + else: + for i in range(0): + yield i + return (x.rstrip(b'\n').decode('utf-8') for x in + chain(iter(pipe.stdout.readline, b''), check_return_code())) return self.mapPartitions(func) def foreach(self, f): """ Applies a function to all elements of this RDD. - >>> def f(x): print x + >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ def processPartition(iterator): @@ -666,7 +754,7 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: - ... print x + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): @@ -682,21 +770,8 @@ def collect(self): Return a list that contains all of the elements in this RDD. """ with SCCallSiteSync(self.context) as css: - bytesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(bytesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) - tempFile.close() - self.ctx._writeToFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in self._jrdd_deserializer.load_stream(tempFile): - yield item - os.unlink(tempFile.name) + port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(port, self._jrdd_deserializer)) def reduce(self, f): """ @@ -766,13 +841,21 @@ def op(x, y): def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative function and a neutral "zero - value." + the partitions, using a given associative and commutative function and + a neutral "zero value." The function C{op(t1, t2)} is allowed to modify C{t1} and return it as its result value to avoid object allocation; however, it should not modify C{t2}. + This behaves somewhat differently from fold operations implemented + for non-distributed collections in functional languages like Scala. + This fold operation may be applied to partitions individually, and then + fold those results into the final result, rather than apply the fold + to each element sequentially in some defined ordering. For functions + that are not commutative, the result may differ from that of a fold + applied to a non-distributed collection. + >>> from operator import add >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 @@ -782,6 +865,9 @@ def func(iterator): for obj in iterator: acc = op(obj, acc) yield acc + # collecting result of mapPartitions here ensures that the copy of + # zeroValue provided to each partition is unique from the one provided + # to the final reduce call vals = self.mapPartitions(func).collect() return reduce(op, vals, zeroValue) @@ -811,8 +897,11 @@ def func(iterator): for obj in iterator: acc = seqOp(acc, obj) yield acc - - return self.mapPartitions(func).fold(zeroValue, combOp) + # collecting result of mapPartitions here ensures that the copy of + # zeroValue provided to each partition is unique from the one provided + # to the final reduce call + vals = self.mapPartitions(func).collect() + return reduce(combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ @@ -853,7 +942,7 @@ def aggregatePartition(iterator): # aggregation. while numPartitions > scale + numPartitions / scale: numPartitions /= scale - curNumPartitions = numPartitions + curNumPartitions = int(numPartitions) def mapPartition(i, iterator): for obj in iterator: @@ -905,7 +994,7 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add) def count(self): """ @@ -963,7 +1052,7 @@ def histogram(self, buckets): (('a', 'b', 'c'), [2, 2]) """ - if isinstance(buckets, (int, long)): + if isinstance(buckets, int): if buckets < 1: raise ValueError("number of buckets must be >= 1") @@ -999,6 +1088,7 @@ def minmax(a, b): raise ValueError("Can not generate buckets with infinite value") # keep them as integer if possible + inc = int(inc) if inc * buckets != maxv - minv: inc = (maxv - minv) * 1.0 / buckets @@ -1116,7 +1206,7 @@ def countPartition(iterator): yield counts def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) @@ -1176,7 +1266,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1206,7 +1296,7 @@ def takeUpToNumLeft(iterator): taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) - res = self.context.runJob(self, takeUpToNumLeft, p, True) + res = self.context.runJob(self, takeUpToNumLeft, p) items += res partsScanned += numPartsToTry @@ -1239,7 +1329,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1357,8 +1447,8 @@ def saveAsPickleFile(self, path, batchSize=10): >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() >>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3) - >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) - [1, 2, 'rdd', 'spark'] + >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect()) + ['1', '2', 'rdd', 'spark'] """ if batchSize == 0: ser = AutoBatchedSerializer(PickleSerializer()) @@ -1366,10 +1456,15 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) - def saveAsTextFile(self, path): + @ignore_unicode_prefix + def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. + @param path: path to text file + @param compressionCodecClass: (None by default) string i.e. + "org.apache.hadoop.io.compress.GzipCodec" + >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) @@ -1385,17 +1480,32 @@ def saveAsTextFile(self, path): >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name) >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*")))) '\\n\\n\\nbar\\nfoo\\n' + + Using compressionCodecClass + + >>> tempFile3 = NamedTemporaryFile(delete=True) + >>> tempFile3.close() + >>> codec = "org.apache.hadoop.io.compress.GzipCodec" + >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec) + >>> from fileinput import input, hook_compressed + >>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)) + >>> b''.join(result).decode('utf-8') + u'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: - if not isinstance(x, basestring): + if not isinstance(x, (unicode, bytes)): x = unicode(x) if isinstance(x, unicode): x = x.encode("utf-8") yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) + if compressionCodecClass: + compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass) + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec) + else: + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) # Pair functions @@ -1419,7 +1529,7 @@ def keys(self): >>> m.collect() [1, 3] """ - return self.map(lambda (k, v): k) + return self.map(lambda x: x[0]) def values(self): """ @@ -1429,7 +1539,7 @@ def values(self): >>> m.collect() [2, 4] """ - return self.map(lambda (k, v): v) + return self.map(lambda x: x[1]) def reduceByKey(self, func, numPartitions=None): """ @@ -1468,7 +1578,7 @@ def reducePartition(iterator): yield m def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1565,11 +1675,14 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) >>> sets = pairs.partitionBy(2).glom().collect() - >>> set(sets[0]).intersection(set(sets[1])) - set([]) + >>> len(set(sets[0]).intersection(set(sets[1]))) + 0 """ if numPartitions is None: numPartitions = self._defaultReducePartitions() + partitioner = Partitioner(numPartitions, partitionFunc) + if self.partitioner == partitioner: + return self # Transferring O(n) objects to Java is too expensive. # Instead, we'll form the hash buckets in Python, @@ -1595,37 +1708,35 @@ def add_shuffle_key(split, iterator): if (c % 1000 == 0 and get_used_memory() > limit or c > batch): n, size = len(buckets), 0 - for split in buckets.keys(): + for split in list(buckets.keys()): yield pack_long(split) d = outputSerializer.dumps(buckets[split]) del buckets[split] yield d size += len(d) - avg = (size / n) >> 20 + avg = int(size / n) >> 20 # let 1M < avg < 10M if avg < 1: batch *= 1.5 elif avg > 10: - batch = max(batch / 1.5, 1) + batch = max(int(batch / 1.5), 1) c = 0 - for split, items in buckets.iteritems(): + for split, items in buckets.items(): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = self.mapPartitionsWithIndex(add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True) keyed._bypass_serializer = True with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) - jrdd = pairRDD.partitionBy(partitioner).values() + jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) + jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner)) rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) - # This is required so that id(partitionFunc) remains unique, - # even if partitionFunc is a lambda: - rdd._partitionFunc = partitionFunc + rdd.partitioner = partitioner return rdd # TODO: add control over map-side aggregation @@ -1659,28 +1770,26 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) - return merger.iteritems() + return merger.items() - locally_combined = self.mapPartitions(combineLocally) + locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) - return merger.iteritems() + return merger.items() - return shuffled.mapPartitions(_mergeCombiners, True) + return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ @@ -1707,7 +1816,7 @@ def foldByKey(self, zeroValue, func, numPartitions=None): >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add - >>> rdd.foldByKey(0, add).collect() + >>> sorted(rdd.foldByKey(0, add).collect()) [('a', 2), ('b', 1)] """ def createZero(): @@ -1715,21 +1824,28 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + def _can_spill(self): + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _memory_limit(self): + return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions. + Hash-partitions the resulting RDD with numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.groupByKey().mapValues(len).collect()) + [('a', 2), ('b', 1)] + >>> sorted(rdd.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): return [x] @@ -1741,8 +1857,27 @@ def mergeCombiners(a, b): a.extend(b) return a - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + spill = self._can_spill() + memory = self._memory_limit() + serializer = self._jrdd_deserializer + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + + def combine(iterator): + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.items() + + locally_combined = self.mapPartitions(combine, preservesPartitioning=True) + shuffled = locally_combined.partitionBy(numPartitions) + + def groupByKey(it): + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.items() + + return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) def flatMapValues(self, f): """ @@ -1755,7 +1890,7 @@ def flatMapValues(self, f): >>> x.flatMapValues(f).collect() [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')] """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def mapValues(self, f): @@ -1769,7 +1904,7 @@ def mapValues(self, f): >>> x.mapValues(f).collect() [('a', 3), ('b', 1)] """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def groupWith(self, other, *others): @@ -1780,8 +1915,7 @@ def groupWith(self, other, *others): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) >>> z = sc.parallelize([("b", 42)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ - sorted(list(w.groupWith(x, y, z).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))] [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] """ @@ -1796,7 +1930,7 @@ def cogroup(self, other, numPartitions=None): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))] [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup((self, other), numPartitions) @@ -1832,8 +1966,9 @@ def subtractByKey(self, other, numPartitions=None): >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] """ - def filter_func((key, vals)): - return vals[0] and not vals[1] + def filter_func(pair): + key, (val1, val2) = pair + return val1 and not val2 return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) def subtract(self, other, numPartitions=None): @@ -1855,8 +1990,8 @@ def keyBy(self, f): >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) >>> y = sc.parallelize(zip(range(0,5), range(0,5))) - >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect())) - [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] + >>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())] + [(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])] """ return self.map(lambda x: (f(x), x)) @@ -1915,7 +2050,7 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: + if my_batch != other_batch or not my_batch: # use the smallest batchSize for both of them batchSize = min(my_batch, other_batch) if batchSize <= 0: @@ -1985,17 +2120,18 @@ def name(self): """ Return the name of this RDD. """ - name_ = self._jrdd.name() - if name_: - return name_.encode('utf-8') + n = self._jrdd.name() + if n: + return n + @ignore_unicode_prefix def setName(self, name): """ Assign a name to this RDD. - >>> rdd1 = sc.parallelize([1,2]) + >>> rdd1 = sc.parallelize([1, 2]) >>> rdd1.setName('RDD1').name() - 'RDD1' + u'RDD1' """ self._jrdd.setName(name) return self @@ -2057,10 +2193,10 @@ def lookup(self, key): >>> sorted.lookup(1024) [] """ - values = self.filter(lambda (k, v): k == key).values() + values = self.filter(lambda kv: kv[0] == key).values() - if self._partitionFunc is not None: - return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False) + if self.partitioner is not None: + return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)]) return values.collect() @@ -2076,6 +2212,7 @@ def _to_java_object_rdd(self): def countApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate version of count() that returns a potentially incomplete result within a timeout, even if not all tasks have finished. @@ -2089,12 +2226,13 @@ def countApprox(self, timeout, confidence=0.95): def sumApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the sum within a timeout or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) - >>> (rdd.sumApprox(1000) - r) / r < 0.05 + >>> r = sum(range(1000)) + >>> abs(rdd.sumApprox(1000) - r) / r < 0.05 True """ jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() @@ -2105,12 +2243,13 @@ def sumApprox(self, timeout, confidence=0.95): def meanApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the mean within a timeout or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) / 1000.0 - >>> (rdd.meanApprox(1000) - r) / r < 0.05 + >>> r = sum(range(1000)) / 1000.0 + >>> abs(rdd.meanApprox(1000) - r) / r < 0.05 True """ jrdd = self.map(float)._to_java_object_rdd() @@ -2121,6 +2260,7 @@ def meanApprox(self, timeout, confidence=0.95): def countApproxDistinct(self, relativeSD=0.05): """ .. note:: Experimental + Return approximate number of distinct elements in the RDD. The algorithm used is based on streamlib's implementation of @@ -2133,16 +2273,14 @@ def countApproxDistinct(self, relativeSD=0.05): It must be greater than 0.000017. >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct() - >>> 950 < n < 1050 + >>> 900 < n < 1100 True >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct() - >>> 18 < n < 22 + >>> 16 < n < 24 True """ if relativeSD < 0.000017: raise ValueError("relativeSD should be greater than 0.000017") - if relativeSD > 0.37: - raise ValueError("relativeSD should be smaller than 0.37") # the hash space in Java is 2^32 hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) @@ -2155,13 +2293,32 @@ def toLocalIterator(self): >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - partitions = xrange(self.getNumPartitions()) - for partition in partitions: + for partition in range(self.getNumPartitions()): rows = self.context.runJob(self, lambda x: x, [partition]) for row in rows: yield row +def _prepare_for_python_RDD(sc, command, obj=None): + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + # There is a bug in py4j.java_gateway.JavaClass with auto_convert + # https://github.com/bartdag/py4j/issues/161 + # TODO: use auto_convert once py4j fix the bug + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + return pickled_command, broadcast_vars, env, includes + + class PipelinedRDD(RDD): """ @@ -2206,13 +2363,10 @@ def pipeline_func(split, iterator): self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False - self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None - self._broadcast = None + self.partitioner = prev.partitioner if self.preservesPartitioning else None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): @@ -2228,25 +2382,12 @@ def _jrdd(self): command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - self._broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(self._broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx._gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - env = MapConverter().convert(self.ctx.environment, - self.ctx._gateway._gateway_client) - includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self.preservesPartitioning, - self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator) + self.ctx.pythonExec, self.ctx.pythonVer, + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() if profiler: diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 459e1427803c..fe8f87324804 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -23,7 +23,7 @@ class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - self._seed = seed if seed is not None else random.randint(0, sys.maxint) + self._seed = seed if seed is not None else random.randint(0, sys.maxsize) self._withReplacement = withReplacement self._random = None @@ -31,7 +31,7 @@ def initRandomGenerator(self, split): self._random = random.Random(self._seed ^ split) # mixing because the initial seeds are close to each other - for _ in xrange(10): + for _ in range(10): self._random.randint(0, 1) def getUniformSample(self): diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index ef04c82866e6..1ab5ce14c353 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,15 +15,16 @@ # limitations under the License. # -__all__ = ["ResultIterable"] - import collections +__all__ = ["ResultIterable"] + class ResultIterable(collections.Iterable): """ - A special result iterable. This is used because the standard iterator can not be pickled + A special result iterable. This is used because the standard + iterator can not be pickled """ def __init__(self, data): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0ffb41d02f6f..411b4dbf481f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -44,21 +44,29 @@ >>> rdd.glom().collect() [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -8L +>>> int(rdd._jrdd.count()) +8 >>> sc.stop() """ -import cPickle -from itertools import chain, izip, product +import sys +from itertools import chain, product import marshal import struct -import sys import types import collections import zlib import itertools +if sys.version < '3': + import cPickle as pickle + protocol = 2 + from itertools import izip as zip +else: + import pickle + protocol = 3 + xrange = range + from pyspark import cloudpickle @@ -97,7 +105,7 @@ def _load_stream_without_unbatching(self, stream): # subclasses should override __eq__ as appropriate. def __eq__(self, other): - return isinstance(other, self.__class__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -212,14 +220,33 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): return self.serializer.load_stream(stream) - def __eq__(self, other): - return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer and other.batchSize == self.batchSize) - def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) +class FlattenedValuesSerializer(BatchedSerializer): + + """ + Serializes a stream of list of pairs, split the list of values + which contain more than a certain number of objects to make them + have similar sizes. + """ + def __init__(self, serializer, batchSize=10): + BatchedSerializer.__init__(self, serializer, batchSize) + + def _batched(self, iterator): + n = self.batchSize + for key, values in iterator: + for i in range(0, len(values), n): + yield key, values[i:i + n] + + def load_stream(self, stream): + return self.serializer.load_stream(stream) + + def __repr__(self): + return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize) + + class AutoBatchedSerializer(BatchedSerializer): """ Choose the size of batch automatically based on the size of object @@ -245,14 +272,10 @@ def dump_stream(self, iterator, stream): if size < best: batch *= 2 elif size > best * 10 and batch > 1: - batch /= 2 + batch //= 2 - def __eq__(self, other): - return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer and other.bestSize == self.bestSize) - - def __str__(self): - return "AutoBatchedSerializer(%s)" % str(self.serializer) + def __repr__(self): + return "AutoBatchedSerializer(%s)" % self.serializer class CartesianDeserializer(FramedSerializer): @@ -262,6 +285,7 @@ class CartesianDeserializer(FramedSerializer): """ def __init__(self, key_ser, val_ser): + FramedSerializer.__init__(self) self.key_ser = key_ser self.val_ser = val_ser @@ -270,7 +294,7 @@ def prepare_keys_values(self, stream): val_stream = self.val_ser._load_stream_without_unbatching(stream) key_is_batched = isinstance(self.key_ser, BatchedSerializer) val_is_batched = isinstance(self.val_ser, BatchedSerializer) - for (keys, vals) in izip(key_stream, val_stream): + for (keys, vals) in zip(key_stream, val_stream): keys = keys if key_is_batched else [keys] vals = vals if val_is_batched else [vals] yield (keys, vals) @@ -280,10 +304,6 @@ def load_stream(self, stream): for pair in product(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, CartesianDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -295,22 +315,14 @@ class PairDeserializer(CartesianDeserializer): Deserializes the JavaRDD zip() of two PythonRDDs. """ - def __init__(self, key_ser, val_ser): - self.key_ser = key_ser - self.val_ser = val_ser - def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): if len(keys) != len(vals): raise ValueError("Can not deserialize RDD with different number of items" " in pair: (%d, %d)" % (len(keys), len(vals))) - for pair in izip(keys, vals): + for pair in zip(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, PairDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) @@ -359,8 +371,8 @@ def _hijack_namedtuple(): global _old_namedtuple # or it will put in closure def _copy_func(f): - return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + return types.FunctionType(f.__code__, f.__globals__, f.__name__, + f.__defaults__, f.__closure__) _old_namedtuple = _copy_func(collections.namedtuple) @@ -369,15 +381,15 @@ def namedtuple(*args, **kwargs): return _hack_namedtuple(cls) # replace namedtuple with new one - collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple - collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple - collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, # so only hack those in __main__ module - for n, o in sys.modules["__main__"].__dict__.iteritems(): + for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple and hasattr(o, "_fields") and "__reduce__" not in o.__dict__): @@ -390,7 +402,7 @@ def namedtuple(*args, **kwargs): class PickleSerializer(FramedSerializer): """ - Serializes objects using Python's cPickle serializer: + Serializes objects using Python's pickle serializer: http://docs.python.org/2/library/pickle.html @@ -399,10 +411,14 @@ class PickleSerializer(FramedSerializer): """ def dumps(self, obj): - return cPickle.dumps(obj, 2) + return pickle.dumps(obj, protocol) - def loads(self, obj): - return cPickle.loads(obj) + if sys.version >= '3': + def loads(self, obj, encoding="bytes"): + return pickle.loads(obj, encoding=encoding) + else: + def loads(self, obj, encoding=None): + return pickle.loads(obj) class CloudPickleSerializer(PickleSerializer): @@ -431,7 +447,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol automatically + Choose marshal or pickle as serialization protocol automatically """ def __init__(self): @@ -440,19 +456,19 @@ def __init__(self): def dumps(self, obj): if self._type is not None: - return 'P' + cPickle.dumps(obj, -1) + return b'P' + pickle.dumps(obj, -1) try: - return 'M' + marshal.dumps(obj) + return b'M' + marshal.dumps(obj) except Exception: - self._type = 'P' - return 'P' + cPickle.dumps(obj, -1) + self._type = b'P' + return b'P' + pickle.dumps(obj, -1) def loads(self, obj): _type = obj[0] - if _type == 'M': + if _type == b'M': return marshal.loads(obj[1:]) - elif _type == 'P': - return cPickle.loads(obj[1:]) + elif _type == b'P': + return pickle.loads(obj[1:]) else: raise ValueError("invalid sevialization type: %s" % _type) @@ -472,8 +488,8 @@ def dumps(self, obj): def loads(self, obj): return self.serializer.loads(zlib.decompress(obj)) - def __eq__(self, other): - return isinstance(other, CompressedSerializer) and self.serializer == other.serializer + def __repr__(self): + return "CompressedSerializer(%s)" % self.serializer class UTF8Deserializer(Serializer): @@ -482,7 +498,7 @@ class UTF8Deserializer(Serializer): Deserializes streams written by String.getBytes. """ - def __init__(self, use_unicode=False): + def __init__(self, use_unicode=True): self.use_unicode = use_unicode def loads(self, stream): @@ -503,13 +519,13 @@ def load_stream(self, stream): except EOFError: return - def __eq__(self, other): - return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode + def __repr__(self): + return "UTF8Deserializer(%s)" % self.use_unicode def read_long(stream): length = stream.read(8) - if length == "": + if not length: raise EOFError return struct.unpack("!q", length)[0] @@ -524,7 +540,7 @@ def pack_long(value): def read_int(stream): length = stream.read(4) - if length == "": + if not length: raise EOFError return struct.unpack("!i", length)[0] @@ -540,4 +556,6 @@ def write_with_length(obj, stream): if __name__ == '__main__': import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 89cf76920e35..99331297c19f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -21,30 +21,40 @@ This file is designed to be launched as a PYTHONSTARTUP script. """ -import sys -if sys.version_info[0] != 2: - print("Error: Default Python used is Python%s" % sys.version_info.major) - print("\tSet env variable PYSPARK_PYTHON to Python2 binary and re-run it.") - sys.exit(1) - - import atexit import os import platform + +import py4j + import pyspark from pyspark.context import SparkContext +from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the equivalent of ADD_JARS -add_files = (os.environ.get("ADD_FILES").split(',') - if os.environ.get("ADD_FILES") is not None else None) +# this is the deprecated equivalent of ADD_JARS +add_files = None +if os.environ.get("ADD_FILES") is not None: + add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext(appName="PySparkShell", pyFiles=add_files) +sc = SparkContext(pyFiles=add_files) atexit.register(lambda: sc.stop()) +try: + # Try to access HiveConf, it will raise exception if Hive is not added + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + sqlContext = HiveContext(sc) +except py4j.protocol.Py4JError: + sqlContext = SQLContext(sc) +except TypeError: + sqlContext = SQLContext(sc) + +# for compatibility +sqlCtx = sqlContext + print("""Welcome to ____ __ / __/__ ___ _____/ /__ @@ -56,9 +66,10 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc.") +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: + print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") print("Adding files: [%s]" % ", ".join(add_files)) # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 10a7ccd50200..b8118bdb7ca7 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,28 +16,35 @@ # import os -import sys import platform import shutil import warnings import gc import itertools +import operator import random import pyspark.heapq3 as heapq -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ + CompressedSerializer, AutoBatchedSerializer + try: import psutil + process = None + def get_used_memory(): """ Return the used memory in MB """ - process = psutil.Process(os.getpid()) + global process + if process is None or process._pid != os.getpid(): + process = psutil.Process(os.getpid()) if hasattr(process, "memory_info"): info = process.memory_info() else: info = process.get_memory_info() return info.rss >> 20 + except ImportError: def get_used_memory(): @@ -46,6 +53,7 @@ def get_used_memory(): for line in open('/proc/self/status'): if line.startswith('VmRSS:'): return int(line.split()[1]) >> 10 + else: warnings.warn("Please install psutil to have better " "support with spilling") @@ -54,6 +62,7 @@ def get_used_memory(): rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss return rss >> 20 # TODO: support windows + return 0 @@ -69,8 +78,8 @@ def _get_local_dirs(sub): # global stats -MemoryBytesSpilled = 0L -DiskBytesSpilled = 0L +MemoryBytesSpilled = 0 +DiskBytesSpilled = 0 class Aggregator(object): @@ -117,7 +126,7 @@ def mergeCombiners(self, iterator): """ Merge the combined items by mergeCombiner """ raise NotImplementedError - def iteritems(self): + def items(self): """ Return the merged items ad iterator """ raise NotImplementedError @@ -147,9 +156,15 @@ def mergeCombiners(self, iterator): for k, v in iterator: d[k] = comb(d[k], v) if k in d else v - def iteritems(self): + def items(self): """ Return the merged items ad iterator """ - return self.data.iteritems() + return iter(self.data.items()) + + +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) class ExternalMerger(Merger): @@ -173,7 +188,7 @@ class ExternalMerger(Merger): dict. Repeat this again until combine all the items. - Before return any items, it will load each partition and - combine them seperately. Yield them before loading next + combine them separately. Yield them before loading next partition. - During loading a partition, if the memory goes over limit, @@ -182,7 +197,7 @@ class ExternalMerger(Merger): `data` and `pdata` are used to hold the merged items in memory. At first, all the data are merged into `data`. Once the used - memory goes over limit, the items in `data` are dumped indo + memory goes over limit, the items in `data` are dumped into disks, `data` will be cleared, all rest of items will be merged into `pdata` and then dumped into disks. Before returning, all the items in `pdata` will be dumped into disks. @@ -193,16 +208,16 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeValues(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) - 499950000 + >>> sum(v for k,v in merger.items()) + 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeCombiners(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) - 499950000 + >>> sum(v for k,v in merger.items()) + 49995000 """ # the max total partitions created recursively @@ -212,8 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - # default serializer is only used for tests - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -221,7 +235,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, self.batch = batch # scale is used to scale down the hash of key for recursive hash map self.scale = scale - # unpartitioned merged data + # un-partitioned merged data self.data = {} # partitioned merged data, list of dicts self.pdata = [] @@ -244,72 +258,63 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if get_used_memory() >= limit: + self._spill() def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + def _object_size(self, obj): + """ How much of memory for this obj, assume that all the objects + consume similar bytes of memory + """ + return 1 - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + c += objsize(v) + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if limit and get_used_memory() >= limit: + self._spill() def _spill(self): """ @@ -330,12 +335,12 @@ def _spill(self): # above limit at the first time. # open all the files for writing - streams = [open(os.path.join(path, str(i)), 'w') + streams = [open(os.path.join(path, str(i)), 'wb') for i in range(self.partitions)] - for k, v in self.data.iteritems(): + for k, v in self.data.items(): h = self._partition(k) - # put one item in batch, make it compatitable with load_stream + # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch self.serializer.dump_stream([(k, v)], streams[h]) @@ -344,25 +349,25 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): p = os.path.join(path, str(i)) - with open(p, "w") as f: + with open(p, "wb") as f: # dump items in batch - self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.serializer.dump_stream(iter(self.pdata[i].items()), f) self.pdata[i].clear() DiskBytesSpilled += os.path.getsize(p) self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 - def iteritems(self): + def items(self): """ Return all merged items as iterator """ if not self.pdata and not self.spills: - return self.data.iteritems() + return iter(self.data.items()) return self._external_items() def _external_items(self): @@ -370,29 +375,12 @@ def _external_items(self): assert not self.data if any(self.pdata): self._spill() - hard_limit = self._next_limit() + # disable partitioning and spilling when merge combiners from disk + self.pdata = [] try: for i in range(self.partitions): - self.data = {} - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) - - # limit the total partitions - if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS - and j < self.spills - 1 - and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible - for v in self._recursive_merged_items(i): - yield v - return - - for v in self.data.iteritems(): + for v in self._merged_items(i): yield v self.data.clear() @@ -400,53 +388,58 @@ def _external_items(self): for j in range(self.spills): path = self._get_spill_dir(j) os.remove(os.path.join(path, str(i))) - finally: self._cleanup() - def _cleanup(self): - """ Clean up all the files in disks """ - for d in self.localdirs: - shutil.rmtree(d, True) - - def _recursive_merged_items(self, start): + def _merged_items(self, index): + self.data = {} + limit = self._next_limit() + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + return self._recursive_merged_items(index) + + return self.data.items() + + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions, self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + with open(p, 'rb') as f: + m.mergeCombiners(self.serializer.load_stream(f), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - for v in m._external_items(): - yield v + return m._external_items() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) class ExternalSorter(object): @@ -457,9 +450,10 @@ class ExternalSorter(object): The spilling will only happen when the used memory goes above the limit. + >>> sorter = ExternalSorter(1) # 1M >>> import random - >>> l = range(1024) + >>> l = list(range(1024)) >>> random.shuffle(l) >>> sorted(l) == list(sorter.sorted(l)) True @@ -469,7 +463,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -503,21 +497,27 @@ def sorted(self, iterator, key=None, reverse=False): break used_memory = get_used_memory() - if used_memory > self.memory_limit: + if used_memory > limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) - with open(path, 'w') as f: + with open(path, 'wb') as f: self.serializer.dump_stream(current_chunk, f) - chunks.append(self.serializer.load_stream(open(path))) + + def load(f): + for v in self.serializer.load_stream(f): + yield v + # close the file explicit once we consume all the items + # to avoid ResourceWarning in Python3 + f.close() + chunks.append(load(open(path, 'rb'))) current_chunk = [] - gc.collect() - limit = self._next_limit() - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) + os.unlink(path) # data will be deleted after close elif not chunks: - batch = min(batch * 2, 10000) + batch = min(int(batch * 1.5), 10000) current_chunk.sort(key=key, reverse=reverse) if not chunks: @@ -529,6 +529,315 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class ExternalList(object): + """ + ExternalList can have many items which cannot be hold in memory in + the same time. + + >>> l = ExternalList(list(range(100))) + >>> len(l) + 100 + >>> l.append(10) + >>> len(l) + 101 + >>> for i in range(20240): + ... l.append(i) + >>> len(l) + 20341 + >>> import pickle + >>> l2 = pickle.loads(pickle.dumps(l)) + >>> len(l2) + 20341 + >>> list(l2)[100] + 10 + """ + LIMIT = 10240 + + def __init__(self, values): + self.values = values + self.count = len(values) + self._file = None + self._ser = None + + def __getstate__(self): + if self._file is not None: + self._file.flush() + with os.fdopen(os.dup(self._file.fileno()), "rb") as f: + f.seek(0) + serialized = f.read() + else: + serialized = b'' + return self.values, self.count, serialized + + def __setstate__(self, item): + self.values, self.count, serialized = item + if serialized: + self._open_file() + self._file.write(serialized) + else: + self._file = None + self._ser = None + + def __iter__(self): + if self._file is not None: + self._file.flush() + # read all items from disks first + with os.fdopen(os.dup(self._file.fileno()), 'rb') as f: + f.seek(0) + for v in self._ser.load_stream(f): + yield v + + for v in self.values: + yield v + + def __len__(self): + return self.count + + def append(self, value): + self.values.append(value) + self.count += 1 + # dump them into disk if the key is huge + if len(self.values) >= self.LIMIT: + self._spill() + + def _open_file(self): + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id(self))) + self._file = open(p, "w+b", 65536) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + os.unlink(p) + + def __del__(self): + if self._file: + self._file.close() + self._file = None + + def _spill(self): + """ dump the values into disk """ + global MemoryBytesSpilled, DiskBytesSpilled + if self._file is None: + self._open_file() + + used_memory = get_used_memory() + pos = self._file.tell() + self._ser.dump_stream(self.values, self._file) + self.values = [] + gc.collect() + DiskBytesSpilled += self._file.tell() - pos + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 + + +class ExternalListOfList(ExternalList): + """ + An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 + """ + + def __init__(self, values): + ExternalList.__init__(self, values) + self.count = sum(len(i) for i in values) + + def append(self, value): + ExternalList.append(self, value) + # already counted 1 in ExternalList.append + self.count += len(value) - 1 + + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + + +class GroupByKey(object): + """ + Group a sorted iterator as [(k1, it1), (k2, it2), ...] + + >>> k = [i // 3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(zip(k, v)) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + + def __init__(self, iterator): + self.iterator = iterator + + def __iter__(self): + key, values = None, None + for k, v in self.iterator: + if values is not None and k == key: + values.append(v) + else: + if values is not None: + yield (key, values) + key = k + values = ExternalListOfList([v]) + if values is not None: + yield (key, values) + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + + This class works as follows: + + - It repeatedly group the items by key and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. If the number of keys + in one partitions is smaller than 1000, it will sort them + by key before dumping into disk. + + - Then it goes through the rest of the iterator, group items + by key into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. It + also will try to sort the items by key in each partition + before dumping into disks. + + - It will yield the grouped items partitions by partitions. + If the data in one partitions can be hold in memory, then it + will load and combine them in memory and yield. + + - If the dataset in one partition cannot be hold in memory, + it will sort them first. If all the files are already sorted, + it merge them by heap.merge(), so it will do external sort + for all the files. + + - After sorting, `GroupByKey` class will put all the continuous + items with the same key as a group, yield the values as + an iterator. + """ + SORT_KEY_LIMIT = 1000 + + def flattened_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) + ser = self.serializer + return FlattenedValuesSerializer(ser, 20) + + def _object_size(self, obj): + return len(obj) + + def _spill(self): + """ + dump already partitioned data into disks. + """ + global MemoryBytesSpilled, DiskBytesSpilled + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + used_memory = get_used_memory() + if not self.pdata: + # The data has not been partitioned, it will iterator the + # data once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'wb') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + self.serializer = self.flattened_serializer() + for k in sorted(self.data.keys()): + h = self._partition(k) + self.serializer.dump_stream([(k, self.data[k])], streams[h]) + else: + for k, v in self.data.items(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + DiskBytesSpilled += s.tell() + s.close() + + self.data.clear() + # self.pdata is cached in `mergeValues` and `mergeCombiners` + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "wb") as f: + # dump items in batch + if self._sorted: + # sort by key only (stable) + sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0)) + self.serializer.dump_stream(sorted_items, f) + else: + self.serializer.dump_stream(self.pdata[i].items(), f) + self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) + + self.spills += 1 + gc.collect() # release the memory as much as possible + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 + + def _merged_items(self, index): + size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) + for j in range(self.spills)) + # if the memory can not hold all the partition, + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if size >= self.memory_limit << 17: # * 1M / 8 + return self._merge_sorted_items(index) + + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) + return self.data.items() + + def _merge_sorted_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + with open(p, 'rb', 65536) as f: + for v in self.serializer.load_stream(f): + yield v + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self.flattened_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) + + if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py deleted file mode 100644 index 32bff0c7e8c5..000000000000 --- a/python/pyspark/sql.py +++ /dev/null @@ -1,2600 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -public classes of Spark SQL: - - - L{SQLContext} - Main entry point for SQL functionality. - - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedDataFrame} - - L{Column} - Column is a DataFrame with a single column. - - L{Row} - A Row of data returned by a Spark SQL query. - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. -""" - -import sys -import itertools -import decimal -import datetime -import keyword -import warnings -import json -import re -import random -import os -from tempfile import NamedTemporaryFile -from array import array -from operator import itemgetter -from itertools import imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer, UTF8Deserializer -from pyspark.storagelevel import StorageLevel -from pyspark.traceback_utils import SCCallSiteSync - - -__all__ = [ - "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", - "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", - "SchemaRDD"] - - -class DataType(object): - - """Spark SQL DataType""" - - def __repr__(self): - return self.__class__.__name__ - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - @classmethod - def typeName(cls): - return cls.__name__[:-4].lower() - - def jsonValue(self): - return self.typeName() - - def json(self): - return json.dumps(self.jsonValue(), - separators=(',', ':'), - sort_keys=True) - - -class PrimitiveTypeSingleton(type): - - """Metaclass for PrimitiveType""" - - _instances = {} - - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() - return cls._instances[cls] - - -class PrimitiveType(DataType): - - """Spark SQL PrimitiveType""" - - __metaclass__ = PrimitiveTypeSingleton - - def __eq__(self, other): - # because they should be the same object - return self is other - - -class NullType(PrimitiveType): - - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. - """ - - -class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. - """ - - -class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. - """ - - -class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. - """ - - -class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. - """ - - -class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. - """ - - -class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. - """ - - def __init__(self, precision=None, scale=None): - self.precision = precision - self.scale = scale - self.hasPrecisionInfo = precision is not None - - def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" - - def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" - - -class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. - """ - - -class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. - """ - - -class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. - """ - - -class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. - """ - - -class LongType(PrimitiveType): - - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. - """ - - -class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. - """ - - -class ArrayType(DataType): - - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - - """ - - def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - - >>> ArrayType(StringType) == ArrayType(StringType, True) - True - >>> ArrayType(StringType, False) == ArrayType(StringType) - False - """ - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self): - return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull} - - @classmethod - def fromJson(cls, json): - return ArrayType(_parse_datatype_json_value(json["elementType"]), - json["containsNull"]) - - -class MapType(DataType): - - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. - - """ - - def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) - True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) - False - """ - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self): - return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull} - - @classmethod - def fromJson(cls, json): - return MapType(_parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), - json["valueContainsNull"]) - - -class StructField(DataType): - - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - - """ - - def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) - True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) - False - """ - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self): - return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) - - def jsonValue(self): - return {"name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata} - - @classmethod - def fromJson(cls, json): - return StructField(json["name"], - _parse_datatype_json_value(json["type"]), - json["nullable"], - json["metadata"]) - - -class StructType(DataType): - - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - - """ - - def __init__(self, fields): - """Creates a StructType - - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) - >>> struct1 == struct2 - True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) - >>> struct1 == struct2 - False - """ - self.fields = fields - - def __repr__(self): - return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) - - def jsonValue(self): - return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} - - @classmethod - def fromJson(cls, json): - return StructType([StructField.fromJson(f) for f in json["fields"]]) - - -class UserDefinedType(DataType): - """ - .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). - """ - - @classmethod - def typeName(cls): - return cls.__name__.lower() - - @classmethod - def sqlType(cls): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def module(cls): - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") - - @classmethod - def scalaUDT(cls): - """ - The class name of the paired Scala UDT. - """ - raise NotImplementedError("UDT must have a paired Scala UDT.") - - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") - - def deserialize(self, datum): - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement deserialize().") - - def json(self): - return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - - def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } - return schema - - @classmethod - def fromJson(cls, json): - pyUDT = json["pyClass"] - split = pyUDT.rfind(".") - pyModule = pyUDT[:split] - pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) - UDT = getattr(m, pyClass) - return UDT() - - def __eq__(self, other): - return type(self) == type(other) - - -_all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - - -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) - - -def _parse_datatype_json_string(json_string): - """Parses the given data type JSON string. - >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) - ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True - >>> # Simple ArrayType. - >>> simple_arraytype = ArrayType(StringType(), True) - >>> check_datatype(simple_arraytype) - True - >>> # Simple MapType. - >>> simple_maptype = MapType(StringType(), LongType()) - >>> check_datatype(simple_maptype) - True - >>> # Simple StructType. - >>> simple_structtype = StructType([ - ... StructField("a", DecimalType(), False), - ... StructField("b", BooleanType(), True), - ... StructField("c", LongType(), True), - ... StructField("d", BinaryType(), False)]) - >>> check_datatype(simple_structtype) - True - >>> # Complex StructType. - >>> complex_structtype = StructType([ - ... StructField("simpleArray", simple_arraytype, True), - ... StructField("simpleMap", simple_maptype, True), - ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False), - ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) - >>> check_datatype(complex_structtype) - True - >>> # Complex ArrayType. - >>> complex_arraytype = ArrayType(complex_structtype, True) - >>> check_datatype(complex_arraytype) - True - >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, - ... complex_arraytype, False) - >>> check_datatype(complex_maptype) - True - >>> check_datatype(ExamplePointUDT()) - True - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) - True - """ - return _parse_datatype_json_value(json.loads(json_string)) - - -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - -def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() - elif json_value == u'decimal': - return DecimalType() - elif _FIXED_DECIMAL.match(json_value): - m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % json_value) - else: - tpe = json_value["type"] - if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) - elif tpe == 'udt': - return UserDefinedType.fromJson(json_value) - else: - raise ValueError("not supported type: %s" % tpe) - - -# Mapping Python types to Spark SQL DataType -_type_mappings = { - type(None): NullType, - bool: BooleanType, - int: IntegerType, - long: LongType, - float: DoubleType, - str: StringType, - unicode: StringType, - bytearray: BinaryType, - decimal.Decimal: DecimalType, - datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, -} - - -def _infer_type(obj): - """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT - """ - if obj is None: - raise ValueError("Can not infer type for None") - - if hasattr(obj, '__UDT__'): - return obj.__UDT__ - - dataType = _type_mappings.get(type(obj)) - if dataType is not None: - return dataType() - - if isinstance(obj, dict): - for key, value in obj.iteritems(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - else: - return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): - for v in obj: - if v is not None: - return ArrayType(_infer_type(obj[0]), True) - else: - return ArrayType(NullType(), True) - else: - try: - return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) - - -def _infer_schema(row): - """Infer the schema from dict/namedtuple/object""" - if isinstance(row, dict): - items = sorted(row.items()) - - elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row - else: - raise ValueError("Can't infer schema from tuple") - - elif hasattr(row, "__dict__"): # object - items = sorted(row.__dict__.items()) - - else: - raise ValueError("Can not infer schema for type: %s" % type(row)) - - fields = [StructField(k, _infer_type(v), True) for k, v in items] - return StructType(fields) - - -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - False - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - else: - return False - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = map(_python_to_sql_converter, types) - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) - else: - raise ValueError("Unexpected type %r" % dataType) - - -def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ - if isinstance(dt, StructType): - return any(_has_nulltype(f.dataType) for f in dt.fields) - elif isinstance(dt, ArrayType): - return _has_nulltype((dt.elementType)) - elif isinstance(dt, MapType): - return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) - else: - return isinstance(dt, NullType) - - -def _merge_type(a, b): - if isinstance(a, NullType): - return b - elif isinstance(b, NullType): - return a - elif type(a) is not type(b): - # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) - - # same type - if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) - for f in a.fields] - names = set([f.name for f in fields]) - for n in nfs: - if n not in names: - fields.append(StructField(n, nfs[n])) - return StructType(fields) - - elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) - - elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), - True) - else: - return a - - -def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ - if isinstance(dataType, ArrayType): - conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) - - elif isinstance(dataType, MapType): - kconv = _create_converter(dataType.keyType) - vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) - - elif isinstance(dataType, NullType): - return lambda x: None - - elif not isinstance(dataType, StructType): - return lambda x: x - - # dataType must be StructType - names = [f.name for f in dataType.fields] - converters = [_create_converter(f.dataType) for f in dataType.fields] - - def convert_struct(obj): - if obj is None: - return - - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) - - elif isinstance(obj, dict): - d = obj - elif hasattr(obj, "__dict__"): # object - d = obj.__dict__ - else: - raise ValueError("Unexpected obj: %s" % obj) - - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - - return convert_struct - - -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _split_schema_abstract(s): - """ - split the schema abstract into fields - - >>> _split_schema_abstract("a b c") - ['a', 'b', 'c'] - >>> _split_schema_abstract("a(a b)") - ['a(a b)'] - >>> _split_schema_abstract("a b[] c{a b}") - ['a', 'b[]', 'c{a b}'] - >>> _split_schema_abstract(" ") - [] - """ - - r = [] - w = '' - brackets = [] - for c in s: - if c == ' ' and not brackets: - if w: - r.append(w) - w = '' - else: - w += c - if c in _BRACKETS: - brackets.append(c) - elif c in _BRACKETS.values(): - if not brackets or c != _BRACKETS[brackets.pop()]: - raise ValueError("unexpected " + c) - - if brackets: - raise ValueError("brackets not closed: %s" % brackets) - if w: - r.append(w) - return r - - -def _parse_field_abstract(s): - """ - Parse a field in schema abstract - - >>> _parse_field_abstract("a") - StructField(a,None,true) - >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... - >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) - >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) - """ - if set(_BRACKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRACKETS if c in s)) - name = s[:idx] - return StructField(name, _parse_schema_abstract(s[idx:]), True) - else: - return StructField(s, None, True) - - -def _parse_schema_abstract(s): - """ - parse abstract into schema - - >>> _parse_schema_abstract("a b c") - StructType...a...b...c... - >>> _parse_schema_abstract("a[b c] b{}") - StructType...a,ArrayType...b...c...b,MapType... - >>> _parse_schema_abstract("c{} d{a b}") - StructType...c,MapType...d,MapType...a...b... - >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) - """ - s = s.strip() - if not s: - return - - elif s.startswith('('): - return _parse_schema_abstract(s[1:-1]) - - elif s.startswith('['): - return ArrayType(_parse_schema_abstract(s[1:-1]), True) - - elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) - - parts = _split_schema_abstract(s) - fields = [_parse_field_abstract(p) for p in parts] - return StructType(fields) - - -def _infer_schema_type(obj, dataType): - """ - Fill the dataType with types inferred from obj - - >>> schema = _parse_schema_abstract("a b c d") - >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) - >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... - """ - if dataType is None: - return _infer_type(obj) - - if not obj: - return NullType() - - if isinstance(dataType, ArrayType): - eType = _infer_schema_type(obj[0], dataType.elementType) - return ArrayType(eType, True) - - elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() - return MapType(_infer_schema_type(k, dataType.keyType), - _infer_schema_type(v, dataType.valueType)) - - elif isinstance(dataType, StructType): - fs = dataType.fields - assert len(fs) == len(obj), \ - "Obj(%s) have different length with fields(%s)" % (obj, fs) - fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] - return StructType(fields) - - else: - raise ValueError("Unexpected dataType: %s" % dataType) - - -_acceptable_types = { - BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), - FloatType: (float,), - DoubleType: (float,), - DecimalType: (decimal.Decimal,), - StringType: (str, unicode), - BinaryType: (bytearray,), - DateType: (datetime.date,), - TimestampType: (datetime.datetime,), - ArrayType: (list, tuple, array), - MapType: (dict,), - StructType: (tuple, list), -} - - -def _verify_type(obj, dataType): - """ - Verify the type of obj against dataType, raise an exception if - they do not match. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) - >>> _verify_type(range(3), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - """ - # all objects are nullable - if obj is None: - return - - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) - return - - _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType - - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) - - if isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType) - - elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) - - elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) - - -_cached_cls = {} - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for primitive types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) - - def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) - - return Row - - -class SQLContext(object): - - """Main entry point for Spark SQL functionality. - - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as - tables, execute SQL over tables, cache tables, and read parquet files. - """ - - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new - SQLContext in the JVM, instead we make all calls to this object. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - - >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, - ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), - ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> df = sqlCtx.inferSchema(allTypes) - >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' - ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] - """ - self._sc = sparkContext - self._jsc = self._sc._jsc - self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - - @property - def _ssql_ctx(self): - """Accessor for the JVM Spark SQL context. - - Subclasses can override this property to provide their own - JVM Contexts. - """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext - - def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = self._sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self._sc._pickled_broadcast_vars], - self._sc._gateway._gateway_client) - self._sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(self._sc.environment, - self._sc._gateway._gateway_client) - includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_command), - env, - includes, - self._sc.pythonExec, - broadcast_vars, - self._sc._javaAccumulator, - returnType.json()) - - def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> df = sqlCtx.inferSchema(nestedRdd1) - >>> df.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> df = sqlCtx.inferSchema(nestedRdd2) - >>> df.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio > 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) - - def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT * from table1") - >>> df2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) - >>> df = sqlCtx.applySchema(rdd, schema) - >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> df.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] - - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> df = sqlCtx.applySchema(rdd, typedSchema) - >>> df.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) - - def registerRDDAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of - SQLContext. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerRDDAsTable(df, tableName) - else: - raise ValueError("Can only register DataFrame as table") - - def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{DataFrame}. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - jdf = self._ssql_ctx.parquetFile(path) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() - >>> df1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] - """ - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] - """ - return DataFrame(self._ssql_ctx.sql(sqlQuery), self) - - def table(self, tableName): - """Returns the specified table as a L{DataFrame}. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - return DataFrame(self._ssql_ctx.table(tableName), self) - - def cacheTable(self, tableName): - """Caches the specified table in-memory.""" - self._ssql_ctx.cacheTable(tableName) - - def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache.""" - self._ssql_ctx.uncacheTable(tableName) - - -class HiveContext(SQLContext): - - """A variant of Spark SQL that integrates with data stored in Hive. - - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - """ - - def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ - SQLContext.__init__(self, sparkContext) - - if hiveContext: - self._scala_HiveContext = hiveContext - - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) - - def _get_hive_ctx(self): - return self._jvm.HiveContext(self._jsc.sc()) - - -class LocalHiveContext(HiveContext): - - def __init__(self, sparkContext, sqlContext=None): - HiveContext.__init__(self, sparkContext, sqlContext) - warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) - - def _get_hive_ctx(self): - return self._jvm.LocalHiveContext(self._jsc.sc()) - - -def _create_row(fields, values): - row = Row(*values) - row.__FIELDS__ = fields - return row - - -class Row(tuple): - - """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) - row.__FIELDS__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__FIELDS__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) - - # let obect acs like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - - def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - if hasattr(self, "__FIELDS__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) - else: - return "" % ", ".join(self) - - -class DataFrame(object): - - """A collection of rows that have the same columns. - - A :class:`DataFrame` is equivalent to a relational table in Spark SQL, - and can be created using various functions in :class:`SQLContext`:: - - people = sqlContext.parquetFile("...") - - Once created, it can be manipulated using the various domain-specific-language - (DSL) functions defined in: [[DataFrame]], [[Column]]. - - To select a column from the data frame, use the apply method:: - - ageCol = people.age - - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - - A more concrete example:: - - # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") - - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ - .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - self._sc = sql_ctx and sql_ctx._sc - self.is_cached = False - - @property - def rdd(self): - """Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row`s. """ - if not hasattr(self, '_lazy_rdd'): - jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema() - - def applySchema(it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - - return self._lazy_rdd - - def limit(self, num): - """Limit the result count to the number specified. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.limit(2).collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> df.limit(0).collect() - [] - """ - jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) - - def toJSON(self, use_unicode=False): - """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( "SELECT * from table1") - >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' - True - >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] - True - """ - rdd = self._jdf.toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - - def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. - - Files that are written out using this method can be read back in as - a DataFrame using the L{SQLContext.parquetFile} method. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True - """ - self._jdf.saveAsParquetFile(path) - - def registerTempTable(self, name): - """Registers this RDD as a temporary table using the given name. - - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.registerTempTable("test") - >>> df2 = sqlCtx.sql("select * from test") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - self._jdf.registerTempTable(name) - - def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this DataFrame into the specified table. - - Optionally overwriting any existing data. - """ - self._jdf.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName): - """Creates a new table with the contents of this DataFrame.""" - self._jdf.saveAsTable(tableName) - - def schema(self): - """Returns the schema of this DataFrame (represented by - a L{StructType}).""" - return _parse_datatype_json_string(self._jdf.schema().json()) - - def printSchema(self): - """Prints out the schema in the tree format.""" - print (self._jdf.schema().treeString()) - - def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.count() - 3L - >>> df.count() == df.map(lambda x: x).count() - True - """ - return self._jdf.count() - - def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect() - [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] - """ - with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.javaToPython().collect().iterator() - cls = _create_cls(self.schema()) - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile.close() - self._sc._writeToFile(bytesInJava, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) - os.unlink(tempFile.name) - return [cls(r) for r in rs] - - def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.take(2) - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - """ - return self.limit(num).collect() - - def map(self, f): - """ Return a new RDD by applying a function to each Row, it's a - shorthand for df.rdd.map() - """ - return self.rdd.map(f) - - def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(iterator): yield 1 - >>> rdd.mapPartitions(f).sum() - 4 - """ - return self.rdd.mapPartitions(f, preservesPartitioning) - - def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - self._jdf.cache() - return self - - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) - self._jdf.persist(javaStorageLevel) - return self - - def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from - memory and disk. - """ - self.is_cached = False - self._jdf.unpersist(blocking) - return self - - # def coalesce(self, numPartitions, shuffle=False): - # rdd = self._jdf.coalesce(numPartitions, shuffle, None) - # return DataFrame(rdd, self.sql_ctx) - - def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. - """ - rdd = self._jdf.repartition(numPartitions, None) - return DataFrame(rdd, self.sql_ctx) - - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jdf.sample(withReplacement, fraction, long(seed)) - return DataFrame(rdd, self.sql_ctx) - - # def takeSample(self, withReplacement, num, seed=None): - # """Return a fixed-size sampled subset of this DataFrame. - # - # >>> df = sqlCtx.inferSchema(rdd) - # >>> df.takeSample(False, 2, 97) - # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - # """ - # seed = seed if seed is not None else random.randint(0, sys.maxint) - # with SCCallSiteSync(self.context) as css: - # bytesInJava = self._jdf \ - # .takeSampleToPython(withReplacement, num, long(seed)) \ - # .iterator() - # cls = _create_cls(self.schema()) - # return map(cls, self._collect_iterator_through_file(bytesInJava)) - - @property - def dtypes(self): - """Return all column names and their data types as a list. - """ - return [(f.name, str(f.dataType)) for f in self.schema().fields] - - @property - def columns(self): - """ Return all column names as a list. - """ - return [f.name for f in self.schema().fields] - - def show(self): - raise NotImplemented - - def join(self, other, joinExprs=None, joinType=None): - """ - Join with another DataFrame, using the given join expression. - The following performs a full outer join between `df1` and `df2`:: - - df1.join(df2, df1.key == df2.key, "outer") - - :param other: Right side of the join - :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, - `semijoin`. - """ - if joinType is None: - if joinExprs is None: - jdf = self._jdf.join(other._jdf) - else: - jdf = self._jdf.join(other._jdf, joinExprs) - else: - jdf = self._jdf.join(other._jdf, joinExprs, joinType) - return DataFrame(jdf, self.sql_ctx) - - def sort(self, *cols): - """ Return a new [[DataFrame]] sorted by the specified column, - in ascending column. - - :param cols: The columns or expressions used for sorting - """ - if not cols: - raise ValueError("should sort by at least one column") - for i, c in enumerate(cols): - if isinstance(c, basestring): - cols[i] = Column(c) - jcols = [c._jc for c in cols] - jdf = self._jdf.join(*jcols) - return DataFrame(jdf, self.sql_ctx) - - sortBy = sort - - def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. """ - if n is None: - rs = self.head(1) - return rs[0] if rs else None - return self.take(n) - - def first(self): - """ Return the first row. """ - return self.head() - - def tail(self): - raise NotImplemented - - def __getitem__(self, item): - if isinstance(item, basestring): - return Column(self._jdf.apply(item)) - - # TODO projection - raise IndexError - - def __getattr__(self, name): - """ Return the column by given name """ - if name.startswith("__"): - raise AttributeError(name) - return Column(self._jdf.apply(name)) - - def alias(self, name): - """ Alias the current DataFrame """ - return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) - - def select(self, *cols): - """ Selecting a set of expressions.:: - - df.select() - df.select('colA', 'colB') - df.select(df.colA, df.colB + 1) - - """ - if not cols: - cols = ["*"] - if isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) - return DataFrame(jdf, self.sql_ctx) - - def filter(self, condition): - """ Filtering rows using the given condition:: - - df.filter(df.age > 15) - df.where(df.age > 15) - - """ - return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) - - where = filter - - def groupBy(self, *cols): - """ Group the [[DataFrame]] using the specified columns, - so we can run aggregation on them. See :class:`GroupedDataFrame` - for all the available aggregate functions:: - - df.groupBy(df.department).avg() - df.groupBy("department", "gender").agg({ - "salary": "avg", - "age": "max", - }) - """ - if cols and isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) - return GroupedDataFrame(jdf, self.sql_ctx) - - def agg(self, *exprs): - """ Aggregate on the entire [[DataFrame]] without groups - (shorthand for df.groupBy.agg()):: - - df.agg({"age": "max", "salary": "avg"}) - """ - return self.groupBy().agg(*exprs) - - def unionAll(self, other): - """ Return a new DataFrame containing union of rows in this - frame and another frame. - - This is equivalent to `UNION ALL` in SQL. - """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - - def intersect(self, other): - """ Return a new [[DataFrame]] containing rows only in - both this frame and another frame. - - This is equivalent to `INTERSECT` in SQL. - """ - return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) - - def subtract(self, other): - """ Return a new [[DataFrame]] containing rows in this frame - but not in another frame. - - This is equivalent to `EXCEPT` in SQL. - """ - return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - - def sample(self, withReplacement, fraction, seed=None): - """ Return a new DataFrame by sampling a fraction of rows. """ - if seed is None: - jdf = self._jdf.sample(withReplacement, fraction) - else: - jdf = self._jdf.sample(withReplacement, fraction, seed) - return DataFrame(jdf, self.sql_ctx) - - def addColumn(self, colName, col): - """ Return a new [[DataFrame]] by adding a column. """ - return self.select('*', col.alias(colName)) - - def removeColumn(self, colName): - raise NotImplemented - - -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame - """ - - -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedDataFrame(object): - - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. - - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. - - :param exprs: list or aggregate columns or a map from column - name to agregate methods. - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """ Count the number of rows for each group. """ - - @dfapi - def mean(self): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`.""" - - @dfapi - def avg(self): - """Compute the average value for each numeric columns - for each group.""" - - @dfapi - def max(self): - """Compute the max value for each numeric columns for - each group. """ - - @dfapi - def min(self): - """Compute the min value for each numeric column for - each group.""" - - @dfapi - def sum(self): - """Compute the sum for each numeric columns for each - group.""" - - -SCALA_METHOD_MAPPINGS = { - '=': '$eq', - '>': '$greater', - '<': '$less', - '+': '$plus', - '-': '$minus', - '*': '$times', - '/': '$div', - '!': '$bang', - '@': '$at', - '#': '$hash', - '%': '$percent', - '^': '$up', - '&': '$amp', - '~': '$tilde', - '?': '$qmark', - '|': '$bar', - '\\': '$bslash', - ':': '$colon', -} - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.org.apache.spark.sql.Dsl.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.IncomputableColumn(name) - - -def _scalaMethod(name): - """ Translate operators into methodName in Scala - - For example: - >>> _scalaMethod('+') - '$plus' - >>> _scalaMethod('>=') - '$greater$eq' - >>> _scalaMethod('cast') - 'cast' - """ - return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) - - -def _unary_op(name): - """ Create a method for given unary operator """ - def _(self): - return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) - return _ - - -def _bin_op(name, pass_literal_through=True): - """ Create a method for given binary operator - - Keyword arguments: - pass_literal_through -- whether to pass literal value directly through to the JVM. - """ - def _(self, other): - if isinstance(other, Column): - jc = other._jc - else: - if pass_literal_through: - jc = other - else: - jc = _create_column_from_literal(other) - return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) - return _ - - -def _reverse_op(name): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), - self._jdf, self.sql_ctx) - return _ - - -class Column(DataFrame): - - """ - A column in a DataFrame. - - `Column` instances can be created by: - {{{ - // 1. Select a column out of a DataFrame - df.colName - df["colName"] - - // 2. Create from an expression - df["colName"] + 1 - }}} - """ - - def __init__(self, jc, jdf=None, sql_ctx=None): - self._jc = jc - super(Column, self).__init__(jdf, sql_ctx) - - # arithmetic operators - __neg__ = _unary_op("unary_-") - __add__ = _bin_op("+") - __sub__ = _bin_op("-") - __mul__ = _bin_op("*") - __div__ = _bin_op("/") - __mod__ = _bin_op("%") - __radd__ = _bin_op("+") - __rsub__ = _reverse_op("-") - __rmul__ = _bin_op("*") - __rdiv__ = _reverse_op("/") - __rmod__ = _reverse_op("%") - __abs__ = _unary_op("abs") - abs = _unary_op("abs") - sqrt = _unary_op("sqrt") - - # logistic operators - __eq__ = _bin_op("===") - __ne__ = _bin_op("!==") - __lt__ = _bin_op("<") - __le__ = _bin_op("<=") - __ge__ = _bin_op(">=") - __gt__ = _bin_op(">") - # `and`, `or`, `not` cannot be overloaded in Python - And = _bin_op('&&') - Or = _bin_op('||') - Not = _unary_op('unary_!') - - # bitwise operators - __and__ = _bin_op("&") - __or__ = _bin_op("|") - __invert__ = _unary_op("unary_~") - __xor__ = _bin_op("^") - # __lshift__ = _bin_op("<<") - # __rshift__ = _bin_op(">>") - __rand__ = _bin_op("&") - __ror__ = _bin_op("|") - __rxor__ = _bin_op("^") - # __rlshift__ = _reverse_op("<<") - # __rrshift__ = _reverse_op(">>") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") - # __getattr__ = _bin_op("getField") - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - upper = _unary_op("upper") - lower = _unary_op("lower") - - def substr(self, startPos, pos): - if type(startPos) != type(pos): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, pos) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, pos._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self._jdf, self.sql_ctx) - - __getslice__ = substr - - # order - asc = _unary_op("asc") - desc = _unary_op("desc") - - isNull = _unary_op("isNull") - isNotNull = _unary_op("isNotNull") - - # `as` is keyword - def alias(self, alias): - return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) - - def cast(self, dataType): - if self.sql_ctx is None: - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - else: - ssql_ctx = self.sql_ctx._ssql_ctx - jdt = ssql_ctx.parseDataType(dataType.json()) - return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _aggregate_func(name): - """ Create a function for aggregator by name""" - def _(col): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) - return Column(jc) - - return staticmethod(_) - - -class Aggregator(object): - """ - A collections of builtin aggregators - """ - AGGS = [ - 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs', - 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct', - ] - for _name in AGGS: - locals()[_name] = _aggregate_func(_name) - del _name - - @staticmethod - def countDistinct(col, *cols): - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.Dsl.toColumns(jcols)) - return Column(jc) - - @staticmethod - def approxCountDistinct(col, rsd=None): - sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) - - -def _test(): - import doctest - from pyspark.context import SparkContext - # let doctest run in pyspark.sql, so DataTypes can be picklable - import pyspark.sql - from pyspark.sql import Row, SQLContext - from pyspark.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) - (failure_count, test_count) = doctest.testmod( - pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 000000000000..ad9c891ba1c0 --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Important classes of Spark SQL and DataFrames: + + - :class:`pyspark.sql.SQLContext` + Main entry point for :class:`DataFrame` and SQL functionality. + - :class:`pyspark.sql.DataFrame` + A distributed collection of data grouped into named columns. + - :class:`pyspark.sql.Column` + A column expression in a :class:`DataFrame`. + - :class:`pyspark.sql.Row` + A row of data in a :class:`DataFrame`. + - :class:`pyspark.sql.HiveContext` + Main entry point for accessing data stored in Apache Hive. + - :class:`pyspark.sql.GroupedData` + Aggregation methods, returned by :func:`DataFrame.groupBy`. + - :class:`pyspark.sql.DataFrameNaFunctions` + Methods for handling missing data (null values). + - :class:`pyspark.sql.DataFrameStatFunctions` + Methods for statistics functionality. + - :class:`pyspark.sql.functions` + List of built-in functions available for :class:`DataFrame`. + - :class:`pyspark.sql.types` + List of data types available. + - :class:`pyspark.sql.Window` + For working with window functions. +""" +from __future__ import absolute_import + + +def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + + def deco(f): + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) + return f + return deco + + +from pyspark.sql.types import Row +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.group import GroupedData +from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter +from pyspark.sql.window import Window, WindowSpec + + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'DataFrameReader', 'DataFrameWriter' +] diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py new file mode 100644 index 000000000000..0948f9b27cd3 --- /dev/null +++ b/python/pyspark/sql/column.py @@ -0,0 +1,460 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import warnings + +if sys.version >= '3': + basestring = str + long = int + +from pyspark.context import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql import since +from pyspark.sql.types import * + +__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toSeq(cols) + + +def _to_list(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM (Scala) List of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toList(cols) + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("apply") + + # bitwise operators + bitwiseOR = _bin_op("bitwiseOR") + bitwiseAND = _bin_op("bitwiseAND") + bitwiseXOR = _bin_op("bitwiseXOR") + + @since(1.3) + def getItem(self, key): + """ + An expression that gets an item at position ``ordinal`` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + >>> df.select(df.l[0], df.d["key"]).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + """ + return self[key] + + @since(1.3) + def getField(self, name): + """ + An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + +----+ + |r[b]| + +----+ + | b| + +----+ + >>> df.select(df.r.a).show() + +----+ + |r[a]| + +----+ + | 1| + +----+ + """ + return self[name] + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + @ignore_unicode_prefix + @since(1.3) + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column. + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + @ignore_unicode_prefix + @since(1.3) + def inSet(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + + .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. + """ + warnings.warn("inSet is deprecated. Use isin() instead.") + return self.isin(*cols) + + @ignore_unicode_prefix + @since(1.5) + def isin(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.isin("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.isin([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jc = getattr(self._jc, "isin")(_to_seq(sc, cols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + @since(1.3) + def alias(self, *alias): + """ + Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + + @ignore_unicode_prefix + @since(1.3) + def cast(self, dataType): + """ Convert the column into type ``dataType``. + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + astype = cast + + @since(1.3) + def between(self, lowerBound, upperBound): + """ + A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ + """ + return (self >= lowerBound) & (self <= upperBound) + + @since(1.4) + def when(self, condition, value): + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ + """ + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = self._jc.when(condition._jc, v) + return Column(jc) + + @since(1.4) + def otherwise(self, value): + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(v) + return Column(jc) + + @since(1.4) + def over(self, window): + """ + Define a windowing column. + + :param window: a :class:`WindowSpec` + :return: a Column + + >>> from pyspark.sql import Window + >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) + >>> from pyspark.sql.functions import rank, min + >>> # df.select(rank().over(window), min('age').over(window)) + + .. note:: Window functions is only supported with HiveContext in 1.4 + """ + from pyspark.sql.window import WindowSpec + if not isinstance(window, WindowSpec): + raise TypeError("window should be WindowSpec") + jc = self._jc.over(window._jspec) + return Column(jc) + + def __nonzero__(self): + raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " + "'~' for 'not' when building DataFrame boolean expressions.") + __bool__ = __nonzero__ + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + import pyspark.sql.column + globs = pyspark.sql.column.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.column, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 000000000000..0ef46c44644a --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,724 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import warnings +import json + +if sys.version >= '3': + basestring = unicode = str +else: + from itertools import imap as map + +from py4j.protocol import Py4JError + +from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql import since +from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.utils import install_exception_handler +from pyspark.sql.functions import UserDefinedFunction + +try: + import pandas + has_pandas = True +except Exception: + has_pandas = False + +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] + + +def _monkey_patch_RDD(sqlContext): + def toDF(self, schema=None, sampleRatio=None): + """ + Converts current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sqlContext.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + +class SQLContext(object): + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as + tables, execute SQL over tables, cache tables, and read parquet files. + + :param sparkContext: The :class:`SparkContext` backing this SQLContext. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new + SQLContext in the JVM, instead we make all calls to this object. + """ + + @ignore_unicode_prefix + def __init__(self, sparkContext, sqlContext=None): + """Creates a new SQLContext. + + >>> from datetime import datetime + >>> sqlContext = SQLContext(sc) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = allTypes.toDF() + >>> df.registerTempTable("allTypes") + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ + time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + _monkey_patch_RDD(self) + install_exception_handler() + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + @since(1.3) + def setConf(self, key, value): + """Sets the given Spark SQL configuration property. + """ + self._ssql_ctx.setConf(key, value) + + @since(1.3) + def getConf(self, key, defaultValue): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set, returns defaultValue. + """ + return self._ssql_ctx.getConf(key, defaultValue) + + @property + @since("1.3.1") + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ + return UDFRegistration(self) + + @since(1.4) + def range(self, start, end=None, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: :class:`DataFrame` + + >>> sqlContext.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + + If only one argument is specified, it will be used as the end value. + + >>> sqlContext.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + if end is None: + jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + + return DataFrame(jdf, self) + + @ignore_unicode_prefix + @since(1.2) + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + :param name: name of the UDF + :param samplingRatio: lambda function + :param returnType: a :class:`DataType` object + + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() + [Row(_c0=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(_c0=4)] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(_c0=4)] + """ + udf = UserDefinedFunction(f, returnType, name) + self._ssql_ctx.udf().registerPython(name, udf._judf) + + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + + @ignore_unicode_prefix + def inferSchema(self, rdd, samplingRatio=None): + """ + .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + return self.createDataFrame(rdd, None, samplingRatio) + + @ignore_unicode_prefix + def applySchema(self, rdd, schema): + """ + .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType, but got %s" % type(schema)) + + return self.createDataFrame(rdd, schema) + + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + + @since(1.3) + @ignore_unicode_prefix + def createDataFrame(self, data, schema=None, samplingRatio=None): + """ + Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, + list or :class:`pandas.DataFrame`. + + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. + + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. + + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. + + :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, + :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`StructType` or list of column names. default None. + :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` + + >>> l = [('Alice', 1)] + >>> sqlContext.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> sqlContext.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) + >>> sqlContext.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = sqlContext.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = sqlContext.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] + >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP + [Row(0=1, 1=2)] + """ + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data, schema, samplingRatio) + else: + rdd, schema = self._createFromLocal(data, schema) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self) + df._schema = schema + return df + + @since(1.3) + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + """ + if (df.__class__ is DataFrame): + self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a :class:`DataFrame`. + + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. + + >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") + gateway = self._sc._gateway + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) + for i in range(0, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """Loads a text file storing one JSON object per line as a :class:`DataFrame`. + + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. + + >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes + [('age', 'bigint'), ('name', 'string')] + """ + warnings.warn("jsonFile is deprecated. Use read.json() instead.") + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + @ignore_unicode_prefix + @since(1.0) + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + + >>> df1 = sqlContext.jsonRDD(json) + >>> df1.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> df2 = sqlContext.jsonRDD(json, df1.schema) + >>> df2.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))])) + ... ]) + >>> df3 = sqlContext.jsonRDD(json, schema) + >>> df3.first() + Row(field2=u'row1', field3=Row(field5=None)) + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def load(self, path=None, source=None, schema=None, **options): + """Returns the dataset in a data source as a :class:`DataFrame`. + + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. + """ + warnings.warn("load is deprecated. Use read.load() instead.") + return self.read.load(path, source, schema, **options) + + @since(1.3) + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + + :return: :class:`DataFrame` + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + if schema is None: + df = self._ssql_ctx.createExternalTable(tableName, source, options) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, + options) + return DataFrame(df, self) + + @ignore_unicode_prefix + @since(1.0) + def sql(self, sqlQuery): + """Returns a :class:`DataFrame` representing the result of the given query. + + :return: :class:`DataFrame` + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + @since(1.0) + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + :return: :class:`DataFrame` + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + @ignore_unicode_prefix + @since(1.3) + def tables(self, dbName=None): + """Returns a :class:`DataFrame` containing names of tables in the given database. + + If ``dbName`` is not specified, the current database will be used. + + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + + :param dbName: string, name of the database to use. + :return: :class:`DataFrame` + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() + >>> df2.filter("tableName = 'table1'").first() + Row(tableName=u'table1', isTemporary=True) + """ + if dbName is None: + return DataFrame(self._ssql_ctx.tables(), self) + else: + return DataFrame(self._ssql_ctx.tables(dbName), self) + + @since(1.3) + def tableNames(self, dbName=None): + """Returns a list of names of tables in the database ``dbName``. + + :param dbName: string, name of the database to use. Default to the current database. + :return: list of table names, in string + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() + True + >>> "table1" in sqlContext.tableNames("db") + True + """ + if dbName is None: + return [name for name in self._ssql_ctx.tableNames()] + else: + return [name for name in self._ssql_ctx.tableNames(dbName)] + + @since(1.0) + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + @since(1.0) + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + @since(1.3) + def clearCache(self): + """Removes all cached tables from the in-memory cache. """ + self._ssql_ctx.clearCache() + + @property + @since(1.4) + def read(self): + """ + Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + :return: :class:`DataFrameReader` + """ + return DataFrameReader(self) + + +class HiveContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from ``hive-site.xml`` on the classpath. + It supports running both SQL and HiveQL commands. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :class:`HiveContext` in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, hiveContext=None): + SQLContext.__init__(self, sparkContext) + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + def refreshTable(self, tableName): + """Invalidate and refresh all the cached the metadata of the given + table. For performance reasons, Spark SQL or the external data source + library it uses might cache certain metadata about a table, such as the + location of blocks. When those change outside of Spark SQL, users should + call this function to invalidate the cache. + """ + self._ssql_ctx.refreshTable(tableName) + + +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sqlContext): + self.sqlContext = sqlContext + + def register(self, name, f, returnType=StringType()): + return self.sqlContext.registerFunction(name, f, returnType) + + register.__doc__ = SQLContext.registerFunction.__doc__ + + +def _test(): + import os + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['rdd'] = rdd = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + globs['df'] = rdd.toDF() + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 000000000000..e269ef4304f3 --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,1383 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import warnings +import random + +if sys.version >= '3': + basestring = unicode = str + long = int + from functools import reduce +else: + from itertools import imap as map + +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql import since +from pyspark.sql.types import _parse_datatype_json_string +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column +from pyspark.sql.readwriter import DataFrameWriter +from pyspark.sql.types import * + +__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] + + +class DataFrame(object): + """A distributed collection of data grouped into named columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.read.parquet("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.read.parquet("...") + department = sqlContext.read.parquet("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + self._schema = None # initialized lazily + self._lazy_rdd = None + + @property + @since(1.3) + def rdd(self): + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. + """ + if self._lazy_rdd is None: + jrdd = self._jdf.javaToPython() + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + return self._lazy_rdd + + @property + @since("1.3.1") + def na(self): + """Returns a :class:`DataFrameNaFunctions` for handling missing values. + """ + return DataFrameNaFunctions(self) + + @property + @since(1.4) + def stat(self): + """Returns a :class:`DataFrameStatFunctions` for statistic functions. + """ + return DataFrameStatFunctions(self) + + @ignore_unicode_prefix + @since(1.3) + def toJSON(self, use_unicode=True): + """Converts a :class:`DataFrame` into a :class:`RDD` of string. + + Each row is turned into a JSON document as one element in the returned RDD. + + >>> df.toJSON().first() + u'{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Saves the contents as a Parquet file, preserving the schema. + + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. + """ + warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") + self._jdf.saveAsParquetFile(path) + + @since(1.3) + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the :class:`SQLContext` + that was used to create this :class:`DataFrame`. + + >>> df.registerTempTable("people") + >>> df2 = sqlContext.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """ + .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. + """ + warnings.warn("Use registerTempTable instead of registerAsTable.") + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this :class:`DataFrame` into the specified table. + + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. + """ + warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") + self.write.insertInto(tableName, overwrite) + + def saveAsTable(self, tableName, source=None, mode="error", **options): + """Saves the contents of this :class:`DataFrame` to a data source as a table. + + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. + """ + warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") + self.write.saveAsTable(tableName, source, mode, **options) + + @since(1.3) + def save(self, path=None, source=None, mode="error", **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. + """ + warnings.warn("insertInto is deprecated. Use write.save() instead.") + return self.write.save(path, source, mode, **options) + + @property + @since(1.4) + def write(self): + """ + Interface for saving the content of the :class:`DataFrame` out into external storage. + + :return: :class:`DataFrameWriter` + """ + return DataFrameWriter(self) + + @property + @since(1.3) + def schema(self): + """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. + + >>> df.schema + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + if self._schema is None: + try: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + except AttributeError as e: + raise Exception( + "Unable to parse datatype from schema. %s" % e) + return self._schema + + @since(1.3) + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ + print(self._jdf.schema().treeString()) + + @since(1.3) + def explain(self, extended=False): + """Prints the (logical and physical) plans to the console for debugging purpose. + + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + + >>> df.explain() + Scan PhysicalRDD[age#0,name#1] + + >>> df.explain(True) + == Parsed Logical Plan == + ... + == Analyzed Logical Plan == + ... + == Optimized Logical Plan == + ... + == Physical Plan == + ... + """ + if extended: + print(self._jdf.queryExecution().toString()) + else: + print(self._jdf.queryExecution().executedPlan().toString()) + + @since(1.3) + def isLocal(self): + """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally + (without any Spark executors). + """ + return self._jdf.isLocal() + + @since(1.3) + def show(self, n=20, truncate=True): + """Prints the first ``n`` rows to the console. + + :param n: Number of rows to show. + :param truncate: Whether truncate long strings and align cells right. + + >>> df + DataFrame[age: int, name: string] + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + print(self._jdf.showString(n, truncate)) + + def __repr__(self): + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + @since(1.3) + def count(self): + """Returns the number of rows in this :class:`DataFrame`. + + >>> df.count() + 2 + """ + return int(self._jdf.count()) + + @ignore_unicode_prefix + @since(1.3) + def collect(self): + """Returns all the records as a list of :class:`Row`. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + + @ignore_unicode_prefix + @since(1.3) + def limit(self, num): + """Limits the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def take(self, num): + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + @ignore_unicode_prefix + @since(1.3) + def map(self, f): + """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. + + This is a shorthand for ``df.rdd.map()``. + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + @ignore_unicode_prefix + @since(1.3) + def flatMap(self, f): + """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, + and then flattening the results. + + This is a shorthand for ``df.rdd.flatMap()``. + + >>> df.flatMap(lambda p: p.name).collect() + [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] + """ + return self.rdd.flatMap(f) + + @since(1.3) + def mapPartitions(self, f, preservesPartitioning=False): + """Returns a new :class:`RDD` by applying the ``f`` function to each partition. + + This is a shorthand for ``df.rdd.mapPartitions()``. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + @since(1.3) + def foreach(self, f): + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. + + This is a shorthand for ``df.rdd.foreach()``. + + >>> def f(person): + ... print(person.name) + >>> df.foreach(f) + """ + return self.rdd.foreach(f) + + @since(1.3) + def foreachPartition(self, f): + """Applies the ``f`` function to each partition of this :class:`DataFrame`. + + This a shorthand for ``df.rdd.foreachPartition()``. + + >>> def f(people): + ... for person in people: + ... print(person.name) + >>> df.foreachPartition(f) + """ + return self.rdd.foreachPartition(f) + + @since(1.3) + def cache(self): + """ Persists with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + @since(1.3) + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """Sets the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + @since(1.3) + def unpersist(self, blocking=True): + """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + @since(1.4) + def coalesce(self, numPartitions): + """ + Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + + Similar to coalesce defined on an :class:`RDD`, this operation results in a + narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, + there will not be a shuffle, instead each of the 100 new partitions will + claim 10 of the current partitions. + + >>> df.coalesce(1).rdd.getNumPartitions() + 1 + """ + return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx) + + @since(1.3) + def repartition(self, numPartitions): + """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + """ + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + + @since(1.3) + def distinct(self): + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + + >>> df.distinct().count() + 2 + """ + return DataFrame(self._jdf.distinct(), self.sql_ctx) + + @since(1.3) + def sample(self, withReplacement, fraction, seed=None): + """Returns a sampled subset of this :class:`DataFrame`. + + >>> df.sample(False, 0.5, 42).count() + 1 + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + + @since(1.4) + def randomSplit(self, weights, seed=None): + """Randomly splits this :class:`DataFrame` with the provided weights. + + :param weights: list of doubles as weights with which to split the DataFrame. Weights will + be normalized if they don't sum up to 1.0. + :param seed: The seed for sampling. + + >>> splits = df4.randomSplit([1.0, 2.0], 24) + >>> splits[0].count() + 1 + + >>> splits[1].count() + 3 + """ + for w in weights: + if w < 0.0: + raise ValueError("Weights must be positive. Found weight value: %s" % w) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) + return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] + + @property + @since(1.3) + def dtypes(self): + """Returns all column names and their data types as a list. + + >>> df.dtypes + [('age', 'int'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] + + @property + @since(1.3) + def columns(self): + """Returns all column names as a list. + + >>> df.columns + ['age', 'name'] + """ + return [f.name for f in self.schema.fields] + + @ignore_unicode_prefix + @since(1.3) + def alias(self, alias): + """Returns a new :class:`DataFrame` with an alias set. + + >>> from pyspark.sql.functions import * + >>> df_as1 = df.alias("df_as1") + >>> df_as2 = df.alias("df_as2") + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') + >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() + [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + """ + assert isinstance(alias, basestring), "alias should be a string" + return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def join(self, other, on=None, how=None): + """Joins with another :class:`DataFrame`, using the given join expression. + + The following performs a full outer join between ``df1`` and ``df2``. + + :param other: Right side of the join + :param on: a string for join column name, a list of column names, + , a join expression (Column) or a list of Columns. + If `on` is a string or a list of string indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an inner equi-join. + :param how: str, default 'inner'. + One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + + >>> cond = [df.name == df3.name, df.age == df3.age] + >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() + [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + + >>> df.join(df2, 'name').select(df.name, df2.height).collect() + [Row(name=u'Bob', height=85)] + + >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() + [Row(name=u'Bob', age=5)] + """ + + if on is not None and not isinstance(on, list): + on = [on] + + if on is None or len(on) == 0: + jdf = self._jdf.join(other._jdf) + elif isinstance(on[0], basestring): + jdf = self._jdf.join(other._jdf, self._jseq(on)) + else: + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + if how is None: + jdf = self._jdf.join(other._jdf, on._jc, "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on._jc, how) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def sort(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sort("age", ascending=False).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> from pyspark.sql.functions import * + >>> df.sort(asc("age")).collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.orderBy(desc("age"), "name").collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + + jdf = self._jdf.sort(self._jseq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + orderBy = sort + + def _jseq(self, cols, converter=None): + """Return a JVM Seq of Columns from a list of Column or names""" + return _to_seq(self.sql_ctx._sc, cols, converter) + + def _jmap(self, jm): + """Return a JVM Scala Map from a dict""" + return _to_scala_map(self.sql_ctx._sc, jm) + + def _jcols(self, *cols): + """Return a JVM Seq of Columns from a list of Column or column names + + If `cols` has only one list in it, cols[0] will be used as the list. + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return self._jseq(cols, _to_java_column) + + @since("1.3.1") + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + + >>> df.describe().show() + +-------+---+ + |summary|age| + +-------+---+ + | count| 2| + | mean|3.5| + | stddev|1.5| + | min| 2| + | max| 5| + +-------+---+ + >>> df.describe(['age', 'name']).show() + +-------+---+-----+ + |summary|age| name| + +-------+---+-----+ + | count| 2| 2| + | mean|3.5| null| + | stddev|1.5| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+---+-----+ + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jdf = self._jdf.describe(self._jseq(cols)) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def head(self, n=None): + """Returns the first ``n`` rows. + + :param n: int, default 1. Number of rows to return. + :return: If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + @ignore_unicode_prefix + @since(1.3) + def first(self): + """Returns the first row as a :class:`Row`. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + @ignore_unicode_prefix + @since(1.3) + def __getitem__(self, item): + """Returns the column as a :class:`Column`. + + >>> df.select(df['age']).collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + >>> df[df[0] > 3].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + jc = self._jdf.apply(item) + return Column(jc) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, (list, tuple)): + return self.select(*item) + elif isinstance(item, int): + jc = self._jdf.apply(self.columns[item]) + return Column(jc) + else: + raise TypeError("unexpected item type: %s" % type(item)) + + @since(1.3) + def __getattr__(self, name): + """Returns the :class:`Column` denoted by ``name``. + + >>> df.select(df.age).collect() + [Row(age=2), Row(age=5)] + """ + if name not in self.columns: + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) + jc = self._jdf.apply(name) + return Column(jc) + + @ignore_unicode_prefix + @since(1.3) + def select(self, *cols): + """Projects a set of expressions and returns a new :class:`DataFrame`. + + :param cols: list of column names (string) or expressions (:class:`Column`). + If one of the column names is '*', that column is expanded to include all columns + in the current DataFrame. + + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + jdf = self._jdf.select(self._jcols(*cols)) + return DataFrame(jdf, self.sql_ctx) + + @since(1.3) + def selectExpr(self, *expr): + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] + """ + if len(expr) == 1 and isinstance(expr[0], list): + expr = expr[0] + jdf = self._jdf.selectExpr(self._jseq(expr)) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def filter(self, condition): + """Filters rows using the given condition. + + :func:`where` is an alias for :func:`filter`. + + :param condition: a :class:`Column` of :class:`types.BooleanType` + or a string of SQL expression. + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + @ignore_unicode_prefix + @since(1.3) + def groupBy(self, *cols): + """Groups the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + :func:`groupby` is an alias for :func:`groupBy`. + + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + + >>> df.groupBy().avg().collect() + [Row(avg(age)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] + >>> df.groupBy(['name', df.age]).count().collect() + [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] + """ + jgd = self._jdf.groupBy(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) + def rollup(self, *cols): + """ + Create a multi-dimensional rollup for the current :class:`DataFrame` using + the specified columns, so we can run aggregation on them. + + >>> df.rollup('name', df.age).count().show() + +-----+----+-----+ + | name| age|count| + +-----+----+-----+ + |Alice|null| 1| + | Bob| 5| 1| + | Bob|null| 1| + | null|null| 2| + |Alice| 2| 1| + +-----+----+-----+ + """ + jgd = self._jdf.rollup(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) + def cube(self, *cols): + """ + Create a multi-dimensional cube for the current :class:`DataFrame` using + the specified columns, so we can run aggregation on them. + + >>> df.cube('name', df.age).count().show() + +-----+----+-----+ + | name| age|count| + +-----+----+-----+ + | null| 2| 1| + |Alice|null| 1| + | Bob| 5| 1| + | Bob|null| 1| + | null| 5| 1| + | null|null| 2| + |Alice| 2| 1| + +-----+----+-----+ + """ + jgd = self._jdf.cube(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.3) + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for ``df.groupBy.agg()``). + + >>> df.agg({"age": "max"}).collect() + [Row(max(age)=5)] + >>> from pyspark.sql import functions as F + >>> df.agg(F.min(df.age)).collect() + [Row(min(age)=2)] + """ + return self.groupBy().agg(*exprs) + + @since(1.3) + def unionAll(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + @since(1.3) + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + @since(1.3) + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + @since(1.4) + def dropDuplicates(self, subset=None): + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only considering certain columns. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([ \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=5, height=80), \ + Row(name='Alice', age=10, height=80)]).toDF() + >>> df.dropDuplicates().show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + | 10| 80|Alice| + +---+------+-----+ + + >>> df.dropDuplicates(['name', 'height']).show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 5| 80|Alice| + +---+------+-----+ + """ + if subset is None: + jdf = self._jdf.dropDuplicates() + else: + jdf = self._jdf.dropDuplicates(self._jseq(subset)) + return DataFrame(jdf, self.sql_ctx) + + @since("1.3.1") + def dropna(self, how='any', thresh=None, subset=None): + """Returns a new :class:`DataFrame` omitting rows with null values. + :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. + + :param how: 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + :param thresh: int, default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + :param subset: optional list of column names to consider. + + >>> df4.na.drop().show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 10| 80|Alice| + +---+------+-----+ + """ + if how is not None and how not in ['any', 'all']: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if subset is None: + subset = self.columns + elif isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + if thresh is None: + thresh = len(subset) if how == 'any' else 1 + + return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) + + @since("1.3.1") + def fillna(self, value, subset=None): + """Replace null values, alias for ``na.fill()``. + :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. + + :param value: int, long, float, string, or dict. + Value to replace null values with. + If the value is a dict, then `subset` is ignored and `value` must be a mapping + from column name (string) to replacement value. The replacement value must be + an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.na.fill(50).show() + +---+------+-----+ + |age|height| name| + +---+------+-----+ + | 10| 80|Alice| + | 5| 50| Bob| + | 50| 50| Tom| + | 50| 50| null| + +---+------+-----+ + + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() + +---+------+-------+ + |age|height| name| + +---+------+-------+ + | 10| 80| Alice| + | 5| null| Bob| + | 50| null| Tom| + | 50| null|unknown| + +---+------+-------+ + """ + if not isinstance(value, (float, int, long, basestring, dict)): + raise ValueError("value should be a float, int, long, string, or dict") + + if isinstance(value, (int, long)): + value = float(value) + + if isinstance(value, dict): + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + elif subset is None: + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + else: + if isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + + @since(1.4) + def replace(self, to_replace, value, subset=None): + """Returns a new :class:`DataFrame` replacing a value with another value. + :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are + aliases of each other. + + :param to_replace: int, long, float, string, or list. + Value to be replaced. + If the value is a dict, then `value` is ignored and `to_replace` must be a + mapping from column name (string) to replacement value. The value to be + replaced must be an int, long, float, or string. + :param value: int, long, float, string, or list. + Value to use to replace holes. + The replacement value must be an int, long, float, or string. If `value` is a + list or tuple, `value` should be of the same length with `to_replace`. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.na.replace(10, 20).show() + +----+------+-----+ + | age|height| name| + +----+------+-----+ + | 20| 80|Alice| + | 5| null| Bob| + |null| null| Tom| + |null| null| null| + +----+------+-----+ + + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80| A| + | 5| null| B| + |null| null| Tom| + |null| null|null| + +----+------+----+ + """ + if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + raise ValueError( + "to_replace should be a float, int, long, string, list, tuple, or dict") + + if not isinstance(value, (float, int, long, basestring, list, tuple)): + raise ValueError("value should be a float, int, long, string, list, or tuple") + + rep_dict = dict() + + if isinstance(to_replace, (float, int, long, basestring)): + to_replace = [to_replace] + + if isinstance(to_replace, tuple): + to_replace = list(to_replace) + + if isinstance(value, tuple): + value = list(value) + + if isinstance(to_replace, list) and isinstance(value, list): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length") + rep_dict = dict(zip(to_replace, value)) + elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): + rep_dict = dict([(tr, value) for tr in to_replace]) + elif isinstance(to_replace, dict): + rep_dict = to_replace + + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + elif isinstance(subset, basestring): + subset = [subset] + + if not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + + @since(1.4) + def corr(self, col1, col2, method=None): + """ + Calculates the correlation of two columns of a DataFrame as a double value. + Currently only supports the Pearson Correlation Coefficient. + :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other. + + :param col1: The name of the first column + :param col2: The name of the second column + :param method: The correlation method. Currently only supports "pearson" + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + if not method: + method = "pearson" + if not method == "pearson": + raise ValueError("Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") + return self._jdf.stat().corr(col1, col2, method) + + @since(1.4) + def cov(self, col1, col2): + """ + Calculate the sample covariance for the given columns, specified by their names, as a + double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. + + :param col1: The name of the first column + :param col2: The name of the second column + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return self._jdf.stat().cov(col1, col2) + + @since(1.4) + def crosstab(self, col1, col2): + """ + Computes a pair-wise frequency table of the given columns. Also known as a contingency + table. The number of distinct values for each column should be less than 1e4. At most 1e6 + non-zero pair frequencies will be returned. + The first column of each row will be the distinct values of `col1` and the column names + will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. + Pairs that have no occurrences will have zero as their counts. + :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. + + :param col1: The name of the first column. Distinct items will make the first item of + each row. + :param col2: The name of the second column. Distinct items will make the column names + of the DataFrame. + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) + + @since(1.4) + def freqItems(self, cols, support=None): + """ + Finding frequent items for columns, possibly with false positives. Using the + frequent element count algorithm described in + "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". + :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + + :param cols: Names of the columns to calculate frequent items for as a list or tuple of + strings. + :param support: The frequency with which to consider an item 'frequent'. Default is 1%. + The support must be greater than 1e-4. + """ + if isinstance(cols, tuple): + cols = list(cols) + if not isinstance(cols, list): + raise ValueError("cols must be a list or tuple of column names as strings.") + if not support: + support = 0.01 + return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def withColumn(self, colName, col): + """ + Returns a new :class:`DataFrame` by adding a column or replacing the + existing column that has the same name. + + :param colName: string, name of the new column. + :param col: a :class:`Column` expression for the new column. + + >>> df.withColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + assert isinstance(col, Column), "col should be Column" + return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) + + @ignore_unicode_prefix + @since(1.3) + def withColumnRenamed(self, existing, new): + """Returns a new :class:`DataFrame` by renaming an existing column. + + :param existing: string, name of the existing column to rename. + :param col: string, new name of the column. + + >>> df.withColumnRenamed('age', 'age2').collect() + [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + """ + return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) + + @since(1.4) + @ignore_unicode_prefix + def drop(self, col): + """Returns a new :class:`DataFrame` that drops the specified column. + + :param col: a string name of the column to drop, or a + :class:`Column` to drop. + + >>> df.drop('age').collect() + [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.drop(df.age).collect() + [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() + [Row(age=5, height=85, name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() + [Row(age=5, name=u'Bob', height=85)] + """ + if isinstance(col, basestring): + jdf = self._jdf.drop(col) + elif isinstance(col, Column): + jdf = self._jdf.drop(col._jc) + else: + raise TypeError("col should be a string or a Column") + return DataFrame(jdf, self.sql_ctx) + + @since(1.3) + def toPandas(self): + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. + + >>> df.toPandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + ########################################################################################## + # Pandas compatibility + ########################################################################################## + + groupby = groupBy + drop_duplicates = dropDuplicates + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """SchemaRDD is deprecated, please use :class:`DataFrame`. + """ + + +def _to_scala_map(sc, jm): + """ + Convert a dict into a JVM Map. + """ + return sc._jvm.PythonUtils.toScalaMap(jm) + + +class DataFrameNaFunctions(object): + """Functionality for working with missing data in :class:`DataFrame`. + + .. versionadded:: 1.4 + """ + + def __init__(self, df): + self.df = df + + def drop(self, how='any', thresh=None, subset=None): + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + + def fill(self, value, subset=None): + return self.df.fillna(value=value, subset=subset) + + fill.__doc__ = DataFrame.fillna.__doc__ + + def replace(self, to_replace, value, subset=None): + return self.df.replace(to_replace, value, subset) + + replace.__doc__ = DataFrame.replace.__doc__ + + +class DataFrameStatFunctions(object): + """Functionality for statistic functions with :class:`DataFrame`. + + .. versionadded:: 1.4 + """ + + def __init__(self, df): + self.df = df + + def corr(self, col1, col2, method=None): + return self.df.corr(col1, col2, method) + + corr.__doc__ = DataFrame.corr.__doc__ + + def cov(self, col1, col2): + return self.df.cov(col1, col2) + + cov.__doc__ = DataFrame.cov.__doc__ + + def crosstab(self, col1, col2): + return self.df.crosstab(col1, col2) + + crosstab.__doc__ = DataFrame.crosstab.__doc__ + + def freqItems(self, cols, support=None): + return self.df.freqItems(cols, support) + + freqItems.__doc__ = DataFrame.freqItems.__doc__ + + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2), + Row(name='Bob', age=5)]).toDF() + globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py new file mode 100644 index 000000000000..4b74a501521a --- /dev/null +++ b/python/pyspark/sql/functions.py @@ -0,0 +1,1478 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin functions +""" +import math +import sys + +if sys.version < "3": + from itertools import imap as map + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql import since +from pyspark.sql.types import StringType +from pyspark.sql.column import Column, _to_java_column, _to_seq + + +def _create_function(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +def _create_binary_mathfunction(name, doc=""): + """ Create a binary mathfunction by name""" + def _(col1, col2): + sc = SparkContext._active_spark_context + # users might write ints for simplicity. This would throw an error on the JVM side. + jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1), + col2._jc if isinstance(col2, Column) else float(col2)) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +def _create_window_function(name, doc=''): + """ Create a window function by name """ + def _(): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)() + return Column(jc) + _.__name__ = name + _.__doc__ = 'Window function: ' + doc + return _ + + +_functions = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'asc': 'Returns a sort expression based on the ascending order of the given column name.', + 'desc': 'Returns a sort expression based on the descending order of the given column name.', + + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolute value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + +_functions_1_4 = { + # unary math functions + 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + + '0.0 through pi.', + 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2.', + 'atan': 'Computes the tangent inverse of the given value.', + 'cbrt': 'Computes the cube-root of the given value.', + 'ceil': 'Computes the ceiling of the given value.', + 'cos': 'Computes the cosine of the given value.', + 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'exp': 'Computes the exponential of the given value.', + 'expm1': 'Computes the exponential of the given value minus one.', + 'floor': 'Computes the floor of the given value.', + 'log': 'Computes the natural logarithm of the given value.', + 'log10': 'Computes the logarithm of the given value in Base 10.', + 'log1p': 'Computes the natural logarithm of the given value plus one.', + 'rint': 'Returns the double value that is closest in value to the argument and' + + ' is equal to a mathematical integer.', + 'signum': 'Computes the signum of the given value.', + 'sin': 'Computes the sine of the given value.', + 'sinh': 'Computes the hyperbolic sine of the given value.', + 'tan': 'Computes the tangent of the given value.', + 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + + 'measured in degrees.', + 'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + + 'measured in radians.', + + 'bitwiseNOT': 'Computes bitwise not.', +} + +# math functions that take two arguments as input +_binary_mathfunctions = { + 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + + 'polar coordinates (r, theta).', + 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', + 'pow': 'Returns the value of the first argument raised to the power of the second argument.', +} + +_window_functions = { + 'rowNumber': + """returns a sequential number starting at 1 within a window partition. + + This is equivalent to the ROW_NUMBER function in SQL.""", + 'denseRank': + """returns the rank of rows within a window partition, without any gaps. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the DENSE_RANK function in SQL.""", + 'rank': + """returns the rank of rows within a window partition. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the RANK function in SQL.""", + 'cumeDist': + """returns the cumulative distribution of values within a window partition, + i.e. the fraction of rows that are below the current row. + + This is equivalent to the CUME_DIST function in SQL.""", + 'percentRank': + """returns the relative rank (i.e. percentile) of rows within a window partition. + + This is equivalent to the PERCENT_RANK function in SQL.""", +} + +for _name, _doc in _functions.items(): + globals()[_name] = since(1.3)(_create_function(_name, _doc)) +for _name, _doc in _functions_1_4.items(): + globals()[_name] = since(1.4)(_create_function(_name, _doc)) +for _name, _doc in _binary_mathfunctions.items(): + globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) +for _name, _doc in _window_functions.items(): + globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) +del _name, _doc + + +@since(1.3) +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + +@since(1.4) +def coalesce(*cols): + """Returns the first column that is not null. + + >>> cDf = sqlContext.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) + >>> cDf.show() + +----+----+ + | a| b| + +----+----+ + |null|null| + | 1|null| + |null| 2| + +----+----+ + + >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() + +-------------+ + |coalesce(a,b)| + +-------------+ + | null| + | 1| + | 2| + +-------------+ + + >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() + +----+----+---------------+ + | a| b|coalesce(a,0.0)| + +----+----+---------------+ + |null|null| 0.0| + | 1|null| 1.0| + |null| 2| 0.0| + +----+----+---------------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.3) +def countDistinct(col, *cols): + """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. + + >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.4) +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. + + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + +@since(1.4) +def rand(seed=None): + """Generates a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed is not None: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +@since(1.4) +def randn(seed=None): + """Generates a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed is not None: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + +@since(1.5) +def round(col, scale=0): + """ + Round the value of `e` to `scale` decimal places if `scale` >= 0 + or at integral part when `scale` < 0. + + >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect() + [Row(r=2.5)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.round(_to_java_column(col), scale)) + + +@since(1.5) +def shiftLeft(col, numBits): + """Shift the the given value numBits left. + + >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + [Row(r=42)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)) + + +@since(1.5) +def shiftRight(col, numBits): + """Shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + [Row(r=21)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftRightUnsigned(col, numBits): + """Unsigned shift the the given value numBits right. + + >>> df = sqlContext.createDataFrame([(-42,)], ['a']) + >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() + [Row(r=9223372036854775787)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.4) +def sparkPartitionId(): + """A column for partition ID of the Spark task. + + Note that this is indeterministic because it depends on data partitioning and task scheduling. + + >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + [Row(pid=0), Row(pid=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.sparkPartitionId()) + + +@since(1.5) +def expr(str): + """Parses the expression string into the column that it represents + + >>> df.select(expr("length(name)")).collect() + [Row('length(name)=5), Row('length(name)=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.expr(str)) + + +@ignore_unicode_prefix +@since(1.4) +def struct(*cols): + """Creates a new struct column. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> df.select(struct('age', 'name').alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + >>> df.select(struct([df.age, df.name]).alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.5) +def greatest(*cols): + """ + Returns the greatest value of the list of column names, skipping null values. + This function takes at least 2 parameters. It will return null iff all parameters are null. + + >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() + [Row(greatest=4)] + """ + if len(cols) < 2: + raise ValueError("greatest should take at least two columns") + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def least(*cols): + """ + Returns the least value of the list of column names, skipping null values. + This function takes at least 2 parameters. It will return null iff all parameters are null. + + >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() + [Row(least=1)] + """ + if len(cols) < 2: + raise ValueError("least should take at least two columns") + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column))) + + +@since(1.4) +def when(condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + + >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + +@since(1.5) +def log(arg1, arg2=None): + """Returns the first argument-based logarithm of the second argument. + + If there is only one argument, then this takes the natural logarithm of the argument. + + >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + ['0.30102', '0.69897'] + + >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() + ['0.69314', '1.60943'] + """ + sc = SparkContext._active_spark_context + if arg2 is None: + jc = sc._jvm.functions.log(_to_java_column(arg1)) + else: + jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) + return Column(jc) + + +@since(1.5) +def log2(col): + """Returns the base-2 logarithm of the argument. + + >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() + [Row(log2=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.log2(_to_java_column(col))) + + +@since(1.5) +@ignore_unicode_prefix +def conv(col, fromBase, toBase): + """ + Convert a number in a string column from one base to another. + + >>> df = sqlContext.createDataFrame([("010101",)], ['n']) + >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() + [Row(hex=u'15')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase)) + + +@since(1.5) +def factorial(col): + """ + Computes the factorial of the given value. + + >>> df = sqlContext.createDataFrame([(5,)], ['n']) + >>> df.select(factorial(df.n).alias('f')).collect() + [Row(f=120)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.factorial(_to_java_column(col))) + + +# --------------- Window functions ------------------------ + +@since(1.4) +def lag(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows before the current row, and + `defaultValue` if there is less than `offset` rows before the current row. For example, + an `offset` of one will return the previous row at any given point in the window partition. + + This is equivalent to the LAG function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lag(_to_java_column(col), count, default)) + + +@since(1.4) +def lead(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows after the current row, and + `defaultValue` if there is less than `offset` rows after the current row. For example, + an `offset` of one will return the next row at any given point in the window partition. + + This is equivalent to the LEAD function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lead(_to_java_column(col), count, default)) + + +@since(1.4) +def ntile(n): + """ + Window function: returns the ntile group id (from 1 to `n` inclusive) + in an ordered window partition. For example, if `n` is 4, the first + quarter of the rows will get value 1, the second quarter will get 2, + the third quarter will get 3, and the last quarter will get 4. + + This is equivalent to the NTILE function in SQL. + + :param n: an integer + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.ntile(int(n))) + + +# ---------------------- Date/Timestamp functions ------------------------------ + +@since(1.5) +def current_date(): + """ + Returns the current date as a date column. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.current_date()) + + +def current_timestamp(): + """ + Returns the current timestamp as a timestamp column. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.current_timestamp()) + + +@ignore_unicode_prefix +@since(1.5) +def date_format(date, format): + """ + Converts a date/timestamp/string to a value of string in the format specified by the date + format given by the second argument. + + A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + pattern letters of the Java class `java.text.SimpleDateFormat` can be used. + + NOTE: Use when ever possible specialized functions like `year`. These benefit from a + specialized implementation. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() + [Row(date=u'04/08/2015')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_format(_to_java_column(date), format)) + + +@since(1.5) +def year(col): + """ + Extract the year of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(year('a').alias('year')).collect() + [Row(year=2015)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.year(_to_java_column(col))) + + +@since(1.5) +def quarter(col): + """ + Extract the quarter of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(quarter('a').alias('quarter')).collect() + [Row(quarter=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.quarter(_to_java_column(col))) + + +@since(1.5) +def month(col): + """ + Extract the month of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(month('a').alias('month')).collect() + [Row(month=4)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.month(_to_java_column(col))) + + +@since(1.5) +def dayofmonth(col): + """ + Extract the day of the month of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(dayofmonth('a').alias('day')).collect() + [Row(day=8)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.dayofmonth(_to_java_column(col))) + + +@since(1.5) +def dayofyear(col): + """ + Extract the day of the year of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(dayofyear('a').alias('day')).collect() + [Row(day=98)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.dayofyear(_to_java_column(col))) + + +@since(1.5) +def hour(col): + """ + Extract the hours of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df.select(hour('a').alias('hour')).collect() + [Row(hour=13)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.hour(_to_java_column(col))) + + +@since(1.5) +def minute(col): + """ + Extract the minutes of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df.select(minute('a').alias('minute')).collect() + [Row(minute=8)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.minute(_to_java_column(col))) + + +@since(1.5) +def second(col): + """ + Extract the seconds of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df.select(second('a').alias('second')).collect() + [Row(second=15)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.second(_to_java_column(col))) + + +@since(1.5) +def weekofyear(col): + """ + Extract the week number of a given date as integer. + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df.select(weekofyear(df.a).alias('week')).collect() + [Row(week=15)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.weekofyear(_to_java_column(col))) + + +@since(1.5) +def date_add(start, days): + """ + Returns the date that is `days` days after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_add(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 9))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) + + +@since(1.5) +def date_sub(start, days): + """ + Returns the date that is `days` days before `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_sub(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 7))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) + + +@since(1.5) +def datediff(end, start): + """ + Returns the number of days from `start` to `end`. + + >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() + [Row(diff=32)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start))) + + +@since(1.5) +def add_months(start, months): + """ + Returns the date that is `months` months after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(add_months(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 5, 8))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) + + +@since(1.5) +def months_between(date1, date2): + """ + Returns the number of months between date1 and date2. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df.select(months_between(df.t, df.d).alias('months')).collect() + [Row(months=3.9495967...)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + + +@since(1.5) +def to_date(col): + """ + Converts the column of StringType or TimestampType into DateType. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_date(_to_java_column(col))) + + +@since(1.5) +def trunc(date, format): + """ + Returns date truncated to the unit specified by the format. + + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + + >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df.select(trunc(df.d, 'year').alias('year')).collect() + [Row(year=datetime.date(1997, 1, 1))] + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + [Row(month=datetime.date(1997, 2, 1))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) + + +@since(1.5) +def next_day(date, dayOfWeek): + """ + Returns the first date which is later than the value of the date column. + + Day of the week parameter is case insensitive, and accepts: + "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + + >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d']) + >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() + [Row(date=datetime.date(2015, 8, 2))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek)) + + +@since(1.5) +def last_day(date): + """ + Returns the last day of the month which the given date belongs to. + + >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d']) + >>> df.select(last_day(df.d).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.last_day(_to_java_column(date))) + + +@since(1.5) +def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): + """ + Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + representing the timestamp of that moment in the current system time zone in the given + format. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) + + +@since(1.5) +def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): + """ + Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default) + to Unix time stamp (in seconds), using the default timezone and the default + locale, return null if fail. + + if `timestamp` is None, then it returns current timestamp. + """ + sc = SparkContext._active_spark_context + if timestamp is None: + return Column(sc._jvm.functions.unix_timestamp()) + return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format)) + + +@since(1.5) +def from_utc_timestamp(timestamp, tz): + """ + Assumes given timestamp is UTC and converts to given timezone. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() + [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) + + +@since(1.5) +def to_utc_timestamp(timestamp, tz): + """ + Assumes given timestamp is in given timezone and converts to UTC. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() + [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) + + +# ---------------------------- misc functions ---------------------------------- + +@since(1.5) +@ignore_unicode_prefix +def crc32(col): + """ + Calculates the cyclic redundancy check value (CRC32) of a binary column and + returns the value as a bigint. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() + [Row(crc32=2743272264)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.crc32(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def md5(col): + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.md5(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + +# ---------------------- String/Binary functions ------------------------------ + +_string_functions = { + 'ascii': 'Computes the numeric value of the first character of the string column.', + 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.', + 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.', + 'initcap': 'Returns a new string column by converting the first letter of each word to ' + + 'uppercase. Words are delimited by whitespace.', + 'lower': 'Converts a string column to lower case.', + 'upper': 'Converts a string column to upper case.', + 'reverse': 'Reverses the string column and returns it as a new string column.', + 'ltrim': 'Trim the spaces from right end for the specified string value.', + 'rtrim': 'Trim the spaces from right end for the specified string value.', + 'trim': 'Trim the spaces from both ends for the specified string column.', +} + + +for _name, _doc in _string_functions.items(): + globals()[_name] = since(1.5)(_create_function(_name, _doc)) +del _name, _doc + + +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input string columns together into a single string column. + + >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +@ignore_unicode_prefix +def concat_ws(sep, *cols): + """ + Concatenates multiple input string columns together into a single string column, + using the given separator. + + >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + [Row(s=u'abcd-123')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def decode(col, charset): + """ + Computes the first argument into a string from a binary using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.decode(_to_java_column(col), charset)) + + +@since(1.5) +def encode(col, charset): + """ + Computes the first argument into a binary from a string using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.encode(_to_java_column(col), charset)) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """ + Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places, + and returns the result as a string. + + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) + + +@ignore_unicode_prefix +@since(1.5) +def format_string(format, *cols): + """ + Formats the arguments in printf-style and returns the result as a string column. + + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + + >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b']) + >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() + [Row(v=u'5 hello')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column))) + + +@since(1.5) +def instr(str, substr): + """ + Locate the position of the first occurrence of substr column in the given string. + Returns null if either of the arguments are null. + + NOTE: The position is not zero based, but 1 based index, returns 0 if substr + could not be found in str. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(instr(df.s, 'b').alias('s')).collect() + [Row(s=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.instr(_to_java_column(str), substr)) + + +@since(1.5) +@ignore_unicode_prefix +def substring(str, pos, len): + """ + Substring starts at `pos` and is of length `len` when str is String type or + returns the slice of byte array that starts at `pos` in byte and is of length `len` + when str is Binary type + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + [Row(s=u'ab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + + +@since(1.5) +@ignore_unicode_prefix +def substring_index(str, delim, count): + """ + Returns the substring from string str before count occurrences of the delimiter delim. + If count is positive, everything the left of the final delimiter (counting from left) is + returned. If count is negative, every to the right of the final delimiter (counting from the + right) is returned. substring_index performs a case-sensitive match when searching for delim. + + >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s']) + >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() + [Row(s=u'a.b')] + >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() + [Row(s=u'b.c.d')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count)) + + +@ignore_unicode_prefix +@since(1.5) +def levenshtein(left, right): + """Computes the Levenshtein distance of the two given strings. + + >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + [Row(d=3)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) + return Column(jc) + + +@since(1.5) +def locate(substr, str, pos=0): + """ + Locate the position of the first occurrence of substr in a string column, after position pos. + + NOTE: The position is not zero based, but 1 based index. returns 0 if substr + could not be found in str. + + :param substr: a string + :param str: a Column of StringType + :param pos: start position (zero based) + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(locate('b', df.s, 1).alias('s')).collect() + [Row(s=2)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos)) + + +@since(1.5) +@ignore_unicode_prefix +def lpad(col, len, pad): + """ + Left-pad the string column to width `len` with `pad`. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() + [Row(s=u'##abcd')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad)) + + +@since(1.5) +@ignore_unicode_prefix +def rpad(col, len, pad): + """ + Right-pad the string column to width `len` with `pad`. + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() + [Row(s=u'abcd##')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad)) + + +@since(1.5) +@ignore_unicode_prefix +def repeat(col, n): + """ + Repeats a string column n times, and returns it as a new string column. + + >>> df = sqlContext.createDataFrame([('ab',)], ['s',]) + >>> df.select(repeat(df.s, 3).alias('s')).collect() + [Row(s=u'ababab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.repeat(_to_java_column(col), n)) + + +@since(1.5) +@ignore_unicode_prefix +def split(str, pattern): + """ + Splits str around pattern (pattern is a regular expression). + + NOTE: pattern is a string represent the regular expression. + + >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',]) + >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() + [Row(s=[u'ab', u'cd'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.split(_to_java_column(str), pattern)) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_extract(str, pattern, idx): + """Extract a specific(idx) group identified by a java regex, from the specified string column. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + [Row(d=u'100')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_replace(str, pattern, replacement): + """Replace all substrings of the specified string value that match regexp with rep. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() + [Row(d=u'-----')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def initcap(col): + """Translate the first letter of each word to upper case in the sentence. + + >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() + [Row(v=u'Ab Cd')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.initcap(_to_java_column(col))) + + +@since(1.5) +@ignore_unicode_prefix +def soundex(col): + """ + Returns the SoundEx encoding for a string + + >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) + >>> df.select(soundex(df.name).alias("soundex")).collect() + [Row(soundex=u'P362'), Row(soundex=u'U612')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.soundex(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def bin(col): + """Returns the string representation of the binary value of the given column. + + >>> df.select(bin(df.age).alias('c')).collect() + [Row(c=u'10'), Row(c=u'101')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.bin(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unhex(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def length(col): + """Calculates the length of a string or binary expression. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def translate(srcCol, matching, replace): + """A function translate any character in the `srcCol` by a character in `matching`. + The characters in `replace` is corresponding to the characters in `matching`. + The translate will happen when any character in the string matching with the character + in the `matching`. + + >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ + .alias('r')).collect() + [Row(r=u'1a2s3ae')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace)) + + +# ---------------------- Collection functions ------------------------------ + +@since(1.4) +def array(*cols): + """Creates a new array column. + + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. + + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.5) +def array_contains(col, value): + """ + Collection function: returns True if the array contains the given value. The collection + elements and value must be of the same type. + + :param col: name of column containing array + :param value: value to check for in array + + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) + + +@since(1.4) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + +@since(1.5) +def size(col): + """ + Collection function: returns the length of the array or map stored in the column. + + :param col: name of column or expression + + >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df.select(size(df.data)).collect() + [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.size(_to_java_column(col))) + + +@since(1.5) +def sort_array(col, asc=True): + """ + Collection function: sorts the input array for the given column in ascending order. + + :param col: name of column or expression + + >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df.select(sort_array(df.data).alias('r')).collect() + [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) + + +# ---------------------------- User Defined Function ---------------------------------- + +class UserDefinedFunction(object): + """ + User defined function in Python + + .. versionadded:: 1.3 + """ + def __init__(self, func, returnType, name=None): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf(name) + + def _create_judf(self, name): + f, returnType = self.func, self.returnType # put them in closure `func` + func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + if name is None: + name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, + sc.pythonExec, sc.pythonVer, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jc = self._judf.apply(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.3) +def udf(f, returnType=StringType()): + """Creates a :class:`Column` expression representing a user defined function (UDF). + + >>> from pyspark.sql.types import IntegerType + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + +blacklist = ['map', 'since', 'ignore_unicode_prefix'] +__all__ = [k for k, v in globals().items() + if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] +__all__.sort() + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.functions + globs = pyspark.sql.functions.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.functions, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py new file mode 100644 index 000000000000..04594d5a836c --- /dev/null +++ b/python/pyspark/sql/group.py @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql import since +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import * + +__all__ = ["GroupedData"] + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + name = f.__name__ + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + @ignore_unicode_prefix + @since(1.3) + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jdf = self._jdf.agg(exprs[0]) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + @since(1.3) + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + @since(1.3) + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(avg(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(avg(age)=3.5, avg(height)=82.5)] + """ + + @df_varargs_api + @since(1.3) + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(avg(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(avg(age)=3.5, avg(height)=82.5)] + """ + + @df_varargs_api + @since(1.3) + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(max(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(max(age)=5, max(height)=85)] + """ + + @df_varargs_api + @since(1.3) + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(min(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(min(age)=2, min(height)=80)] + """ + + @df_varargs_api + @since(1.3) + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(sum(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(sum(age)=7, sum(height)=165)] + """ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.group + globs = pyspark.sql.group.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.group, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py new file mode 100644 index 000000000000..78247c8fa737 --- /dev/null +++ b/python/pyspark/sql/readwriter.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import JavaClass + +from pyspark.sql import since +from pyspark.sql.column import _to_seq +from pyspark.sql.types import * + +__all__ = ["DataFrameReader", "DataFrameWriter"] + + +def to_str(value): + """ + A wrapper over str(), but convert bool values to lower case string + """ + if isinstance(value, bool): + return str(value).lower() + else: + return str(value) + + +class DataFrameReader(object): + """ + Interface used to load a :class:`DataFrame` from external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read` + to access this. + + ::Note: Experimental + + .. versionadded:: 1.4 + """ + + def __init__(self, sqlContext): + self._jreader = sqlContext._ssql_ctx.read() + self._sqlContext = sqlContext + + def _df(self, jdf): + from pyspark.sql.dataframe import DataFrame + return DataFrame(jdf, self._sqlContext) + + @since(1.4) + def format(self, source): + """Specifies the input data source format. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.5) + def option(self, key, value): + """Adds an input option for the underlying data source. + """ + self._jreader = self._jreader.option(key, to_str(value)) + return self + + @since(1.4) + def options(self, **options): + """Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, to_str(options[k])) + return self + + @since(1.4) + def load(self, path=None, format=None, schema=None, **options): + """Loads data from a data source and returns it as a :class`DataFrame`. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`StructType` for the input schema. + :param options: all other string options + + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + ... opt2=1, opt3='str') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + return self._df(self._jreader.load(path)) + else: + return self._df(self._jreader.load()) + + @since(1.4) + def json(self, path, schema=None): + """ + Loads a JSON file (one object per line) and returns the result as + a :class`DataFrame`. + + If the ``schema`` parameter is not specified, this function goes + through the input once to determine the input schema. + + :param path: string, path to the JSON dataset. + :param schema: an optional :class:`StructType` for the input schema. + + >>> df = sqlContext.read.json('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) + + @since(1.4) + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + :param tableName: string, name of the table. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.registerTempTable('tmpTable') + >>> sqlContext.read.table('tmpTable').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + return self._df(self._jreader.table(tableName)) + + @since(1.4) + def parquet(self, *paths): + """Loads a Parquet file, returning the result as a :class:`DataFrame`. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + + @since(1.5) + def orc(self, path): + """ + Loads an ORC file, returning the result as a :class:`DataFrame`. + + ::Note: Currently ORC support is only available together with + :class:`HiveContext`. + + >>> df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') + >>> df.dtypes + [('a', 'bigint'), ('b', 'int'), ('c', 'int')] + """ + return self._df(self._jreader.orc(path)) + + @since(1.4) + def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, + predicates=None, properties=None): + """ + Construct a :class:`DataFrame` representing the database table accessible + via JDBC URL `url` named `table` and connection `properties`. + + The `column` parameter could be used to partition the table, then it will + be retrieved in parallel based on the parameters passed to this function. + + The `predicates` parameter gives a list expressions suitable for inclusion + in WHERE clauses; each one defines one partition of the :class:`DataFrame`. + + ::Note: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. + + :param url: a JDBC URL + :param table: name of table + :param column: the column used to partition + :param lowerBound: the lower bound of partition column + :param upperBound: the upper bound of the partition column + :param numPartitions: the number of partitions + :param predicates: a list of expressions + :param properties: JDBC database connection arguments, a list of arbitrary string + tag/value. Normally at least a "user" and "password" property + should be included. + :return: a DataFrame + """ + if properties is None: + properties = dict() + jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + for k in properties: + jprop.setProperty(k, properties[k]) + if column is not None: + if numPartitions is None: + numPartitions = self._sqlContext._sc.defaultParallelism + return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), + int(numPartitions), jprop)) + if predicates is not None: + arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) + return self._df(self._jreader.jdbc(url, table, arr, jprop)) + return self._df(self._jreader.jdbc(url, table, jprop)) + + +class DataFrameWriter(object): + """ + Interface used to write a [[DataFrame]] to external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` + to access this. + + ::Note: Experimental + + .. versionadded:: 1.4 + """ + def __init__(self, df): + self._df = df + self._sqlContext = df.sql_ctx + self._jwrite = df._jdf.write() + + @since(1.4) + def mode(self, saveMode): + """Specifies the behavior when data or table already exists. + + Options include: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + # At the JVM side, the default value of mode is already set to "error". + # So, if the given saveMode is None, we will not call JVM-side's mode method. + if saveMode is not None: + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """Specifies the underlying output data source. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.5) + def option(self, key, value): + """Adds an output option for the underlying data source. + """ + self._jwrite = self._jwrite.option(key, value) + return self + + @since(1.4) + def options(self, **options): + """Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """Partitions the output by the given columns on the file system. + + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + + >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + + @since(1.4) + def save(self, path=None, format=None, mode=None, partitionBy=None, **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + :param options: all other string options + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if path is None: + self._jwrite.save() + else: + self._jwrite.save(path) + + @since(1.4) + def insertInto(self, tableName, overwrite=False): + """Inserts the content of the :class:`DataFrame` to the specified table. + + It requires that the schema of the class:`DataFrame` is the same as the + schema of the table. + + Optionally overwriting any existing data. + """ + self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) + + @since(1.4) + def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options): + """Saves the content of the :class:`DataFrame` as the specified table. + + In the case the table already exists, behavior of this function depends on the + save mode, specified by the `mode` function (default to throwing an exception). + When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + the same as that of the existing table. + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + + :param name: the table name + :param format: the format used to save + :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param partitionBy: names of partitioning columns + :param options: all other string options + """ + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + self._jwrite.saveAsTable(name) + + @since(1.4) + def json(self, path, mode=None): + """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode)._jwrite.json(path) + + @since(1.4) + def parquet(self, path, mode=None, partitionBy=None): + """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + + >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode) + if partitionBy is not None: + self.partitionBy(partitionBy) + self._jwrite.parquet(path) + + def orc(self, path, mode=None, partitionBy=None): + """Saves the content of the :class:`DataFrame` in ORC format at the specified path. + + ::Note: Currently ORC support is only available together with + :class:`HiveContext`. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + + >>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') + >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode) + if partitionBy is not None: + self.partitionBy(partitionBy) + self._jwrite.orc(path) + + @since(1.4) + def jdbc(self, url, table, mode=None, properties=None): + """Saves the content of the :class:`DataFrame` to a external database table via JDBC. + + .. note:: Don't create too many partitions in parallel on a large cluster;\ + otherwise Spark might crash your external database systems. + + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` + :param table: Name of the table in the external database. + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param properties: JDBC database connection arguments, a list of + arbitrary string tag/value. Normally at least a + "user" and "password" property should be included. + """ + if properties is None: + properties = dict() + jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + for k in properties: + jprop.setProperty(k, properties[k]) + self._jwrite.mode(mode).jdbc(url, table, jprop) + + +def _test(): + import doctest + import os + import tempfile + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext, HiveContext + import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.readwriter.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['hiveContext'] = HiveContext(sc) + globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.readwriter, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py new file mode 100644 index 000000000000..aacfb34c7761 --- /dev/null +++ b/python/pyspark/sql/tests.py @@ -0,0 +1,1157 @@ +# -*- encoding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile +import pickle +import functools +import time +import datetime + +import py4j + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql.types import * +from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction, sha2 +from pyspark.sql.window import Window +from pyspark.sql.utils import AnalysisException, IllegalArgumentException + + +class UTC(datetime.tzinfo): + """UTC""" + ZERO = datetime.timedelta(0) + + def utcoffset(self, dt): + return self.ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return self.ZERO + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.sql.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, self.__class__) and \ + other.x == self.x and other.y == self.y + + +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEquals(lt, lt2) + + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData, 2) + cls.df = rdd.toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_row_should_be_read_only(self): + row = Row(a=1, b=2) + self.assertEqual(1, row.a) + + def foo(): + row.a = 3 + self.assertRaises(Exception, foo) + + row2 = self.sqlCtx.range(10).first() + self.assertEqual(0, row2.id) + + def foo2(): + row2.id = 2 + self.assertRaises(Exception, foo2) + + def test_range(self): + self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) + self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) + self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.sqlCtx.range(-2).count(), 0) + self.assertEqual(self.sqlCtx.range(3).count(), 3) + + def test_duplicated_column_names(self): + df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + row = df.select('*').first() + self.assertEqual(1, row[0]) + self.assertEqual(2, row[1]) + self.assertEqual("Row(c=1, c=2)", str(row)) + # Cannot access columns + self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) + self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) + self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=list(range(3)), d={"key": list(range(5))})] + rdd = self.sc.parallelize(d) + self.sqlCtx.createDataFrame(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(list(range(3)), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}, s=None), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.sqlCtx.inferSchema(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.sqlCtx.inferSchema(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.sqlCtx.inferSchema(rdd) + self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + + def test_select_null_literal(self): + df = self.sqlCtx.sql("select null as col") + self.assertEquals(Row(col=None), df.first()) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.sqlCtx.createDataFrame(rdd, schema) + results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + df.registerTempTable("table2") + r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3])]) + abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" + schema = _parse_schema_abstract(abstract) + typedSchema = _infer_schema_type(rdd.first(), schema) + df = self.sqlCtx.createDataFrame(rdd, typedSchema) + r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) + self.assertEqual(r, tuple(df.first())) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + df = self.sc.parallelize(d).toDF() + k, v = list(df.head().m.items())[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + df = self.sc.parallelize([row]).toDF() + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + + def test_udf_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.write.parquet(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_freqItems(self): + vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] + df = self.sc.parallelize(vals).toDF() + items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] + self.assertTrue(1 in items[0]) + self.assertTrue(-2.0 in items[1]) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + + def test_corr(self): + import math + df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() + corr = df.stat.corr("a", "b") + self.assertTrue(abs(corr - 0.95734012) < 1e-6) + + def test_cov(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + cov = df.stat.cov("a", "b") + self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + + def test_crosstab(self): + df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + ct = df.stat.crosstab("a", "b").collect() + ct = sorted(ct, key=lambda x: x[0]) + for i, row in enumerate(ct): + self.assertEqual(row[0], str(i)) + self.assertTrue(row[1], 1) + self.assertTrue(row[2], 1) + + def test_math_functions(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + from pyspark.sql import functions + import math + + def get_values(l): + return [j[0] for j in l] + + def assert_close(a, b): + c = get_values(b) + diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] + return sum(diff) == len(a) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos(df.a)).collect()) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos("a")).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df.a)).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df['a'])).collect()) + assert_close([math.pow(i, 2 * i) for i in range(10)], + df.select(functions.pow(df.a, df.b)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2.0)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot(df.a, df.b)).collect()) + + def test_rand_functions(self): + df = self.df + from pyspark.sql import functions + rnd = df.select('key', functions.rand()).collect() + for row in rnd: + assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] + rndn = df.select('key', functions.randn(5)).collect() + for row in rndn: + assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + + # If the specified seed is 0, we should use it. + # https://issues.apache.org/jira/browse/SPARK-9691 + rnd1 = df.select('key', functions.rand(0)).collect() + rnd2 = df.select('key', functions.rand(0)).collect() + self.assertEqual(sorted(rnd1), sorted(rnd2)) + + rndn1 = df.select('key', functions.randn(0)).collect() + rndn2 = df.select('key', functions.randn(0)).collect() + self.assertEqual(sorted(rndn1), sorted(rndn2)) + + def test_between_function(self): + df = self.sc.parallelize([ + Row(a=1, b=2, c=3), + Row(a=2, b=1, c=3), + Row(a=4, b=1, c=4)]).toDF() + self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], + df.filter(df.a.between(df.b, df.c)).collect()) + + def test_struct_type(self): + from pyspark.sql.types import StructType, StringType, StructField + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + try: + struct1 = StructType().add("name") + self.assertEqual(1, 0) + except ValueError: + self.assertEqual(1, 1) + + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.json(tmpPath, "overwrite") + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.save(format="json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_save_and_load_builder(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.mode("overwrite").json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ + .option("noUse", "this option will not be used in save.")\ + .format("json").save(path=tmpPath) + actual =\ + self.sqlCtx.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_column_name_with_non_ascii(self): + df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) + self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) + self.assertEqual("DataFrame[数量: bigint]", str(df)) + self.assertEqual([("数量", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("数量").first()[0]) + self.assertEqual(1, df.select(df["数量"]).first()[0]) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + + def test_field_accessor(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) + self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + df = self.sc.parallelize(longrow).toDF() + self.assertEqual(df.schema.fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + df.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + + def test_time_with_timezone(self): + day = datetime.date.today() + now = datetime.datetime.now() + ts = time.mktime(now.timetuple()) + # class in __main__ is not serializable + from pyspark.sql.tests import UTC + utc = UTC() + utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds + # add microseconds to utcnow (keeping year,month,day,hour,minute,second) + utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) + df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + day1, now1, utcnow1 = df.first() + self.assertEqual(day1, day) + self.assertEqual(now, now1) + self.assertEqual(now, utcnow1) + + def test_decimal(self): + from decimal import Decimal + schema = StructType([StructField("decimal", DecimalType(10, 5))]) + df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + row = df.select(df.decimal + 1).first() + self.assertEqual(row[0], Decimal("4.14159")) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.parquet(tmpPath) + df2 = self.sqlCtx.read.parquet(tmpPath) + row = df2.first() + self.assertEqual(row[0], Decimal("3.14159")) + + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # fillna shouldn't change non-null values + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with string + row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + + def test_bitwise_operations(self): + from pyspark.sql import functions + row = Row(a=170, b=75) + df = self.sqlCtx.createDataFrame([row]) + result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() + self.assertEqual(170 & 75, result['(a & b)']) + result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() + self.assertEqual(170 | 75, result['(a | b)']) + result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict() + self.assertEqual(170 ^ 75, result['(a ^ b)']) + result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() + self.assertEqual(~75, result['~b']) + + def test_expr(self): + from pyspark.sql import functions + row = Row(a="length string", b=75) + df = self.sqlCtx.createDataFrame([row]) + result = df.select(functions.expr("length(a)")).collect()[0].asDict() + self.assertEqual(13, result["'length(a)"]) + + def test_replace(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # replace with int + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() + self.assertEqual(row.age, 20) + self.assertEqual(row.height, 20.0) + + # replace with double + row = self.sqlCtx.createDataFrame( + [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() + self.assertEqual(row.age, 82) + self.assertEqual(row.height, 82.1) + + # replace with string + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() + self.assertEqual(row.name, u"Ann") + self.assertEqual(row.age, 10) + + # replace with subset specified by a string of a column name w/ actual change + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() + self.assertEqual(row.age, 20) + + # replace with subset specified by a string of a column name w/o actual change + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() + self.assertEqual(row.age, 10) + + # replace with subset specified with one column replaced, another column not in subset + # stays unchanged. + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() + self.assertEqual(row.name, u'Alice') + self.assertEqual(row.age, 20) + self.assertEqual(row.height, 10.0) + + # replace with subset specified but no column will be replaced + row = self.sqlCtx.createDataFrame( + [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() + self.assertEqual(row.name, u'Alice') + self.assertEqual(row.age, 10) + self.assertEqual(row.height, None) + + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + # RuntimeException should not be captured + self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + + def test_capture_illegalargument_exception(self): + self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", + lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1")) + df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) + self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", + lambda: df.select(sha2(df.a, 1024)).collect()) + + def test_with_column_with_existing_name(self): + keys = self.df.withColumn("key", self.df.key).select("key").collect() + self.assertEqual([r.key for r in keys], list(range(100))) + + +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + except TypeError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + os.unlink(cls.tempdir.name) + _scala_HiveContext =\ + cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) + cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json") + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_window_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.partitionBy("value").orderBy("key") + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 1, 1, 1, 1, 1), + ("2", 1, 1, 1, 3, 1, 1, 1, 1), + ("2", 1, 2, 1, 3, 2, 1, 1, 1), + ("2", 2, 2, 2, 3, 3, 3, 2, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_without_partitionBy(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py new file mode 100644 index 000000000000..ed4e5b594bd6 --- /dev/null +++ b/python/pyspark/sql/types.py @@ -0,0 +1,1310 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import decimal +import time +import datetime +import calendar +import json +import re +import base64 +from array import array + +if sys.version >= "3": + long = int + unicode = str + +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass + +from pyspark.serializers import CloudPickleSerializer + +__all__ = [ + "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", + "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] + + +class DataType(object): + """Base class for data types.""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def simpleString(self): + return self.typeName() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + + def needConversion(self): + """ + Does this type need to conversion between Python object and internal SQL object. + + This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. + """ + return False + + def toInternal(self, obj): + """ + Converts a Python object into an internal SQL object. + """ + return obj + + def fromInternal(self, obj): + """ + Converts an internal SQL object into a native Python object. + """ + return obj + + +# This singleton pattern does not work with pickle, you will get +# another object after pickle and unpickle +class DataTypeSingleton(type): + """Metaclass for DataType""" + + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(DataTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class NullType(DataType): + """Null type. + + The data type representing None, used for the types that cannot be inferred. + """ + + __metaclass__ = DataTypeSingleton + + +class AtomicType(DataType): + """An internal type used to represent everything that is not + null, UDTs, arrays, structs, and maps.""" + + +class NumericType(AtomicType): + """Numeric data types. + """ + + +class IntegralType(NumericType): + """Integral data types. + """ + + __metaclass__ = DataTypeSingleton + + +class FractionalType(NumericType): + """Fractional data types. + """ + + +class StringType(AtomicType): + """String data type. + """ + + __metaclass__ = DataTypeSingleton + + +class BinaryType(AtomicType): + """Binary (byte array) data type. + """ + + __metaclass__ = DataTypeSingleton + + +class BooleanType(AtomicType): + """Boolean data type. + """ + + __metaclass__ = DataTypeSingleton + + +class DateType(AtomicType): + """Date (datetime.date) data type. + """ + + __metaclass__ = DataTypeSingleton + + EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def needConversion(self): + return True + + def toInternal(self, d): + return d and d.toordinal() - self.EPOCH_ORDINAL + + def fromInternal(self, v): + return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + + +class TimestampType(AtomicType): + """Timestamp (datetime.datetime) data type. + """ + + __metaclass__ = DataTypeSingleton + + def needConversion(self): + return True + + def toInternal(self, dt): + if dt is not None: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e6 + dt.microsecond) + + def fromInternal(self, ts): + if ts is not None: + # using int to avoid precision loss in float + return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) + + +class DecimalType(FractionalType): + """Decimal (decimal.Decimal) data type. + + The DecimalType must have fixed precision (the maximum total number of digits) + and scale (the number of digits on the right of dot). For example, (5, 2) can + support the value from [-999.99 to 999.99]. + + The precision can be up to 38, the scale must less or equal to precision. + + When create a DecimalType, the default precision and scale is (10, 0). When infer + schema from decimal.Decimal objects, it will be DecimalType(38, 18). + + :param precision: the maximum total number of digits (default: 10) + :param scale: the number of digits on right side of dot. (default: 0) + """ + + def __init__(self, precision=10, scale=0): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = True # this is public API + + def simpleString(self): + return "decimal(%d,%d)" % (self.precision, self.scale) + + def jsonValue(self): + return "decimal(%d,%d)" % (self.precision, self.scale) + + def __repr__(self): + return "DecimalType(%d,%d)" % (self.precision, self.scale) + + +class DoubleType(FractionalType): + """Double data type, representing double precision floats. + """ + + __metaclass__ = DataTypeSingleton + + +class FloatType(FractionalType): + """Float data type, representing single precision floats. + """ + + __metaclass__ = DataTypeSingleton + + +class ByteType(IntegralType): + """Byte data type, i.e. a signed integer in a single byte. + """ + def simpleString(self): + return 'tinyint' + + +class IntegerType(IntegralType): + """Int data type, i.e. a signed 32-bit integer. + """ + def simpleString(self): + return 'int' + + +class LongType(IntegralType): + """Long data type, i.e. a signed 64-bit integer. + + If the values are beyond the range of [-9223372036854775808, 9223372036854775807], + please use :class:`DecimalType`. + """ + def simpleString(self): + return 'bigint' + + +class ShortType(IntegralType): + """Short data type, i.e. a signed 16-bit integer. + """ + def simpleString(self): + return 'smallint' + + +class ArrayType(DataType): + """Array data type. + + :param elementType: :class:`DataType` of each element in the array. + :param containsNull: boolean, whether the array can contain null (None) values. + """ + + def __init__(self, elementType, containsNull=True): + """ + >>> ArrayType(StringType()) == ArrayType(StringType(), True) + True + >>> ArrayType(StringType(), False) == ArrayType(StringType()) + False + """ + assert isinstance(elementType, DataType), "elementType should be DataType" + self.elementType = elementType + self.containsNull = containsNull + + def simpleString(self): + return 'array<%s>' % self.elementType.simpleString() + + def __repr__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + + def needConversion(self): + return self.elementType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.toInternal(v) for v in obj] + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.fromInternal(v) for v in obj] + + +class MapType(DataType): + """Map data type. + + :param keyType: :class:`DataType` of the keys in the map. + :param valueType: :class:`DataType` of the values in the map. + :param valueContainsNull: indicates whether values can contain null (None) values. + + Keys in a map data type are not allowed to be null (None). + """ + + def __init__(self, keyType, valueType, valueContainsNull=True): + """ + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) + True + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) + False + """ + assert isinstance(keyType, DataType), "keyType should be DataType" + assert isinstance(valueType, DataType), "valueType should be DataType" + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def simpleString(self): + return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) + + def __repr__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + + def needConversion(self): + return self.keyType.needConversion() or self.valueType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) + for k, v in obj.items()) + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) + for k, v in obj.items()) + + +class StructField(DataType): + """A field in :class:`StructType`. + + :param name: string, name of the field. + :param dataType: :class:`DataType` of the field. + :param nullable: boolean, whether the field can be null (None) or not. + :param metadata: a dict from string to simple type that can be toInternald to JSON automatically + """ + + def __init__(self, name, dataType, nullable=True, metadata=None): + """ + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) + True + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) + False + """ + assert isinstance(dataType, DataType), "dataType should be DataType" + if not isinstance(name, str): + name = name.encode('utf-8') + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def simpleString(self): + return '%s:%s' % (self.name, self.dataType.simpleString()) + + def __repr__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) + + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"], + json["metadata"]) + + def needConversion(self): + return self.dataType.needConversion() + + def toInternal(self, obj): + return self.dataType.toInternal(obj) + + def fromInternal(self, obj): + return self.dataType.fromInternal(obj) + + +class StructType(DataType): + """Struct type, consisting of a list of :class:`StructField`. + + This is the data type representing a :class:`Row`. + """ + def __init__(self, fields=None): + """ + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) + >>> struct1 == struct2 + False + """ + if not fields: + self.fields = [] + self.names = [] + else: + self.fields = fields + self.names = [f.name for f in fields] + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a + DataType object. + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + self.names.append(field.name) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + return self + + def simpleString(self): + return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + + def __repr__(self): + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) + + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} + + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) + + def needConversion(self): + # We need convert Row()/namedtuple into tuple() + return True + + def toInternal(self, obj): + if obj is None: + return + + if self._needSerializeAnyField: + if isinstance(obj, dict): + return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) + elif isinstance(obj, (tuple, list)): + return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + else: + if isinstance(obj, dict): + return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, (list, tuple)): + return tuple(obj) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + + def fromInternal(self, obj): + if obj is None: + return + if isinstance(obj, Row): + # it's already converted by pickler + return obj + if self._needSerializeAnyField: + values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + else: + values = obj + return _create_row(self.names, values) + + +class UserDefinedType(DataType): + """User-defined type (UDT). + + .. note:: WARN: Spark Internal Use Only + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT (could be '', if there + is no corresponding one). + """ + return '' + + def needConversion(self): + return True + + @classmethod + def _cachedSqlType(cls): + """ + Cache the sqlType() into class, because it's heavy used in `toInternal`. + """ + if not hasattr(cls, "_cached_sql_type"): + cls._cached_sql_type = cls.sqlType() + return cls._cached_sql_type + + def toInternal(self, obj): + return self._cachedSqlType().toInternal(self.serialize(obj)) + + def fromInternal(self, obj): + return self.deserialize(self._cachedSqlType().fromInternal(obj)) + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement toInternal().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement fromInternal().") + + def simpleString(self): + return 'udt' + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + if self.scalaUDT(): + assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + else: + ser = CloudPickleSerializer() + b = ser.dumps(type(self)) + schema = { + "type": "udt", + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "serializedClass": base64.b64encode(b).decode('utf8'), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = str(json["pyClass"]) # convert unicode to str + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass]) + if not hasattr(m, pyClass): + s = base64.b64decode(json['serializedClass'].encode('utf-8')) + UDT = CloudPickleSerializer().loads(s) + else: + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + +_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, NullType] +_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) + + +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. + >>> import pickle + >>> def check_datatype(datatype): + ... pickled = pickle.loads(pickle.dumps(datatype)) + ... assert datatype == pickled + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) + ... assert datatype == python_datatype + >>> for cls in _all_atomic_types.values(): + ... check_datatype(cls()) + + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) + >>> check_datatype(complex_structtype) + + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) + >>> check_datatype(complex_maptype) + """ + return _parse_datatype_json_value(json.loads(json_string)) + + +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + +def _parse_datatype_json_value(json_value): + if not isinstance(json_value, dict): + if json_value in _all_atomic_types.keys(): + return _all_atomic_types[json_value]() + elif json_value == 'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) + else: + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) + + +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: LongType, + float: DoubleType, + str: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, + datetime.time: TimestampType, +} + +if sys.version < "3": + _type_mappings.update({ + unicode: StringType, + long: LongType, + }) + + +def _infer_type(obj): + """Infer the DataType from obj + """ + if obj is None: + return NullType() + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is DecimalType: + # the precision and scale of `obj` may be different from row to row. + return DecimalType(38, 18) + elif dataType is not None: + return dataType() + + if isinstance(obj, dict): + for key, value in obj.items(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) + elif isinstance(obj, (list, array)): + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) + else: + try: + return _infer_schema(obj) + except TypeError: + raise TypeError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, (tuple, list)): + if hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) + elif hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + else: + names = ['_%d' % i for i in range(1, len(row) + 1)] + items = zip(names, row) + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise TypeError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + +def _create_converter(dataType): + """Create an converter to drop the names of fields in obj """ + if not _need_converter(dataType): + return lambda x: x + + if isinstance(dataType, ArrayType): + conv = _create_converter(dataType.elementType) + return lambda row: [conv(v) for v in row] + + elif isinstance(dataType, MapType): + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) + + elif isinstance(dataType, NullType): + return lambda x: None + + elif not isinstance(dataType, StructType): + return lambda x: x + + # dataType must be StructType + names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, (tuple, list)): + if convert_fields: + return tuple(conv(v) for v, conv in zip(obj, converters)) + else: + return tuple(obj) + + if isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ + else: + raise TypeError("Unexpected obj type: %s" % type(obj)) + + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) + + return convert_struct + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,NullType,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,NullType,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(NullType,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, NullType(), True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,NullType,true))),true) + """ + s = s.strip() + if not s: + return NullType() + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(NullType(), _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types inferred from obj + + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) + >>> _infer_schema_type(row, schema) + StructType...LongType...DoubleType...StringType...DateType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,LongType... + """ + if isinstance(dataType, NullType): + return _infer_type(obj) + + if not obj: + return NullType() + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = next(iter(obj.items())) + return MapType(_infer_schema_type(k, dataType.keyType), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise TypeError("Unexpected dataType: %s" % type(dataType)) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + BinaryType: (bytearray,), + DateType: (datetime.date, datetime.datetime), + TimestampType: (datetime.datetime,), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, LongType()) + >>> _verify_type(list(range(3)), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + # StringType can work with any types + if isinstance(dataType, StringType): + return + + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.toInternal(obj), dataType.sqlType()) + return + + _type = type(dataType) + assert _type in _acceptable_types, "unknown datatype: %s" % dataType + + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be fromInternald in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.items(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + + +# This is used to unpickle a Row from JVM +def _create_row_inbound_converter(dataType): + return lambda *a: dataType.fromInternal(a) + + +def _create_row(fields, values): + row = Row(*values) + row.__fields__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + row = tuple.__new__(self, [kwargs[n] for n in names]) + row.__fields__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self, recursive=False): + """ + Return as an dict + + :param recursive: turns the nested Row as dict (default: False). + + >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + True + >>> row = Row(key=1, value=Row(name='a', age=2)) + >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')} + True + >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + True + """ + if not hasattr(self, "__fields__"): + raise TypeError("Cannot convert a Row class into dict") + + if recursive: + def conv(obj): + if isinstance(obj, Row): + return obj.asDict(True) + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + return dict(zip(self.__fields__, (conv(o) for o in self))) + else: + return dict(zip(self.__fields__, self)) + + # let object acts like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + except ValueError: + raise AttributeError(item) + + def __setattr__(self, key, value): + if key != '__fields__': + raise Exception("Row is read-only") + self.__dict__[key] = value + + def __reduce__(self): + """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, "__fields__"): + return (_create_row, (self.__fields__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + """Printable representation of Row used in Python REPL.""" + if hasattr(self, "__fields__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__fields__, tuple(self))) + else: + return "" % ", ".join(self) + + +class DateConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.date) + + def convert(self, obj, gateway_client): + Date = JavaClass("java.sql.Date", gateway_client) + return Date.valueOf(obj.strftime("%Y-%m-%d")) + + +class DatetimeConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.datetime) + + def convert(self, obj, gateway_client): + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) + return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) + + +# datetime is a subclass of date, we should register DatetimeConverter first +register_input_converter(DatetimeConverter()) +register_input_converter(DateConverter()) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py new file mode 100644 index 000000000000..0f795ca35b38 --- /dev/null +++ b/python/pyspark/sql/utils.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j + + +class AnalysisException(Exception): + """ + Failed to analyze a SQL query plan. + """ + + +class IllegalArgumentException(Exception): + """ + Passed an illegal or inappropriate argument. + """ + + +def capture_sql_exception(f): + def deco(*a, **kw): + try: + return f(*a, **kw) + except py4j.protocol.Py4JJavaError as e: + s = e.java_exception.toString() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + raise AnalysisException(s.split(': ', 1)[1]) + if s.startswith('java.lang.IllegalArgumentException: '): + raise IllegalArgumentException(s.split(': ', 1)[1]) + raise + return deco + + +def install_exception_handler(): + """ + Hook an exception handler into Py4j, which could capture some SQL exceptions in Java. + + When calling Java API, it will call `get_return_value` to parse the returned object. + If any exception happened in JVM, the result will be Java exception object, it raise + py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that + could capture the Java exception and throw a Python one (with the same error message). + + It's idempotent, could be called multiple times. + """ + original = py4j.protocol.get_return_value + # The original `get_return_value` is not patched, it's idempotent. + patched = capture_sql_exception(original) + # only patch the one used in in py4j.java_gateway (call Java API) + py4j.java_gateway.get_return_value = patched diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py new file mode 100644 index 000000000000..eaf4d7e98620 --- /dev/null +++ b/python/pyspark/sql/window.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext +from pyspark.sql import since +from pyspark.sql.column import _to_seq, _to_java_column + +__all__ = ["Window", "WindowSpec"] + + +def _to_java_cols(cols): + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return _to_seq(sc, cols, _to_java_column) + + +class Window(object): + """ + Utility functions for defining window in DataFrames. + + For example: + + >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0) + + >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING + >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + @staticmethod + @since(1.4) + def partitionBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + @staticmethod + @since(1.4) + def orderBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + +class WindowSpec(object): + """ + A window specification that defines the partitioning, ordering, + and frame boundaries. + + Use the static methods in :class:`Window` to create a :class:`WindowSpec`. + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + + _JAVA_MAX_LONG = (1 << 63) - 1 + _JAVA_MIN_LONG = - (1 << 63) + + def __init__(self, jspec): + self._jspec = jspec + + @since(1.4) + def partitionBy(self, *cols): + """ + Defines the partitioning columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols))) + + @since(1.4) + def orderBy(self, *cols): + """ + Defines the ordering columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.orderBy(_to_java_cols(cols))) + + @since(1.4) + def rowsBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rowsBetween(start, end)) + + @since(1.4) + def rangeBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rangeBetween(start, end)) + + +def _test(): + import doctest + SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 1e597d64e03f..0fee3b209682 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -30,8 +30,10 @@ class StatCounter(object): - def __init__(self, values=[]): - self.n = 0L # Running count of our values + def __init__(self, values=None): + if values is None: + values = list() + self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) self.maxValue = float("-inf") @@ -87,7 +89,7 @@ def copy(self): return copy.deepcopy(self) def count(self): - return self.n + return int(self.n) def mean(self): return self.mu diff --git a/python/pyspark/status.py b/python/pyspark/status.py new file mode 100644 index 000000000000..a6fa7dd3144d --- /dev/null +++ b/python/pyspark/status.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"] + + +class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")): + """ + Exposes information about Spark Jobs. + """ + + +class SparkStageInfo(namedtuple("SparkStageInfo", + "stageId currentAttemptId name numTasks numActiveTasks " + "numCompletedTasks numFailedTasks")): + """ + Exposes information about Spark Stages. + """ + + +class StatusTracker(object): + """ + Low-level status reporting APIs for monitoring job and stage progress. + + These APIs intentionally provide very weak consistency semantics; + consumers of these APIs should be prepared to handle empty / missing + information. For example, a job's stage ids may be known but the status + API may not have any information about the details of those stages, so + `getStageInfo` could potentially return `None` for a valid stage id. + + To limit memory usage, these APIs only provide information on recent + jobs / stages. These APIs will provide information for the last + `spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs. + """ + def __init__(self, jtracker): + self._jtracker = jtracker + + def getJobIdsForGroup(self, jobGroup=None): + """ + Return a list of all known jobs in a particular job group. If + `jobGroup` is None, then returns all known jobs that are not + associated with a job group. + + The returned list may contain running, failed, and completed jobs, + and may vary across invocations of this method. This method does + not guarantee the order of the elements in its result. + """ + return list(self._jtracker.getJobIdsForGroup(jobGroup)) + + def getActiveStageIds(self): + """ + Returns an array containing the ids of all active stages. + """ + return sorted(list(self._jtracker.getActiveStageIds())) + + def getActiveJobsIds(self): + """ + Returns an array containing the ids of all active jobs. + """ + return sorted((list(self._jtracker.getActiveJobIds()))) + + def getJobInfo(self, jobId): + """ + Returns a :class:`SparkJobInfo` object, or None if the job info + could not be found or was garbage collected. + """ + job = self._jtracker.getJobInfo(jobId) + if job is not None: + return SparkJobInfo(jobId, job.stageIds(), str(job.status())) + + def getStageInfo(self, stageId): + """ + Returns a :class:`SparkStageInfo` object, or None if the stage + info could not be found or was garbage collected. + """ + stage = self._jtracker.getStageInfo(stageId) + if stage is not None: + # TODO: fetch them in batch for better performance + attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]] + return SparkStageInfo(stageId, *attrs) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index d48f3598e33b..4069d7a14998 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from __future__ import print_function + import os import sys -from py4j.java_collections import ListConverter from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf -from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream @@ -84,6 +86,9 @@ class StreamingContext(object): """ _transformerSerializer = None + # Reference to a currently active StreamingContext + _activeContext = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -140,34 +145,84 @@ def getOrCreate(cls, checkpointPath, setupFunc): Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be recreated from the checkpoint data. If the data does not exist, then the provided setupFunc - will be used to create a JavaStreamingContext. + will be used to create a new context. - @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier streaming program + @param setupFunc: Function to create a new context and setup DStreams """ - # TODO: support checkpoint in HDFS - if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + cls._ensure_initialized() + gw = SparkContext._gateway + + # Check whether valid checkpoint information exists in the given path + if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - cls._ensure_initialized() - gw = SparkContext._gateway - try: jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: - print >>sys.stderr, "failed to load StreamingContext from checkpoint" + print("failed to load StreamingContext from checkpoint", file=sys.stderr) raise - jsc = jssc.sparkContext() - conf = SparkConf(_jconf=jsc.getConf()) - sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # If there is already an active instance of Python SparkContext use it, or create a new one + if not SparkContext._active_spark_context: + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + SparkContext(conf=conf, gateway=gw, jsc=jsc) + + sc = SparkContext._active_spark_context + # update ctx in serializer - SparkContext._active_spark_context = sc cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) + @classmethod + def getActive(cls): + """ + Return either the currently active StreamingContext (i.e., if there is a context started + but not stopped) or None. + """ + activePythonContext = cls._activeContext + if activePythonContext is not None: + # Verify that the current running Java StreamingContext is active and is the same one + # backing the supposedly active Python context + activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode() + activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive() + + if activeJvmContextOption.isEmpty(): + cls._activeContext = None + elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId: + cls._activeContext = None + raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext " + "backing the action Python StreamingContext. This is unexpected.") + return cls._activeContext + + @classmethod + def getActiveOrCreate(cls, checkpointPath, setupFunc): + """ + Either return the active StreamingContext (i.e. currently started but not stopped), + or recreate a StreamingContext from checkpoint data or create a new StreamingContext + using the provided setupFunc function. If the checkpointPath is None or does not contain + valid checkpoint data, then setupFunc will be called to create a new context and setup + DStreams. + + @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be + None if the intention is to always create a new context when there + is no active context. + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + + if setupFunc is None: + raise Exception("setupFunc cannot be None") + activeContext = cls.getActive() + if activeContext is not None: + return activeContext + elif checkpointPath is not None: + return cls.getOrCreate(checkpointPath, setupFunc) + else: + return setupFunc() + @property def sparkContext(self): """ @@ -180,6 +235,7 @@ def start(self): Start the execution of the streams. """ self._jssc.start() + StreamingContext._activeContext = self def awaitTermination(self, timeout=None): """ @@ -189,7 +245,16 @@ def awaitTermination(self, timeout=None): if timeout is None: self._jssc.awaitTermination() else: - self._jssc.awaitTermination(int(timeout * 1000)) + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) + + def awaitTerminationOrTimeout(self, timeout): + """ + Wait for the execution to stop. Return `true` if it's stopped; or + throw the reported error during the execution; or `false` if the + waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds + """ + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ @@ -201,6 +266,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) + StreamingContext._activeContext = None if stopSparkContext: self._sc.stop() @@ -251,6 +317,20 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + def binaryRecordsStream(self, directory, recordLength): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as flat binary files with records of + fixed length. Files must be written to the monitored directory by "moving" + them from another location within the same file system. + File names starting with . are ignored. + + @param directory: Directory to load data from + @param recordLength: Length of each record in bytes + """ + return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self, + NoOpSerializer()) + def _check_serializers(self, rdds): # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: @@ -279,9 +359,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): rdds = [self._sc.parallelize(input) for input in rdds] self._check_serializers(rdds) - jrdds = ListConverter().convert([r._jrdd for r in rdds], - SparkContext._gateway._gateway_client) - queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) @@ -296,8 +374,7 @@ def transform(self, dstreams, transformFunc): the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = ListConverter().convert([d._jdstream for d in dstreams], - SparkContext._gateway._gateway_client) + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, lambda t, *rdds: transformFunc(rdds).map(lambda x: x), @@ -320,6 +397,5 @@ def union(self, *dstreams): if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") first = dstreams[0] - jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], - SparkContext._gateway._gateway_client) + jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2fe39392ff08..698336cfce18 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -15,11 +15,15 @@ # limitations under the License. # -from itertools import chain, ifilter, imap +import sys import operator import time +from itertools import chain from datetime import datetime +if sys.version < "3": + from itertools import imap as map, ifilter as filter + from py4j.protocol import Py4JJavaError from pyspark import RDD @@ -76,7 +80,7 @@ def filter(self, f): Return a new DStream containing only the elements that satisfy predicate. """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def flatMap(self, f, preservesPartitioning=False): @@ -85,7 +89,7 @@ def flatMap(self, f, preservesPartitioning=False): this DStream, and then flattening the results """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def map(self, f, preservesPartitioning=False): @@ -93,7 +97,7 @@ def map(self, f, preservesPartitioning=False): Return a new DStream by applying a function to each element of DStream. """ def func(iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitions(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -150,7 +154,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: old_func = func func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) @@ -165,14 +169,14 @@ def pprint(self, num=10): """ def takeAndPrint(time, rdd): taken = rdd.take(num + 1) - print "-------------------------------------------" - print "Time: %s" % time - print "-------------------------------------------" + print("-------------------------------------------") + print("Time: %s" % time) + print("-------------------------------------------") for record in taken[:num]: - print record + print(record) if len(taken) > num: - print "..." - print + print("...") + print("") self.foreachRDD(takeAndPrint) @@ -181,7 +185,7 @@ def mapValues(self, f): Return a new DStream by applying a map function to the value of each key-value pairs in this DStream without changing the key. """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): @@ -189,7 +193,7 @@ def flatMapValues(self, f): Return a new DStream by applying a flatmap function to the value of each key-value pairs in this DStream without changing the key. """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def glom(self): @@ -286,10 +290,10 @@ def transform(self, func): `func` can have one argument of `rdd`, or have two arguments of (`time`, `rdd`) """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: oldfunc = func func = lambda t, rdd: oldfunc(rdd) - assert func.func_code.co_argcount == 2, "func should take one or two arguments" + assert func.__code__.co_argcount == 2, "func should take one or two arguments" return TransformedDStream(self, func) def transformWith(self, func, other, keepSerializer=False): @@ -300,10 +304,10 @@ def transformWith(self, func, other, keepSerializer=False): `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three arguments of (`time`, `rdd_a`, `rdd_b`) """ - if func.func_code.co_argcount == 2: + if func.__code__.co_argcount == 2: oldfunc = func func = lambda t, a, b: oldfunc(a, b) - assert func.func_code.co_argcount == 3, "func should take two or three arguments" + assert func.__code__.co_argcount == 3, "func should take two or three arguments" jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) @@ -460,7 +464,7 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio keyed = self.map(lambda x: (1, x)) reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) - return reduced.map(lambda (k, v): v) + return reduced.map(lambda kv: kv[1]) def countByWindow(self, windowDuration, slideDuration): """ @@ -489,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda (k, v): v > 0).count() + return counted.filter(lambda kv: kv[1] > 0).count() def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ @@ -548,7 +552,8 @@ def reduceFunc(t, a, b): def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: @@ -578,10 +583,10 @@ def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: - g = a.cogroup(b, numPartitions) - g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) - state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) - return state.filter(lambda (k, v): v is not None) + g = a.cogroup(b.partitionBy(numPartitions), numPartitions) + g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None)) + state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1])) + return state.filter(lambda k_v: k_v[1] is not None) jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) @@ -605,7 +610,10 @@ def __init__(self, prev, func): self.is_checkpointed = False self._jdstream_val = None - if (isinstance(prev, TransformedDStream) and + # Using type() to avoid folding the functions and compacting the DStreams which is not + # not strictly a object of TransformedDStream. + # Changed here is to avoid bug in KafkaTransformedDStream when calling offsetRanges(). + if (type(prev) is TransformedDStream and not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func self.func = lambda t, rdd: func(t, prev_func(t, rdd)) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py new file mode 100644 index 000000000000..c0cdc50d8d42 --- /dev/null +++ b/python/pyspark/streaming/flume.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= "3": + from io import BytesIO +else: + from StringIO import StringIO +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int +from pyspark.streaming import DStream + +__all__ = ['FlumeUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + if s is None: + return None + return s.decode('utf-8') + + +class FlumeUtils(object): + + @staticmethod + def createStream(ssc, hostname, port, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + enableDecompression=False, + bodyDecoder=utf8_decoder): + """ + Create an input stream that pulls events from Flume. + + :param ssc: StreamingContext object + :param hostname: Hostname of the slave machine to which the flume data will be sent + :param port: Port of the slave machine to which the flume data will be sent + :param storageLevel: Storage level to use for storing the received objects + :param enableDecompression: Should netty server decompress input stream + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def createPollingStream(ssc, addresses, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + maxBatchSize=1000, + parallelism=5, + bodyDecoder=utf8_decoder): + """ + Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + This stream will poll the sink for data and will pull events as they are available. + + :param ssc: StreamingContext object + :param addresses: List of (host, port)s on which the Spark Sink is running. + :param storageLevel: Storage level to use for storing the received objects + :param maxBatchSize: The maximum number of events to be pulled from the Spark sink + in a single RPC call + :param parallelism: Number of concurrent requests this stream should send to the sink. + Note that having a higher number of requests concurrently being pulled + will result in this stream using more threads + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + hosts = [] + ports = [] + for (host, port) in addresses: + hosts.append(host) + ports.append(port) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def _toPythonDStream(ssc, jstream, bodyDecoder): + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + + def func(event): + headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) + headers = {} + strSer = UTF8Deserializer() + for i in range(0, read_int(headersBytes)): + key = strSer.loads(headersBytes) + value = strSer.loads(headersBytes) + headers[key] = value + body = bodyDecoder(event[1]) + return (headers, body) + return stream.map(func) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Flume libraries not found in class path. Try one of the following. + + 1. Include the Flume library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 19ad71f99d4d..8a814c64c042 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,25 +15,29 @@ # limitations under the License. # -from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JError +from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream +from pyspark.streaming.dstream import TransformedDStream +from pyspark.streaming.util import TransformFunction -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KafkaUtils(object): @staticmethod - def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ @@ -50,8 +54,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ - java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") - + if kafkaParams is None: + kafkaParams = dict() kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, @@ -59,25 +63,295 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, }) if not isinstance(topics, dict): raise TypeError("topics should be dict") - jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) - jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - def getClassByName(name): - return ssc._jvm.org.apache.spark.util.Utils.classForName(name) - try: - array = getClassByName("[B") - decoder = getClassByName("kafka.serializer.DefaultDecoder") - jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, - jparam, jtopics, jlevel) - except Py4JError, e: + # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) + except Py4JJavaError as e: # TODO: use --jar once it also work on driver - if not e.message or 'call a package' in e.message: - print "No kafka package, please put the assembly jar into classpath:" - print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \ - "scala-*/spark-streaming-kafka-assembly-*.jar" + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) - return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v))) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if fromOffsets is None: + fromOffsets = dict() + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) \ + .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders=None, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if leaders is None: + leaders = dict() + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Kafka libraries not found in class path. Try one of the following. + + 1. Include the Kafka library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kafka:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kafka-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self.topic = topic + self.partition = partition + self.fromOffset = fromOffset + self.untilOffset = untilOffset + + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self.topic == other.topic + and self.partition == other.partition + and self.fromOffset == other.fromOffset + and self.untilOffset == other.untilOffset) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \ + % (self.topic, self.partition, self.fromOffset, self.untilOffset) + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self.topic, self.partition, self.fromOffset, + self.untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) + + +class KafkaRDD(RDD): + """ + A Python wrapper of KafkaRDD, to provide additional information on normal RDD. + """ + + def __init__(self, jrdd, ctx, jrdd_deserializer): + RDD.__init__(self, jrdd, ctx, jrdd_deserializer) + + def offsetRanges(self): + """ + Get the OffsetRange of specific KafkaRDD. + :return: A list of OffsetRange + """ + try: + helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(self.ctx) + raise e + + ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset()) + for o in joffsetRanges] + return ranges + + +class KafkaDStream(DStream): + """ + A Python wrapper of KafkaDStream + """ + + def __init__(self, jdstream, ssc, jrdd_deserializer): + DStream.__init__(self, jdstream, ssc, jrdd_deserializer) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.__code__.co_argcount == 1: + old_func = func + func = lambda r, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.__code__.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.__code__.co_argcount == 2, "func should take one or two arguments" + + return KafkaTransformedDStream(self, func) + + +class KafkaTransformedDStream(TransformedDStream): + """ + Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. + """ + + def __init__(self, prev, func): + TransformedDStream.__init__(self, prev, func) + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py new file mode 100644 index 000000000000..34be5880e170 --- /dev/null +++ b/python/pyspark/streaming/kinesis.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.storagelevel import StorageLevel +from pyspark.streaming import DStream + +__all__ = ['KinesisUtils', 'InitialPositionInStream', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + if s is None: + return None + return s.decode('utf-8') + + +class KinesisUtils(object): + + @staticmethod + def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + """ + Create an input stream that pulls messages from a Kinesis stream. This uses the + Kinesis Client Library (KCL) to pull messages from Kinesis. + + Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is + enabled. Make sure that your checkpoint directory is secure. + + :param ssc: StreamingContext object + :param kinesisAppName: Kinesis application name used by the Kinesis Client Library (KCL) to + update DynamoDB + :param streamName: Kinesis stream name + :param endpointUrl: Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + :param regionName: Name of region used by the Kinesis Client Library (KCL) to update + DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + :param initialPositionInStream: In the absence of Kinesis checkpoint info, this is the + worker's initial starting position in the stream. The + values are either the beginning of the stream per Kinesis' + limit of 24 hours (InitialPositionInStream.TRIM_HORIZON) or + the tip of the stream (InitialPositionInStream.LATEST). + :param checkpointInterval: Checkpoint interval for Kinesis checkpointing. See the Kinesis + Spark Streaming documentation for more details on the different + types of checkpoints. + :param storageLevel: Storage level to use for storing the received objects (default is + StorageLevel.MEMORY_AND_DISK_2) + :param awsAccessKeyId: AWS AccessKeyId (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param awsSecretKey: AWS SecretKey (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param decoder: A function used to decode value (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + jduration = ssc._jduration(checkpointInterval) + + try: + # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, jduration, jlevel, + awsAccessKeyId, awsSecretKey) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KinesisUtils._printErrorMsg(ssc.sparkContext) + raise e + stream = DStream(jstream, ssc, NoOpSerializer()) + return stream.map(lambda v: decoder(v)) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Kinesis libraries not found in class path. Try one of the following. + + 1. Include the Kinesis library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kinesis-asl-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) + + +class InitialPositionInStream(object): + LATEST, TRIM_HORIZON = (0, 1) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 000000000000..f06598971c54 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object + :param brokerUrl: Url of remote mqtt publisher + :param topic: topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + MQTTUtils._printErrorMsg(ssc.sparkContext) + raise e + + return DStream(jstream, ssc, UTF8Deserializer()) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's MQTT libraries not found in class path. Try one of the following. + + 1. Include the MQTT library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... +________________________________________________________________________________________________ +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a8d876d0fa3b..cfea95b0dec7 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,39 +15,73 @@ # limitations under the License. # +import glob import os +import sys from itertools import chain import time import operator -import unittest import tempfile +import random +import struct +import shutil +from functools import reduce + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition +from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream class PySparkStreamingTestCase(unittest.TestCase): timeout = 10 # seconds - duration = 1 + duration = .5 - def setUp(self): - class_name = self.__class__.__name__ + @classmethod + def setUpClass(cls): + class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) - self.sc = SparkContext(appName=class_name, conf=conf) - self.sc.setCheckpointDir("/tmp") - # TODO: decrease duration to speed up tests + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir("/tmp") + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + + def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop() + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) def wait_for(self, result, n): start_time = time.time() while len(result) < n and time.time() - start_time < self.timeout: time.sleep(0.01) if len(result) < n: - print "timeout after", self.timeout + print("timeout after", self.timeout) def _take(self, dstream, n): """ @@ -127,7 +161,7 @@ def test_map(self): def func(dstream): return dstream.map(str) - expected = map(lambda x: map(str, x), input) + expected = [list(map(str, x)) for x in input] self._test_func(input, func, expected) def test_flatMap(self): @@ -136,8 +170,8 @@ def test_flatMap(self): def func(dstream): return dstream.flatMap(lambda x: (x, x * 2)) - expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), - input) + expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x)))) + for x in input] self._test_func(input, func, expected) def test_filter(self): @@ -146,7 +180,7 @@ def test_filter(self): def func(dstream): return dstream.filter(lambda x: x % 2 == 0) - expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + expected = [[y for y in x if y % 2 == 0] for x in input] self._test_func(input, func, expected) def test_count(self): @@ -155,7 +189,7 @@ def test_count(self): def func(dstream): return dstream.count() - expected = map(lambda x: [len(x)], input) + expected = [[len(x)] for x in input] self._test_func(input, func, expected) def test_reduce(self): @@ -164,7 +198,7 @@ def test_reduce(self): def func(dstream): return dstream.reduce(operator.add) - expected = map(lambda x: [reduce(operator.add, x)], input) + expected = [[reduce(operator.add, x)] for x in input] self._test_func(input, func, expected) def test_reduceByKey(self): @@ -181,27 +215,27 @@ def func(dstream): def test_mapValues(self): """Basic operation test for DStream.mapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 2), (3, 3)], + [(0, 4), (1, 1), (2, 2), (3, 3)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.mapValues(lambda x: x + 10) expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], - [("", 14), (1, 11), (2, 12), (3, 13)], + [(0, 14), (1, 11), (2, 12), (3, 13)], [(1, 11), (2, 11), (3, 11), (4, 11)]] self._test_func(input, func, expected, sort=True) def test_flatMapValues(self): """Basic operation test for DStream.flatMapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 1), (3, 1)], + [(0, 4), (1, 1), (2, 1), (3, 1)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.flatMapValues(lambda x: (x, x + 10)) expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), ("c", 1), ("c", 11), ("d", 1), ("d", 11)], - [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] self._test_func(input, func, expected) @@ -229,7 +263,7 @@ def f(iterator): def test_countByValue(self): """Basic operation test for DStream.countByValue.""" - input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]] def func(dstream): return dstream.countByValue() @@ -281,7 +315,7 @@ def test_union(self): def func(d1, d2): return d1.union(d2) - expected = [range(6), range(6), range(6)] + expected = [list(range(6)), list(range(6)), list(range(6))] self._test_func(input1, func, expected, input2=input2) def test_cogroup(self): @@ -360,13 +394,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 20 + timeout = 15 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(3, 1).count() + return dstream.window(1.5, .5).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -375,7 +409,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(3, 1) + return dstream.countByWindow(1.5, .5) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -384,7 +418,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(5, 1) + return dstream.countByWindow(2.5, .5) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -393,7 +427,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(5, 1) + return dstream.countByValueAndWindow(2.5, .5) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -402,7 +436,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -418,9 +452,10 @@ def test_reduce_by_invalid_window(self): class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 + setupCalled = False def _add_input_stream(self): - inputs = map(lambda x: range(1, x), range(101)) + inputs = [range(1, x) for x in range(101)] stream = self.ssc.queueStream(inputs) self._collect(stream, 1, block=False) @@ -433,11 +468,11 @@ def test_stop_only_streaming_context(self): def test_stop_multiple_times(self): self._add_input_stream() self.ssc.start() - self.ssc.stop() - self.ssc.stop() + self.ssc.stop(False) + self.ssc.stop(False) def test_queue_stream(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) result = self._collect(dstream, 3) self.assertEqual(input, result) @@ -453,10 +488,24 @@ def test_text_file_stream(self): with open(os.path.join(d, name), "w") as f: f.writelines(["%d\n" % i for i in range(10)]) self.wait_for(result, 2) - self.assertEqual([range(10), range(10)], result) + self.assertEqual([list(range(10)), list(range(10))], result) + + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", bytes(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) def test_union(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) @@ -477,13 +526,89 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) -class CheckpointTests(PySparkStreamingTestCase): + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) - def setUp(self): - pass + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc - def test_get_or_create(self): + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + +class CheckpointTests(unittest.TestCase): + + setupCalled = False + + @staticmethod + def tearDownClass(): + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(True) + if self.sc is not None: + self.sc.stop() + if self.cpd is not None: + shutil.rmtree(self.cpd) + + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -498,15 +623,20 @@ def setup(): wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.checkpoint(.5) + self.setupCalled = True return ssc - cpd = tempfile.mkdtemp("test_streaming_cps") - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + # Verify that getOrCreate() calls setup() in absence of checkpoint files + self.cpd = tempfile.mkdtemp("test_streaming_cps") + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + + self.ssc.start() def check_output(n): while not os.listdir(outputd): - time.sleep(0.1) + time.sleep(0.01) time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) @@ -517,7 +647,7 @@ def check_output(n): # not finished time.sleep(0.01) continue - ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: time.sleep(0.01) @@ -533,13 +663,638 @@ def check_output(n): check_output(1) check_output(2) - ssc.stop(True, True) + # Verify the getOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) time.sleep(1) - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() check_output(3) + # Verify that getOrCreate() uses existing SparkContext + self.ssc.stop(True, True) + time.sleep(1) + sc = SparkContext(SparkConf()) + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.assertTrue(self.ssc.sparkContext == sc) + + # Verify the getActiveOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() + check_output(4) + + # Verify that getActiveOrCreate() returns active context + self.setupCalled = False + self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() uses existing SparkContext + self.ssc.stop(True, True) + time.sleep(1) + self.sc = SparkContext(SparkConf()) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.assertTrue(self.ssc.sparkContext == sc) + + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files + self.ssc.stop(True, True) + shutil.rmtree(self.cpd) # delete checkpoint directory + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertTrue(self.setupCalled) + + # Stop everything + self.ssc.stop(True, True) + + +class KafkaStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(KafkaStreamTests, self).setUp() + + kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") + self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils.setup() + + def tearDown(self): + if self._kafkaTestUtils is not None: + self._kafkaTestUtils.teardown() + self._kafkaTestUtils = None + + super(KafkaStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + self.assertEqual(sendData, result) + + def test_kafka_stream(self): + """Test the Python Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 3, "b": 5, "c": 10} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), + "test-streaming-consumer", {topic: 1}, + {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) + + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_get_offsetRanges(self): + """Test Python direct Kafka RDD get OffsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 3, "b": 4, "c": 5} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self.assertEqual(offsetRanges, rdd.offsetRanges()) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_foreach_get_offsetRanges(self): + """Test the Python direct Kafka stream foreachRDD get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def getOffsetRanges(_, rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + + stream.foreachRDD(getOffsetRanges) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_get_offsetRanges(self): + """Test the Python direct Kafka stream transform get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + # Test whether it is ok mixing KafkaTransformedDStream and TransformedDStream together, + # only the TransformedDstreams can be folded together. + stream.transform(transformWithOffsetRanges).map(lambda kv: kv[1]).count().pprint() + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + +class FlumeStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(FlumeStreamTests, self).setUp() + + utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + super(FlumeStreamTests, self).tearDown() + + def _startContext(self, n, compressed): + # Start the StreamingContext and also collect the result + dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), + enableDecompression=compressed) + result = [] + + def get_output(_, rdd): + for event in rdd.collect(): + if len(result) < n: + result.append(event) + dstream.foreachRDD(get_output) + self.ssc.start() + return result + + def _validateResult(self, input, result): + # Validate both the header and the body + header = {"test": "header"} + self.assertEqual(len(input), len(result)) + for i in range(0, len(input)): + self.assertEqual(header, result[i][0]) + self.assertEqual(input[i], result[i][1]) + + def _writeInput(self, input, compressed): + # Try to write input to the receiver until success or timeout + start_time = time.time() + while True: + try: + self._utils.writeInput(input, compressed) + break + except: + if time.time() - start_time < self.timeout: + time.sleep(0.01) + else: + raise + + def test_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), False) + self._writeInput(input, False) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + def test_compressed_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), True) + self._writeInput(input, True) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + +class FlumePollingStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + maxAttempts = 5 + + def setUp(self): + utilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + def _writeAndVerify(self, ports): + # Set up the streaming context and input streams + ssc = StreamingContext(self.sc, self.duration) + try: + addresses = [("localhost", port) for port in ports] + dstream = FlumeUtils.createPollingStream( + ssc, + addresses, + maxBatchSize=self._utils.eventsPerBatch(), + parallelism=5) + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + dstream.foreachRDD(get_output) + ssc.start() + self._utils.sendDatAndEnsureAllDataHasBeenReceived() + + self.wait_for(outputBuffer, self._utils.getTotalEvents()) + outputHeaders = [event[0] for event in outputBuffer] + outputBodies = [event[1] for event in outputBuffer] + self._utils.assertOutput(outputHeaders, outputBodies) + finally: + ssc.stop(False) + + def _testMultipleTimes(self, f): + attempt = 0 + while True: + try: + f() + break + except: + attempt += 1 + if attempt >= self.maxAttempts: + raise + else: + import traceback + traceback.print_exc() + + def _testFlumePolling(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def _testFlumePollingMultipleHosts(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def test_flume_polling(self): + self._testMultipleTimes(self._testFlumePolling) + + def test_flume_polling_multiple_hosts(self): + self._testMultipleTimes(self._testFlumePollingMultipleHosts) + + +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _startContext(self, topic): + # Start the StreamingContext and also collect the result + stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) + result = [] + + def getOutput(_, rdd): + for data in rdd.collect(): + result.append(data) + + stream.foreachRDD(getOutput) + self.ssc.start() + return result + + def test_mqtt_stream(self): + """Test the Python MQTT stream API.""" + sendData = "MQTT demo for spark streaming" + topic = self._randomTopic() + result = self._startContext(topic) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) + + +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + kinesisStream1 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + kinesisStream2 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + if not are_kinesis_tests_enabled: + sys.stderr.write( + "Skipped test_kinesis_stream (enable by setting environment variable %s=1" + % kinesis_test_environ_var) + return + + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils") + kinesisTestUtils = kinesisTestUtilsClz.newInstance() + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + self.ssc.stop(False) + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + +# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because +# the artifact jars are in different directories. +def search_jar(dir, name_prefix): + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build + glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build + return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)] + + +def search_kafka_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") + jars = search_jar(kafka_assembly_dir, "spark-streaming-kafka-assembly") + if not jars: + raise Exception( + ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/mvn package' before running this test.") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +def search_flume_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") + jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly") + if not jars: + raise Exception( + ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " + "'build/mvn package' before running this test.") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly") + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +def search_mqtt_test_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") + jars = glob.glob( + os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +def search_kinesis_asl_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + if not jars: + return None + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +# Must be same as the variable and condition defined in KinesisTestUtils.scala +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' if __name__ == "__main__": - unittest.main() + kafka_assembly_jar = search_kafka_assembly_jar() + flume_assembly_jar = search_flume_assembly_jar() + mqtt_assembly_jar = search_mqtt_assembly_jar() + mqtt_test_jar = search_mqtt_test_jar() + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + + if kinesis_asl_assembly_jar is None: + kinesis_jar_present = False + jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar) + else: + kinesis_jar_present = True + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar, kinesis_asl_assembly_jar) + + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, + KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests] + + if kinesis_jar_present is True: + testcases.append(KinesisStreamTests) + elif are_kinesis_tests_enabled is False: + sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + raise Exception( + ("Failed to find Spark Streaming Kinesis assembly jar in %s. " + % kinesis_asl_assembly_dir) + + "You need to build Spark with 'build/sbt -Pkinesis-asl " + "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "or 'build/mvn -Pkinesis-asl package' before running this test.") + + sys.stderr.write("Running tests: %s \n" % (str(testcases))) + for testcase in testcases: + sys.stderr.write("[Running %s]\n" % (testcase)) + tests = unittest.TestLoader().loadTestsFromTestCase(testcase) + unittest.TextTestRunner(verbosity=3).run(tests) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 86ee5aa04f25..b20613b1283b 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers + self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + + def rdd_wrapper(self, func): + self._rdd_wrapper = func + return self def call(self, milliseconds, jrdds): try: @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) @@ -91,9 +96,9 @@ def dumps(self, id): except Exception: traceback.print_exc() - def loads(self, bytes): + def loads(self, data): try: - f, deserializers = self.serializer.loads(str(bytes)) + f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() @@ -116,7 +121,7 @@ def rddToFileName(prefix, suffix, timestamp): """ if isinstance(timestamp, datetime): seconds = time.mktime(timestamp.timetuple()) - timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + timestamp = int(seconds * 1000) + timestamp.microsecond // 1000 if suffix is None: return prefix + "-" + str(timestamp) else: @@ -125,4 +130,6 @@ def rddToFileName(prefix, suffix, timestamp): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c7d0622d65f2..8bfed074c905 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,11 +19,10 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ + from array import array -from fileinput import input from glob import glob import os -import pydoc import re import shutil import subprocess @@ -35,6 +34,8 @@ import threading import hashlib +from py4j.protocol import Py4JJavaError + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -43,6 +44,14 @@ sys.exit(1) else: import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str + +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO from pyspark.conf import SparkConf @@ -50,10 +59,10 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ + PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ + FlattenedValuesSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -79,9 +88,9 @@ class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 14 + self.N = 1 << 12 self.l = [i for i in xrange(self.N)] - self.data = zip(self.l, self.l) + self.data = list(zip(self.l, self.l)) self.agg = Aggregator(lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) @@ -89,84 +98,113 @@ def setUp(self): def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 20) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10, partitions=3) - m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.items()), self.N * 10) m._cleanup() + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1024) - self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): - l = range(1024) + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit + l = list(range(1024)) random.shuffle(l) - sorter = ExternalSorter(1) - self.assertEquals(sorted(l), list(sorter.sorted(l))) + sorter = CustomizedSorter(1) + self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") sc = SparkContext(conf=conf) - l = range(10240) + l = list(range(10240)) random.shuffle(l) - rdd = sc.parallelize(l, 10) - self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect()) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) sc.stop() @@ -174,11 +212,11 @@ class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): from collections import namedtuple - from cPickle import dumps, loads + from pickle import dumps, loads P = namedtuple("P", "x y") p1 = P(1, 3) p2 = loads(dumps(p1, 2)) - self.assertEquals(p1, p2) + self.assertEqual(p1, p2) def test_itemgetter(self): from operator import itemgetter @@ -220,7 +258,7 @@ def test_pickling_file_handles(self): ser = CloudPickleSerializer() out1 = sys.stderr out2 = ser.loads(ser.dumps(out1)) - self.assertEquals(out1, out2) + self.assertEqual(out1, out2) def test_func_globals(self): @@ -237,19 +275,48 @@ def __reduce__(self): def foo(): sys.exit(0) - self.assertTrue("exit" in foo.func_code.co_names) + self.assertTrue("exit" in foo.__code__.co_names) ser.dumps(foo) def test_compressed_serializer(self): ser = CompressedSerializer(PickleSerializer()) - from StringIO import StringIO + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO io = StringIO() ser.dump_stream(["abc", u"123", range(5)], io) io.seek(0) self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) ser.dump_stream(range(1000), io) io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) + + +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) class PySparkTestCase(unittest.TestCase): @@ -314,7 +381,7 @@ def test_checkpoint_and_restore(self): self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), flatMappedRDD._jrdd_deserializer) - self.assertEquals([1, 2, 3, 4], recovered.collect()) + self.assertEqual([1, 2, 3, 4], recovered.collect()) class AddFileTests(PySparkTestCase): @@ -323,16 +390,11 @@ def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: # disable logging in log4j temporarily - log4j = self.sc._jvm.org.apache.log4j - old_level = log4j.LogManager.getRootLogger().getLevel() - log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) - def func(x): from userlibrary import UserClass return UserClass().hello() - self.assertRaises(Exception, - self.sc.parallelize(range(2)).map(func).first) - log4j.LogManager.getRootLogger().setLevel(old_level) + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") @@ -346,7 +408,7 @@ def test_add_file_locally(self): download_path = SparkFiles.get("hello.txt") self.assertNotEqual(path, download_path) with open(download_path) as test_file: - self.assertEquals("Hello World!\n", test_file.readline()) + self.assertEqual("Hello World!\n", test_file.readline()) def test_add_py_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -355,7 +417,7 @@ def func(): from userlibrary import UserClass self.assertRaises(ImportError, func) path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addFile(path) + self.sc.addPyFile(path) from userlibrary import UserClass self.assertEqual("Hello World!", UserClass().hello()) @@ -365,7 +427,7 @@ def test_add_egg_file_locally(self): def func(): from userlib import UserClass self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg") + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") self.sc.addPyFile(path) from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) @@ -385,6 +447,11 @@ def func(x): class RDDTests(ReusedPySparkTestCase): + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + def test_id(self): rdd = self.sc.parallelize(range(10)) id = rdd.id() @@ -394,6 +461,14 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -401,8 +476,9 @@ def test_save_as_textfile_with_unicode(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) def test_save_as_textfile_with_utf8(self): x = u"\u00A1Hola, mundo!" @@ -410,19 +486,20 @@ def test_save_as_textfile_with_utf8(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) rdd2 = self.sc.parallelize([3, 4]) cart = rdd1.cartesian(rdd2) - result = cart.map(lambda (x, y): x + y).collect() + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() def test_transforming_pickle_file(self): # Regression test for SPARK-2601 - data = self.sc.parallelize(["Hello", "World!"]) + data = self.sc.parallelize([u"Hello", u"World!"]) tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsPickleFile(tempFile.name) @@ -435,26 +512,144 @@ def test_cartesian_on_textfile(self): a = self.sc.textFile(path) result = a.cartesian(a).collect() (x, y) = result[0] - self.assertEqual("Hello World!", x.strip()) - self.assertEqual("Hello World!", y.strip()) + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name) filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) def test_sampling_default_seed(self): # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(range(1000), 1) + data = self.sc.parallelize(xrange(1000), 1) subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) + def test_aggregate_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregate and treeAggregate to build dict + # representing a counter of ints + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + # Show that single or multiple partitions work + data1 = self.sc.range(10, numSlices=1) + data2 = self.sc.range(10, numSlices=2) + + def seqOp(x, y): + x[y] += 1 + return x + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) + counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) + counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + + ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) + self.assertEqual(counts1, ground_truth) + self.assertEqual(counts2, ground_truth) + self.assertEqual(counts3, ground_truth) + self.assertEqual(counts4, ground_truth) + + def test_aggregate_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that + # contains lists of all values for each key in the original RDD + + # list(range(...)) for Python 3.x compatibility (can't use * operator + # on a range object) + # list(zip(...)) for Python 3.x compatibility (want to parallelize a + # collection, not a zip object) + tuples = list(zip(list(range(10))*2, [1]*20)) + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def seqOp(x, y): + x.append(y) + return x + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.aggregateByKey([], seqOp, comboOp).collect() + values2 = data2.aggregateByKey([], seqOp, comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + ground_truth = [(i, [1]*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_fold_mutable_zero_value(self): + # Test for SPARK-9021; uses fold to merge an RDD of dict counters into + # a single dict + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + counts1 = defaultdict(int, dict((i, 1) for i in range(10))) + counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) + counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) + counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) + all_counts = [counts1, counts2, counts3, counts4] + # Show that single or multiple partitions work + data1 = self.sc.parallelize(all_counts, 1) + data2 = self.sc.parallelize(all_counts, 2) + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + fold1 = data1.fold(defaultdict(int), comboOp) + fold2 = data2.fold(defaultdict(int), comboOp) + + ground_truth = defaultdict(int) + for counts in all_counts: + for key, val in counts.items(): + ground_truth[key] += val + self.assertEqual(fold1, ground_truth) + self.assertEqual(fold2, ground_truth) + + def test_fold_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains + # lists of all values for each key in the original RDD + + tuples = [(i, range(i)) for i in range(10)]*2 + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.foldByKey([], comboOp).collect() + values2 = data2.foldByKey([], comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + # list(range(...)) for Python 3.x compatibility + ground_truth = [(i, list(range(i))*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) @@ -484,21 +679,21 @@ def test_namedtuple_in_rdd(self): jon = Person(1, "Jon", "Doe") jane = Person(2, "Jane", "Doe") theDoes = self.sc.parallelize([jon, jane]) - self.assertEquals([jon, jane], theDoes.collect()) + self.assertEqual([jon, jane], theDoes.collect()) def test_large_broadcast(self): - N = 100000 + N = 10000 data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 270MB + bdata = self.sc.broadcast(data) # 27MB m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEquals(N, m) + self.assertEqual(N, m) def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = range(1 << 15) + r = list(range(1 << 15)) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -509,7 +704,7 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -520,14 +715,12 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) def test_large_closure(self): - N = 1000000 + N = 200000 data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) @@ -546,29 +739,36 @@ def test_zip_with_different_serializers(self): # regression test for bug in _reserializer() self.assertEqual(cnt, t.zip(rdd).count()) + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + def test_zip_with_different_number_of_items(self): a = self.sc.parallelize(range(5), 2) # different number of partitions b = self.sc.parallelize(range(100, 106), 3) self.assertRaises(ValueError, lambda: a.zip(b)) - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEquals(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): - rdd = self.sc.parallelize(range(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) + rdd = self.sc.parallelize(xrange(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) self.assertTrue(18 < rdd.countApproxDistinct() < 22) @@ -577,64 +777,63 @@ def test_count_approx_distinct(self): self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) - self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5)) def test_histogram(self): # empty rdd = self.sc.parallelize([]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) self.assertRaises(ValueError, lambda: rdd.histogram(1)) # out of range rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) # in range with one bucket rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals([4], rdd.histogram([0, 10])[1]) - self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) # in range with one bucket exact match - self.assertEquals([4], rdd.histogram([1, 4])[1]) + self.assertEqual([4], rdd.histogram([1, 4])[1]) # out of range with two buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) # out of range with two uneven buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) # in range with two buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two bucket and None rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two uneven buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) # mixed range with two uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) # mixed range with four uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # mixed range with uneven buckets and NaN rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # out of range with infinite buckets rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) # invalid buckets self.assertRaises(ValueError, lambda: rdd.histogram([])) @@ -644,25 +843,25 @@ def test_histogram(self): # without buckets rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 4], [4]), rdd.histogram(1)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) # without buckets single element rdd = self.sc.parallelize([1]) - self.assertEquals(([1, 1], [1]), rdd.histogram(1)) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) # without bucket no range rdd = self.sc.parallelize([1] * 4) - self.assertEquals(([1, 1], [4]), rdd.histogram(1)) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) # without buckets basic two rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) # without buckets with more requested than elements rdd = self.sc.parallelize([1, 2]) buckets = [1 + 0.2 * i for i in range(6)] hist = [1, 0, 0, 0, 1] - self.assertEquals((buckets, hist), rdd.histogram(5)) + self.assertEqual((buckets, hist), rdd.histogram(5)) # invalid RDDs rdd = self.sc.parallelize([1, float('inf')]) @@ -672,15 +871,8 @@ def test_histogram(self): # string rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - # mixed RDD - rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2) - self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1]) - self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals(([1, "b"], [5]), rdd.histogram(1)) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) def test_repartitionAndSortWithinPartitions(self): @@ -688,16 +880,31 @@ def test_repartitionAndSortWithinPartitions(self): repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) partitions = repartitioned.glom().collect() - self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEquals(rdd.getNumPartitions(), 10) - self.assertEquals(rdd.distinct().count(), 3) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) result = rdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) + + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "1m") + N = 200001 + kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda kv: kv[0] == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -722,7 +929,7 @@ def test_null_in_rdd(self): rdd = RDD(jrdd, self.sc, UTF8Deserializer()) self.assertEqual([u"a", None, u"b"], rdd.collect()) rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual(["a", None, "b"], rdd.collect()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-5361 @@ -730,7 +937,6 @@ def test_multiple_python_java_RDD_conversions(self): (u'1', {u'director': u'David Lean'}), (u'2', {u'director': u'Andrew Dominik'}) ] - from pyspark.rdd import RDD data_rdd = self.sc.parallelize(data) data_java_rdd = data_rdd._to_java_object_rdd() data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) @@ -743,6 +949,72 @@ def test_multiple_python_java_RDD_conversions(self): converted_rdd = RDD(data_python_rdd, self.sc) self.assertEqual(2, converted_rdd.count()) + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + tracker = self.sc.statusTracker() + + self.sc.setJobGroup("test1", "test", True) + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + for x, y in zip(data, result): + self.assertEqual(x, y) + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + class ProfilerTests(PySparkTestCase): @@ -764,7 +1036,12 @@ def test_profiler(self): func_names = [func_name for fname, n, func_name in stat_list] self.assertTrue("heavy_foo" in func_names) + old_stdout = sys.stdout + sys.stdout = io = StringIO() self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + d = tempfile.gettempdir() self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) @@ -788,271 +1065,13 @@ def show(self, id): def do_computation(self): def heavy_foo(x): - for i in range(1 << 20): + for i in range(1 << 18): x = 1 rdd = self.sc.parallelize(range(100)) rdd.foreach(heavy_foo) -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - self.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(self.testData) - self.df = self.sqlCtx.inferSchema(rdd) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - df.count() - df.collect() - df.schema() - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist() - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.applySchema(rdd, df.schema()) - self.assertEqual(10, df3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], df.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) - df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(df.schema(), df2.schema()) - self.assertEqual({}, df2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) - df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - k, v = df.head().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - schema = df.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.applySchema(rdd, schema) - point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) - point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_column_operators(self): - from pyspark.sql import Column, LongType - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbit)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import Aggregator as Agg - self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) - self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0]) - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - class InputFormatTests(ReusedPySparkTestCase): @classmethod @@ -1067,6 +1086,7 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", @@ -1115,15 +1135,16 @@ def test_sequencefiles(self): en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] self.assertEqual(nulls, en) - maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] - self.assertEqual(maps, em) + for v in maps: + self.assertTrue(v in em) # arrays get pickled to tuples by default tuples = sorted(self.sc.sequenceFile( @@ -1250,8 +1271,8 @@ def test_converters(self): def test_binary_files(self): path = os.path.join(self.tempdir.name, "binaryfiles") os.mkdir(path) - data = "short binary data" - with open(os.path.join(path, "part-0000"), 'w') as f: + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: f.write(data) [(p, d)] = self.sc.binaryFiles(path).collect() self.assertTrue(p.endswith("part-0000")) @@ -1264,7 +1285,7 @@ def test_binary_records(self): for i in range(100): f.write('%04d' % i) result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(range(100), result) + self.assertEqual(list(range(100)), result) class OutputFormatTests(ReusedPySparkTestCase): @@ -1276,6 +1297,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir.name, ignore_errors=True) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] @@ -1316,8 +1338,9 @@ def test_sequencefiles(self): (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) - self.assertEqual(maps, em) + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1329,12 +1352,13 @@ def test_oldhadoop(self): "org.apache.hadoop.mapred.SequenceFileOutputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable") - result = sorted(self.sc.hadoopFile( + result = self.sc.hadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) - self.assertEqual(result, dict_data) + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) conf = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -1344,12 +1368,13 @@ def test_oldhadoop(self): } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) input_conf = {"mapred.input.dir": basepath + "/olddataset/"} - old_dataset = sorted(self.sc.hadoopRDD( + result = self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect()) - self.assertEqual(old_dataset, dict_data) + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) def test_newhadoop(self): basepath = self.tempdir.name @@ -1384,6 +1409,7 @@ def test_newhadoop(self): conf=input_conf).collect()) self.assertEqual(new_dataset, data) + @unittest.skipIf(sys.version >= "3", "serialize of array") def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -1464,7 +1490,7 @@ def test_reserialization(self): basepath = self.tempdir.name x = range(1, 5) y = range(1001, 1005) - data = zip(x, y) + data = list(zip(x, y)) rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) @@ -1515,7 +1541,7 @@ def connect(self, port): sock = socket(AF_INET, SOCK_STREAM) sock.connect(('127.0.0.1', port)) # send a split index of -1 to shutdown the worker - sock.send("\xFF\xFF\xFF\xFF") + sock.send(b"\xFF\xFF\xFF\xFF") sock.close() return True @@ -1525,7 +1551,8 @@ def do_termination_test(self, terminator): # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) @@ -1555,8 +1582,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class WorkerTests(PySparkTestCase): - +class WorkerTests(ReusedPySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() @@ -1571,7 +1597,10 @@ def sleep(x): # start job in background thread def run(): - self.sc.parallelize(range(1)).foreach(sleep) + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass import threading t = threading.Thread(target=run) t.daemon = True @@ -1580,7 +1609,8 @@ def run(): daemon_pid, worker_pid = 0, 0 while True: if os.path.exists(path): - data = open(path).read().split(' ') + with open(path) as f: + data = f.read().split(' ') daemon_pid, worker_pid = map(int, data) break time.sleep(0.1) @@ -1604,42 +1634,44 @@ def run(): self.fail("daemon had been killed") # run a normal job - rdd = self.sc.parallelize(range(100), 1) + rdd = self.sc.parallelize(xrange(100), 1) self.assertEqual(100, rdd.map(str).count()) def test_after_exception(self): def raise_exception(_): raise Exception() - rdd = self.sc.parallelize(range(100), 1) - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + rdd = self.sc.parallelize(xrange(100), 1) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) self.assertEqual(100, rdd.map(str).count()) def test_after_jvm_exception(self): tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name, 1) filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) - rdd = self.sc.parallelize(range(100), 1) + rdd = self.sc.parallelize(xrange(100), 1) self.assertEqual(100, rdd.map(str).count()) def test_accumulator_when_reuse_worker(self): from pyspark.accumulators import INT_ACCUMULATOR_PARAM acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) self.assertEqual(sum(range(100)), acc1.value) acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) self.assertEqual(sum(range(100)), acc2.value) self.assertEqual(sum(range(100)), acc1.value) def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(range(100000), 1) + rdd = self.sc.parallelize(xrange(100000), 1) self.assertEqual(0, rdd.first()) def count(): @@ -1655,6 +1687,17 @@ def count(): self.assertTrue(not t.isAlive()) self.assertEqual(100000, rdd.count()) + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" + try: + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + self.sc.pythonVer = version + class SparkSubmitTests(unittest.TestCase): @@ -1665,43 +1708,71 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.programDir) - def createTempFile(self, name, content): + def createTempFile(self, name, content, dir=None): """ Create a temp file with the given name and content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) with open(path, "w") as f: f.write(content) return path - def createFileInZip(self, name, content): + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): """ Create a zip archive containing a file with the given content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name + ".zip") + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) zip = zipfile.ZipFile(path, 'w') zip.writestr(name, content) zip.close() return path + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + def test_single_script(self): """Submit and test a single script file""" script = self.createTempFile("test.py", """ |from pyspark import SparkContext | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) def test_script_with_local_functions(self): """Submit and test a single script file calling a global function""" @@ -1712,12 +1783,12 @@ def test_script_with_local_functions(self): | return x * 3 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) def test_module_dependency(self): """Submit and test a script with a dependency on another module""" @@ -1726,7 +1797,7 @@ def test_module_dependency(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1736,7 +1807,7 @@ def test_module_dependency(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_module_dependency_on_cluster(self): """Submit and test a script with a dependency on another module on a cluster""" @@ -1745,18 +1816,51 @@ def test_module_dependency_on_cluster(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): | return x + 1 """) proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", - "local-cluster[1,1,512]", script], + "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", + "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_single_script_on_cluster(self): """Submit and test a single script on a cluster""" @@ -1767,16 +1871,16 @@ def test_single_script_on_cluster(self): | return x * 2 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], + [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) class ContextTests(unittest.TestCase): @@ -1811,6 +1915,46 @@ def test_with_stop(self): sc.stop() self.assertEqual(SparkContext._active_spark_context, None) + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -1820,7 +1964,7 @@ class SciPyTests(PySparkTestCase): def test_serialize(self): from scipy.special import gammaln x = range(1, 5) - expected = map(gammaln, x) + expected = list(map(gammaln, x)) observed = self.sc.parallelize(x).map(gammaln).collect() self.assertEqual(expected, observed) @@ -1841,11 +1985,11 @@ def test_statcounter_array(self): if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: - print "NOTE: Skipping NumPy tests as it does not seem to be installed" + print("NOTE: Skipping NumPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") if not _have_numpy: - print "NOTE: NumPy tests were skipped as it does not seem to be installed" + print("NOTE: NumPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8a93c320ec5d..42c2f8b75933 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -18,6 +18,7 @@ """ Worker that receives input from Piped RDD. """ +from __future__ import print_function import os import sys import time @@ -37,9 +38,9 @@ def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) - write_long(1000 * boot, outfile) - write_long(1000 * init, outfile) - write_long(1000 * finish, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) def add_path(path): @@ -56,6 +57,12 @@ def main(infile, outfile): if split_index == -1: # for unit tests exit(-1) + version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + ("%d.%d" % sys.version_info[:2], version)) + # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 @@ -72,6 +79,9 @@ def main(infile, outfile): for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) @@ -88,7 +98,7 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer) = command + func, profiler, deserializer, serializer = command init_time = time.time() def process(): @@ -102,14 +112,14 @@ def process(): except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) + write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() + print("PySpark worker failed with exception:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) @@ -136,5 +146,5 @@ def process(): java_port = int(sys.stdin.readline()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("a+", 65536) + sock_file = sock.makefile("rwb", 65536) main(sock_file, sock_file) diff --git a/python/run-tests b/python/run-tests index e91f1a875d35..24949657ed7a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -18,117 +18,7 @@ # -# Figure out where the Spark framework is installed -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" -# CD into the python directory to find things on the right path -cd "$FWDIR/python" - -FAILED=0 -LOG_FILE=unit-tests.log - -rm -f $LOG_FILE - -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL -rm -rf metastore warehouse - -function run_test() { - echo "Running test: $1" | tee -a $LOG_FILE - - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE - - FAILED=$((PIPESTATUS[0]||$FAILED)) - - # Fail and exit on the first test failure. - if [[ $FAILED != 0 ]]; then - cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. - echo -en "\033[31m" # Red - echo "Had test failures; see logs." - echo -en "\033[0m" # No color - exit -1 - fi -} - -function run_core_tests() { - echo "Run core tests ..." - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" -} - -function run_sql_tests() { - echo "Run sql tests ..." - run_test "pyspark/sql.py" -} - -function run_mllib_tests() { - echo "Run mllib tests ..." - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/feature.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/rand.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat/_statistics.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/tests.py" -} - -function run_ml_tests() { - echo "Run ml tests ..." - run_test "pyspark/ml/feature.py" - run_test "pyspark/ml/classification.py" - run_test "pyspark/ml/tests.py" -} - -function run_streaming_tests() { - echo "Run streaming tests ..." - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" -} - -echo "Running PySpark tests. Output is in python/$LOG_FILE." - -export PYSPARK_PYTHON="python" - -# Try to test with Python 2.6, since that's the minimum version that we support: -if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="python2.6" -fi - -echo "Testing with Python version:" -$PYSPARK_PYTHON --version - -run_core_tests -run_sql_tests -run_mllib_tests -run_ml_tests -run_streaming_tests - -# Try to test with PyPy -if [ $(which pypy) ]; then - export PYSPARK_PYTHON="pypy" - echo "Testing with PyPy version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_streaming_tests -fi - -if [[ $FAILED == 0 ]]; then - echo -en "\033[32m" # Green - echo "Tests passed." - echo -en "\033[0m" # No color -fi - -# TODO: in the long-run, it would be nice to use a test runner like `nose`. -# The doctest fixtures are the current barrier to doing this. +exec python -u ./python/run-tests.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py new file mode 100755 index 000000000000..fd56c7ab6e0e --- /dev/null +++ b/python/run-tests.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import logging +from optparse import OptionParser +import os +import re +import subprocess +import sys +import tempfile +from threading import Thread, Lock +import time +if sys.version < '3': + import Queue +else: + import queue as Queue +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + + +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/")) + + +from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) +from sparktestsupport.shellutils import which # noqa +from sparktestsupport.modules import all_modules # noqa + + +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') + + +def print_red(text): + print('\033[31m' + text + '\033[0m') + + +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") +FAILURE_REPORTING_LOCK = Lock() +LOGGER = logging.getLogger() + + +def run_individual_python_test(test_name, pyspark_python): + env = dict(os.environ) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) + LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) + start_time = time.time() + try: + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() + except: + LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(1) + duration = time.time() - start_time + # Exit on the first failure. + if retcode != 0: + try: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output) + per_test_output.seek(0) + for line in per_test_output: + decoded_line = line.decode() + if not re.match('[0-9]+', decoded_line): + print(decoded_line, end='') + per_test_output.close() + except: + LOGGER.exception("Got an exception while trying to print failed test output") + finally: + print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + else: + per_test_output.close() + LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + + +def get_default_python_executables(): + python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] + if "python2.6" not in python_execs: + LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") + python_execs.insert(0, "python") + return python_execs + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "--python-executables", type="string", default=','.join(get_default_python_executables()), + help="A comma-separated list of Python executables to test against (default: %default)" + ) + parser.add_option( + "--modules", type="string", + default=",".join(sorted(python_modules.keys())), + help="A comma-separated list of Python modules to test (default: %default)" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + parser.add_option( + "--verbose", action="store_true", + help="Enable additional debug logging" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + +def main(): + opts = parse_opts() + if (opts.verbose): + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") + LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) + if os.path.exists(LOG_FILE): + os.remove(LOG_FILE) + python_execs = opts.python_executables.split(',') + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module %s" % module_name) + sys.exit(-1) + LOGGER.info("Will test against the following Python executables: %s", python_execs) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + + task_queue = Queue.Queue() + for python_exec in python_execs: + python_implementation = subprocess_check_output( + [python_exec, "-c", "import platform; print(platform.python_implementation())"], + universal_newlines=True).strip() + LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) + LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( + [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + for test_goal in module.python_test_goals: + task_queue.put((python_exec, test_goal)) + + def process_queue(task_queue): + while True: + try: + (python_exec, test_goal) = task_queue.get_nowait() + except Queue.Empty: + break + try: + run_individual_python_test(test_goal, python_exec) + finally: + task_queue.task_done() + + start_time = time.time() + for _ in range(opts.parallelism): + worker = Thread(target=process_queue, args=(task_queue,)) + worker.daemon = True + worker.start() + try: + task_queue.join() + except (KeyboardInterrupt, SystemExit): + print_red("Exiting due to interrupt") + sys.exit(-1) + total_duration = time.time() - start_time + LOGGER.info("Tests passed in %i seconds", total_duration) + + +if __name__ == "__main__": + main() diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc new file mode 100644 index 000000000000..3b7b044936a8 Binary files /dev/null and b/python/test_support/sql/orc_partitioned/._SUCCESS.crc differ diff --git a/python/test_support/sql/orc_partitioned/_SUCCESS b/python/test_support/sql/orc_partitioned/_SUCCESS new file mode 100755 index 000000000000..e69de29bb2d1 diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc new file mode 100644 index 000000000000..834cf0b7f227 Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc differ diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc new file mode 100755 index 000000000000..494380187335 Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc differ diff --git a/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc new file mode 100644 index 000000000000..693dceeee3ef Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc differ diff --git a/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc new file mode 100755 index 000000000000..4cbb95ae0242 Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc differ diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata new file mode 100644 index 000000000000..7ef2320651de Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata new file mode 100644 index 000000000000..78a1ca7d3827 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc new file mode 100644 index 000000000000..e93f42ed6f35 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet new file mode 100644 index 000000000000..461c382937ec Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc new file mode 100644 index 000000000000..b63c4d6d1e1d Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc new file mode 100644 index 000000000000..5bc0ebd71356 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet new file mode 100644 index 000000000000..62a63915beac Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet new file mode 100644 index 000000000000..67665a7b55da Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc new file mode 100644 index 000000000000..ae94a15d08c8 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet new file mode 100644 index 000000000000..6cb8538aa890 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc new file mode 100644 index 000000000000..58d9bb5fc588 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet new file mode 100644 index 000000000000..9b00805481e7 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json new file mode 100644 index 000000000000..50a859cbd7ee --- /dev/null +++ b/python/test_support/sql/people.json @@ -0,0 +1,3 @@ +{"name":"Michael"} +{"name":"Andy", "age":30} +{"name":"Justin", "age":19} diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg deleted file mode 100644 index 1674c9cb2227..000000000000 Binary files a/python/test_support/userlib-0.1-py2.7.egg and /dev/null differ diff --git a/python/test_support/userlib-0.1.zip b/python/test_support/userlib-0.1.zip new file mode 100644 index 000000000000..496e1349aa96 Binary files /dev/null and b/python/test_support/userlib-0.1.zip differ diff --git a/repl/pom.xml b/repl/pom.xml index bd39b90fd871..a5a0f1fc2c85 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -33,22 +33,22 @@ repl - /usr/share/spark - root scala-2.10/src/main/scala scala-2.10/src/test/scala - ${jline.groupid} - jline - ${jline.version} + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} org.apache.spark spark-core_${scala.binary.version} ${project.version} + test-jar + test org.apache.spark @@ -66,7 +66,6 @@ org.apache.spark spark-sql_${scala.binary.version} ${project.version} - test org.scala-lang @@ -87,6 +86,11 @@ scalacheck_${scala.binary.version} test + + org.mockito + mockito-core + test + @@ -129,7 +133,6 @@ - src/main/scala ${extra.source.dir} @@ -142,7 +145,6 @@ - src/test/scala ${extra.testsource.dir} @@ -152,6 +154,20 @@ + + scala-2.10 + + !scala-2.11 + + + + ${jline.groupid} + jline + ${jline.version} + + + + scala-2.11 diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 6480e2d24e04..24fbbc12c08d 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala @@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings) } def this(args: List[String]) { + // scalastyle:off println this(args, str => Console.println("Error: " + str)) + // scalastyle:on println } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 72c1a989999b..304b1e8cdbed 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -45,6 +45,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging import org.apache.spark.SparkConf import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** The Scala interactive shell. It provides a read-eval-print loop @@ -130,6 +131,7 @@ class SparkILoop( // NOTE: Must be public for visibility @DeveloperApi var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ override def echoCommandMessage(msg: String) { intp.reporter printMessage msg @@ -204,7 +206,8 @@ class SparkILoop( // e.g. file:/C:/my/path.jar -> C:/my/path.jar SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") } } else { - SparkILoop.getAddedJars + // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). + SparkILoop.getAddedJars.map { jar => new URI(jar).getPath } } // work around for Scala bug val totalClassPath = addedJars.foldLeft( @@ -978,7 +981,7 @@ class SparkILoop( // which spins off a separate thread, then print the prompt and try // our best to look ready. The interlocking lazy vals tend to // inter-deadlock, so we break the cycle with a single asynchronous - // message to an actor. + // message to an rpcEndpoint. if (isAsync) { intp initialize initializedCallback() createAsyncListener() // listens for signal to run postInitialization @@ -1005,9 +1008,9 @@ class SparkILoop( val jars = SparkILoop.getAddedJars val conf = new SparkConf() .setMaster(getMaster()) - .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", intp.classServerUri) + .setIfMissing("spark.app.name", "Spark shell") if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1016,6 +1019,23 @@ class SparkILoop( sparkContext } + @DeveloperApi + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } + catch { + case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster(): String = { val master = this.master match { case Some(m) => m @@ -1045,15 +1065,16 @@ class SparkILoop( private def main(settings: Settings): Unit = process(settings) } -object SparkILoop { +object SparkILoop extends Logging { implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp private def echo(msg: String) = Console println msg def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") - val propJars = sys.props.get("spark.jars").flatMap { p => - if (p == "") None else Some(p) + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") } + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) } @@ -1080,7 +1101,9 @@ object SparkILoop { val s = super.readLine() // helping out by printing the line being interpreted. if (s != null) + // scalastyle:off println output.println(s) + // scalastyle:on println s } } @@ -1089,7 +1112,7 @@ object SparkILoop { if (settings.classpath.isDefault) settings.classpath.value = sys.props("java.class.path") - getAddedJars.foreach(settings.classpath.append(_)) + getAddedJars.map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) repl process settings } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 99bd777c04fd..bd3314d94eed 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit { if (!initIsComplete) withLock { while (!initIsComplete) initLoopCondition.await() } if (initError != null) { + // scalastyle:off println println(""" |Failed to initialize the REPL due to an unexpected error. |This is a bug, please, report it along with the error diagnostics printed below. |%s.""".stripMargin.format(initError) ) + // scalastyle:on println false } else true } @@ -127,7 +129,17 @@ private[repl] trait SparkILoopInit { _sc } """) + command(""" + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) command("import org.apache.spark.SparkContext._") + command("import sqlContext.implicits._") + command("import sqlContext.sql") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 35fb62564502..4ee605fd7f11 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1079,8 +1079,10 @@ import org.apache.spark.annotation.DeveloperApi throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex) private def load(path: String): Class[_] = { + // scalastyle:off classforname try Class.forName(path, true, classLoader) catch { case ex: Throwable => evalError(path, unwrap(ex)) } + // scalastyle:on classforname } lazy val evalClass = load(evalPath) @@ -1761,7 +1763,9 @@ object SparkIMain { if (intp.totalSilence) () else super.printMessage(msg) } + // scalastyle:off println else Console.println(msg) + // scalastyle:on println } } } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e594ad868ea1..5674dcd669be 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,13 +22,12 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" @@ -121,9 +120,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |var v = 7 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -137,7 +136,7 @@ class ReplSuite extends FunSuite { |class C { |def foo = 5 |} - |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -148,7 +147,7 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |def double(x: Int) = x + x - |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -160,9 +159,9 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -178,9 +177,9 @@ class ReplSuite extends FunSuite { """ |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -212,18 +211,18 @@ class ReplSuite extends FunSuite { } test("local-cluster mode") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -234,7 +233,7 @@ class ReplSuite extends FunSuite { } test("SPARK-1199 two instances of same class don't type check.") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |case class Sum(exp: String, exp2: String) |val a = Sum("A", "B") @@ -255,19 +254,30 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createDataFrame.") { + test("SPARK-2576 importing SQLContext.implicits._") { // We need to use local-cluster to test this case. - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createDataFrame + |import sqlContext.implicits._ |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect() + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + test("SPARK-8461 SQL with codegen") { + val output = runInterpreter("local", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |sqlContext.setConf("spark.sql.codegen", "true") + |sqlContext.range(0, 100).filter('id > 50).count() + """.stripMargin) + assertContains("Long = 49", output) + assertDoesNotContain("java.lang.ClassNotFoundException", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ @@ -275,26 +285,26 @@ class ReplSuite extends FunSuite { |val t = new TestClass |import t.testMethod |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } - if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -309,10 +319,22 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local[2]", """ |case class Foo(i: Int) - |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) assertContains("ret: Array[Foo] = Array(Foo(1),", output) } + + test("collecting objects of class defined in repl - shuffling") { + val output = runInterpreter("local-cluster[1,1,1024]", + """ + |case class Foo(i: Int) + |val list = List((1, Foo(1)), (1, Foo(2))) + |val ret = sc.parallelize(list).groupByKey().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) + } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 69e44d4f916e..be31eb2eda54 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -17,11 +17,13 @@ package org.apache.spark.repl -import org.apache.spark.util.Utils -import org.apache.spark._ +import java.io.File import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.SparkILoop + +import org.apache.spark.util.Utils +import org.apache.spark._ +import org.apache.spark.sql.SQLContext object Main extends Logging { @@ -31,9 +33,11 @@ object Main extends Logging { val outputDir = Utils.createTempDir(rootDir) val s = new Settings() s.processArguments(List("-Yrepl-class-based", - "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator)), true) val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. def main(args: Array[String]) { @@ -46,9 +50,11 @@ object Main extends Logging { Option(sparkContext).map(_.stop) } - def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") + } val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) @@ -59,9 +65,9 @@ object Main extends Logging { val jars = getAddedJars val conf = new SparkConf() .setMaster(getMaster) - .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", classServer.uri) + .setIfMissing("spark.app.name", "Spark shell") logInfo("Spark class server started at " + classServer.uri) if (execUri != null) { conf.set("spark.executor.uri", execUri) @@ -74,6 +80,21 @@ object Main extends Logging { sparkContext } + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } catch { + case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster: String = { val master = { val envMaster = sys.env.get("MASTER") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala deleted file mode 100644 index 8e519fa67f64..000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package scala.tools.nsc -package interpreter - -import scala.tools.nsc.ast.parser.Tokens.EOF - -trait SparkExprTyper { - val repl: SparkIMain - - import repl._ - import global.{ reporter => _, Import => _, _ } - import naming.freshInternalVarName - - def symbolOfLine(code: String): Symbol = { - def asExpr(): Symbol = { - val name = freshInternalVarName() - // Typing it with a lazy val would give us the right type, but runs - // into compiler bugs with things like existentials, so we compile it - // behind a def and strip the NullaryMethodType which wraps the expr. - val line = "def " + name + " = " + code - - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - // drop NullaryMethodType - sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) - case _ => NoSymbol - } - } - def asDefn(): Symbol = { - val old = repl.definedSymbolList.toSet - - interpretSynthetic(code) match { - case IR.Success => - repl.definedSymbolList filterNot old match { - case Nil => NoSymbol - case sym :: Nil => sym - case syms => NoSymbol.newOverloaded(NoPrefix, syms) - } - case _ => NoSymbol - } - } - def asError(): Symbol = { - interpretSynthetic(code) - NoSymbol - } - beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() - } - - private var typeOfExpressionDepth = 0 - def typeOfExpression(expr: String, silent: Boolean = true): Type = { - if (typeOfExpressionDepth > 2) { - repldbg("Terminating typeOfExpression recursion for expression: " + expr) - return NoType - } - typeOfExpressionDepth += 1 - // Don't presently have a good way to suppress undesirable success output - // while letting errors through, so it is first trying it silently: if there - // is an error, and errors are desired, then it re-evaluates non-silently - // to induce the error message. - try beSilentDuring(symbolOfLine(expr).tpe) match { - case NoType if !silent => symbolOfLine(expr).tpe // generate error - case tpe => tpe - } - finally typeOfExpressionDepth -= 1 - } - - // This only works for proper types. - def typeOfTypeString(typeString: String): Type = { - def asProperType(): Option[Type] = { - val name = freshInternalVarName() - val line = "def %s: %s = ???" format (name, typeString) - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - Some(sym0.asMethod.returnType) - case _ => None - } - } - beSilentDuring(asProperType()) getOrElse NoType - } -} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 250727305970..bf609ff0f65f 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1,78 +1,64 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Alexander Spoon +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -package scala -package tools.nsc -package interpreter +package org.apache.spark.repl -import scala.language.{ implicitConversions, existentials } -import scala.annotation.tailrec -import Predef.{ println => _, _ } -import interpreter.session._ -import StdReplTags._ -import scala.reflect.api.{Mirror, Universe, TypeCreator} -import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } -import scala.reflect.{ClassTag, classTag} -import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } -import ScalaClassLoader._ -import scala.reflect.io.{ File, Directory } -import scala.tools.util._ -import scala.collection.generic.Clearable -import scala.concurrent.{ ExecutionContext, Await, Future, future } -import ExecutionContext.Implicits._ -import java.io.{ BufferedReader, FileReader } +import java.io.{BufferedReader, FileReader} -/** The Scala interactive shell. It provides a read-eval-print loop - * around the Interpreter class. - * After instantiation, clients should call the main() method. - * - * If no in0 is specified, then input will come from the console, and - * the class will attempt to provide input editing feature such as - * input history. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - * @version 1.2 - */ -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) - extends AnyRef - with LoopCommands -{ - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) -// -// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp -// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i - - var in: InteractiveReader = _ // the input stream from which commands come - var settings: Settings = _ - var intp: SparkIMain = _ +import Predef.{println => _, _} +import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName} - var globalFuture: Future[Boolean] = _ +import scala.tools.nsc.interpreter.{JPrintWriter, ILoop} +import scala.tools.nsc.Settings +import scala.tools.nsc.util.stringFromStream - protected def asyncMessage(msg: String) { - if (isReplInfo || isReplPower) - echoAndRefresh(msg) - } +/** + * A Spark-specific interactive shell. + */ +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) + extends ILoop(in0, out) { + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) def initializeSpark() { intp.beQuietDuring { - command( """ + processLine(""" @transient val sc = { val _sc = org.apache.spark.repl.Main.createSparkContext() println("Spark context available as sc.") _sc } - """) - command("import org.apache.spark.SparkContext._") + """) + processLine(""" + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) + processLine("import org.apache.spark.SparkContext._") + processLine("import sqlContext.implicits._") + processLine("import sqlContext.sql") + processLine("import org.apache.spark.sql.functions._") } } /** Print a welcome message */ - def printWelcome() { + override def printWelcome() { import org.apache.spark.SPARK_VERSION echo("""Welcome to ____ __ @@ -88,875 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) echo("Type :help for more information.") } - override def echoCommandMessage(msg: String) { - intp.reporter printUntruncatedMessage msg - } - - // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - def history = in.history - - // classpath entries added via :cp - var addedClasspath: String = "" - - /** A reverse list of commands to replay if the user requests a :replay */ - var replayCommandStack: List[String] = Nil - - /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandStack.reverse - - /** Record a command for replay should the user request a :replay */ - def addReplay(cmd: String) = replayCommandStack ::= cmd - - def savingReplayStack[T](body: => T): T = { - val saved = replayCommandStack - try body - finally replayCommandStack = saved - } - def savingReader[T](body: => T): T = { - val saved = in - try body - finally in = saved - } - - /** Close the interpreter and set the var to null. */ - def closeInterpreter() { - if (intp ne null) { - intp.close() - intp = null - } - } - - class SparkILoopInterpreter extends SparkIMain(settings, out) { - outer => - - override lazy val formatting = new Formatting { - def prompt = SparkILoop.this.prompt - } - override protected def parentClassLoader = - settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) - } - - /** Create a new interpreter. */ - def createInterpreter() { - if (addedClasspath != "") - settings.classpath append addedClasspath - - intp = new SparkILoopInterpreter - } - - /** print a friendly help message */ - def helpCommand(line: String): Result = { - if (line == "") helpSummary() - else uniqueCommand(line) match { - case Some(lc) => echo("\n" + lc.help) - case _ => ambiguousError(line) - } - } - private def helpSummary() = { - val usageWidth = commands map (_.usageMsg.length) max - val formatStr = "%-" + usageWidth + "s %s" - - echo("All commands can be abbreviated, e.g. :he instead of :help.") - - commands foreach { cmd => - echo(formatStr.format(cmd.usageMsg, cmd.help)) - } - } - private def ambiguousError(cmd: String): Result = { - matchingCommands(cmd) match { - case Nil => echo(cmd + ": no such command. Type :help for help.") - case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") - } - Result(keepRunning = true, None) - } - private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) - private def uniqueCommand(cmd: String): Option[LoopCommand] = { - // this lets us add commands willy-nilly and only requires enough command to disambiguate - matchingCommands(cmd) match { - case List(x) => Some(x) - // exact match OK even if otherwise appears ambiguous - case xs => xs find (_.name == cmd) - } - } - - /** Show the history */ - lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { - override def usage = "[num]" - def defaultLines = 20 - - def apply(line: String): Result = { - if (history eq NoHistory) - return "No history available." - - val xs = words(line) - val current = history.index - val count = try xs.head.toInt catch { case _: Exception => defaultLines } - val lines = history.asStrings takeRight count - val offset = current - lines.size + 1 - - for ((line, index) <- lines.zipWithIndex) - echo("%3d %s".format(index + offset, line)) - } - } - - // When you know you are most likely breaking into the middle - // of a line being typed. This softens the blow. - protected def echoAndRefresh(msg: String) = { - echo("\n" + msg) - in.redrawLine() - } - protected def echo(msg: String) = { - out println msg - out.flush() - } - - /** Search the history */ - def searchHistory(_cmdline: String) { - val cmdline = _cmdline.toLowerCase - val offset = history.index - history.size + 1 - - for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) - echo("%d %s".format(index + offset, line)) - } - - private val currentPrompt = Properties.shellPromptString - - /** Prompt to print when awaiting input */ - def prompt = currentPrompt - import LoopCommand.{ cmd, nullary } - /** Standard commands **/ - lazy val standardCommands = List( - cmd("cp", "", "add a jar or directory to the classpath", addClasspath), - cmd("edit", "|", "edit history", editCommand), - cmd("help", "[command]", "print this summary or command-specific help", helpCommand), - historyCommand, - cmd("h?", "", "search the history", searchHistory), - cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), - //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), - cmd("javap", "", "disassemble a file or class name", javapCommand), - cmd("line", "|", "place line(s) at the end of history", lineCommand), - cmd("load", "", "interpret lines in a file", loadCommand), - cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), - // nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), - nullary("replay", "reset execution and replay all previous commands", replay), - nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), - cmd("save", "", "save replayable session to a file", saveCommand), - shCommand, - cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings), - nullary("silent", "disable/enable automatic printing of results", verbosity), -// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), -// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand), - nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) - ) - - /** Power user commands */ -// lazy val powerCommands: List[LoopCommand] = List( -// cmd("phase", "", "set the implicit phase for power commands", phaseCommand) -// ) - - private def importsCommand(line: String): Result = { - val tokens = words(line) - val handlers = intp.languageWildcardHandlers ++ intp.importHandlers - - handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { - case (handler, idx) => - val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) - val imps = handler.implicitSymbols - val found = tokens filter (handler importsSymbolNamed _) - val typeMsg = if (types.isEmpty) "" else types.size + " types" - val termMsg = if (terms.isEmpty) "" else terms.size + " terms" - val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" - val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") - val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - - intp.reporter.printMessage("%2d) %-30s %s%s".format( - idx + 1, - handler.importString, - statsMsg, - foundMsg - )) - } - } - - private def findToolsJar() = PathResolver.SupplementalLocations.platformTools - - private def addToolsJarToLoader() = { - val cl = findToolsJar() match { - case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) - case _ => intp.classLoader - } - if (Javap.isAvailable(cl)) { - repldbg(":javap available.") - cl - } - else { - repldbg(":javap unavailable: no tools.jar at " + jdkHome) - intp.classLoader - } - } -// -// protected def newJavap() = -// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) -// -// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) - - // Still todo: modules. -// private def typeCommand(line0: String): Result = { -// line0.trim match { -// case "" => ":type [-v] " -// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") -// } -// } - -// private def kindCommand(expr: String): Result = { -// expr.trim match { -// case "" => ":kind [-v] " -// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") -// } -// } - - private def warningsCommand(): Result = { - if (intp.lastWarnings.isEmpty) - "Can't find any cached warnings." - else - intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } - } + private val blockedCommands = Set("implicits", "javap", "power", "type", "kind") - private def changeSettings(args: String): Result = { - def showSettings() = { - for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) - } - def updateSettings() = { - // put aside +flag options - val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) - val tmps = new Settings - val (ok, leftover) = tmps.processArguments(rest, processAll = true) - if (!ok) echo("Bad settings request.") - else if (leftover.nonEmpty) echo("Unprocessed settings.") - else { - // boolean flags set-by-user on tmp copy should be off, not on - val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) - val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) - // update non-flags - settings.processArguments(nonbools, processAll = true) - // also snag multi-value options for clearing, e.g. -Ylog: and -language: - for { - s <- settings.userSetSettings - if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] - if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) - } s match { - case c: Clearable => c.clear() - case _ => - } - def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { - for (b <- bs) - settings.lookupSetting(name(b)) match { - case Some(s) => - if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) - else echo(s"Not a boolean flag: $b") - case _ => - echo(s"Not an option: $b") - } - } - update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off - update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on - } - } - if (args.isEmpty) showSettings() else updateSettings() - } - - private def javapCommand(line: String): Result = { -// if (javap == null) -// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) -// else if (line == "") -// ":javap [-lcsvp] [path1 path2 ...]" -// else -// javap(words(line)) foreach { res => -// if (res.isError) return "Failed: " + res.value -// else res.show() -// } - } - - private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" - - private def phaseCommand(name: String): Result = { -// val phased: Phased = power.phased -// import phased.NoPhaseName -// -// if (name == "clear") { -// phased.set(NoPhaseName) -// intp.clearExecutionWrapper() -// "Cleared active phase." -// } -// else if (name == "") phased.get match { -// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" -// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) -// } -// else { -// val what = phased.parse(name) -// if (what.isEmpty || !phased.set(what)) -// "'" + name + "' does not appear to represent a valid phase." -// else { -// intp.setExecutionWrapper(pathToPhaseWrapper) -// val activeMessage = -// if (what.toString.length == name.length) "" + what -// else "%s (%s)".format(what, name) -// -// "Active phase is now: " + activeMessage -// } -// } - } + /** Standard commands **/ + lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] = + standardCommands.filter(cmd => !blockedCommands(cmd.name)) /** Available commands */ - def commands: List[LoopCommand] = standardCommands ++ ( - // if (isReplPower) - // powerCommands - // else - Nil - ) - - val replayQuestionMessage = - """|That entry seems to have slain the compiler. Shall I replay - |your session? I can re-run each line except the last one. - |[y/n] - """.trim.stripMargin - - private val crashRecovery: PartialFunction[Throwable, Boolean] = { - case ex: Throwable => - val (err, explain) = ( - if (intp.isInitializeComplete) - (intp.global.throwableAsString(ex), "") - else - (ex.getMessage, "The compiler did not initialize.\n") - ) - echo(err) - - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true - } - - // return false if repl should exit - def processLine(line: String): Boolean = { - import scala.concurrent.duration._ - Await.ready(globalFuture, 60.seconds) - - (line ne null) && (command(line) match { - case Result(false, _) => false - case Result(_, Some(line)) => addReplay(line) ; true - case _ => true - }) - } - - private def readOneLine() = { - out.flush() - in readLine prompt - } - - /** The main read-eval-print loop for the repl. It calls - * command() for each line of input, and stops when - * command() returns false. - */ - @tailrec final def loop() { - if ( try processLine(readOneLine()) catch crashRecovery ) - loop() - } - - /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { - savingReader { - savingReplayStack { - file applyReader { reader => - in = SimpleReader(reader, out, interactive = false) - echo("Loading " + file + "...") - loop() - } - } - } - } - - /** create a new interpreter and replay the given commands */ - def replay() { - reset() - if (replayCommandStack.isEmpty) - echo("Nothing to replay.") - else for (cmd <- replayCommands) { - echo("Replaying: " + cmd) // flush because maybe cmd will have its own output - command(cmd) - echo("") - } - } - def resetCommand() { - echo("Resetting interpreter state.") - if (replayCommandStack.nonEmpty) { - echo("Forgetting this session history:\n") - replayCommands foreach echo - echo("") - replayCommandStack = Nil - } - if (intp.namedDefinedTerms.nonEmpty) - echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) - if (intp.definedTypes.nonEmpty) - echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) - - reset() - } - def reset() { - intp.reset() - unleashAndSetPhase() - } - - def lineCommand(what: String): Result = editCommand(what, None) - - // :edit id or :edit line - def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) - - def editCommand(what: String, editor: Option[String]): Result = { - def diagnose(code: String) = { - echo("The edited code is incomplete!\n") - val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") - if (errless) echo("The compiler reports no errors.") - } - def historicize(text: String) = history match { - case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true - case _ => false - } - def edit(text: String): Result = editor match { - case Some(ed) => - val tmp = File.makeTemp() - tmp.writeAll(text) - try { - val pr = new ProcessResult(s"$ed ${tmp.path}") - pr.exitCode match { - case 0 => - tmp.safeSlurp() match { - case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") - case Some(edited) => - echo(edited.lines map ("+" + _) mkString "\n") - val res = intp interpret edited - if (res == IR.Incomplete) diagnose(edited) - else { - historicize(edited) - Result(lineToRecord = Some(edited), keepRunning = true) - } - case None => echo("Can't read edited text. Did you delete it?") - } - case x => echo(s"Error exit from $ed ($x), ignoring") - } - } finally { - tmp.delete() - } - case None => - if (historicize(text)) echo("Placing text in recent history.") - else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") - } - - // if what is a number, use it as a line number or range in history - def isNum = what forall (c => c.isDigit || c == '-' || c == '+') - // except that "-" means last value - def isLast = (what == "-") - if (isLast || !isNum) { - val name = if (isLast) intp.mostRecentVar else what - val sym = intp.symbolOfIdent(name) - intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { - case Some(req) => edit(req.line) - case None => echo(s"No symbol in scope: $what") - } - } else try { - val s = what - // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) - val (start, len) = - if ((s indexOf '+') > 0) { - val (a,b) = s splitAt (s indexOf '+') - (a.toInt, b.drop(1).toInt) - } else { - (s indexOf '-') match { - case -1 => (s.toInt, 1) - case 0 => val n = s.drop(1).toInt ; (history.index - n, n) - case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) - case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) - } - } - import scala.collection.JavaConverters._ - val index = (start - 1) max 0 - val text = history match { - case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" - case _ => history.asStrings.slice(index, index + len) mkString "\n" - } - edit(text) - } catch { - case _: NumberFormatException => echo(s"Bad range '$what'") - echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") - } - } - - /** fork a shell and run a command */ - lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { - override def usage = "" - def apply(line: String): Result = line match { - case "" => showUsage() - case _ => - val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" - intp interpret toRun - () - } - } - - def withFile[A](filename: String)(action: File => A): Option[A] = { - val res = Some(File(filename)) filter (_.exists) map action - if (res.isEmpty) echo("That file does not exist") // courtesy side-effect - res - } - - def loadCommand(arg: String) = { - var shouldReplay: Option[String] = None - withFile(arg)(f => { - interpretAllFrom(f) - shouldReplay = Some(":load " + arg) - }) - Result(keepRunning = true, shouldReplay) - } - - def saveCommand(filename: String): Result = ( - if (filename.isEmpty) echo("File name is required.") - else if (replayCommandStack.isEmpty) echo("No replay commands in session") - else File(filename).printlnAll(replayCommands: _*) - ) - - def addClasspath(arg: String): Unit = { - val f = File(arg).normalize - if (f.exists) { - addedClasspath = ClassPath.join(addedClasspath, f.path) - val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) - replay() - } - else echo("The path '" + f + "' doesn't seem to exist.") - } - - def powerCmd(): Result = { - if (isReplPower) "Already in power mode." - else enablePowerMode(isDuringInit = false) - } - def enablePowerMode(isDuringInit: Boolean) = { - replProps.power setValue true - unleashAndSetPhase() - // asyncEcho(isDuringInit, power.banner) - } - private def unleashAndSetPhase() { - if (isReplPower) { - // power.unleash() - // Set the phase to "typer" - // intp beSilentDuring phaseCommand("typer") - } - } - - def asyncEcho(async: Boolean, msg: => String) { - if (async) asyncMessage(msg) - else echo(msg) - } - - def verbosity() = { - val old = intp.printResults - intp.printResults = !old - echo("Switched " + (if (old) "off" else "on") + " result printing.") - } - - /** Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. */ - def command(line: String): Result = { - if (line startsWith ":") { - val cmd = line.tail takeWhile (x => !x.isWhitespace) - uniqueCommand(cmd) match { - case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) - case _ => ambiguousError(cmd) - } - } - else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler - else Result(keepRunning = true, interpretStartingWith(line)) - } - - private def readWhile(cond: String => Boolean) = { - Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) - } - - def pasteCommand(arg: String): Result = { - var shouldReplay: Option[String] = None - def result = Result(keepRunning = true, shouldReplay) - val (raw, file) = - if (arg.isEmpty) (false, None) - else { - val r = """(-raw)?(\s+)?([^\-]\S*)?""".r - arg match { - case r(flag, sep, name) => - if (flag != null && name != null && sep == null) - echo(s"""I assume you mean "$flag $name"?""") - (flag != null, Option(name)) - case _ => - echo("usage: :paste -raw file") - return result - } - } - val code = file match { - case Some(name) => - withFile(name)(f => { - shouldReplay = Some(s":paste $arg") - val s = f.slurp.trim - if (s.isEmpty) echo(s"File contains no code: $f") - else echo(s"Pasting file $f...") - s - }) getOrElse "" - case None => - echo("// Entering paste mode (ctrl-D to finish)\n") - val text = (readWhile(_ => true) mkString "\n").trim - if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") - else echo("\n// Exiting paste mode, now interpreting.\n") - text - } - def interpretCode() = { - val res = intp interpret code - // if input is incomplete, let the compiler try to say why - if (res == IR.Incomplete) { - echo("The pasted code is incomplete!\n") - // Remembrance of Things Pasted in an object - val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") - if (errless) echo("...but compilation found no error? Good luck with that.") - } - } - def compileCode() = { - val errless = intp compileSources new BatchSourceFile("", code) - if (!errless) echo("There were compilation errors!") - } - if (code.nonEmpty) { - if (raw) compileCode() else interpretCode() - } - result - } - - private object paste extends Pasted { - val ContinueString = " | " - val PromptString = "scala> " - - def interpret(line: String): Unit = { - echo(line.trim) - intp interpret line - echo("") - } - - def transcript(start: String) = { - echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") - apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) - } - } - import paste.{ ContinueString, PromptString } - - /** Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ - def interpretStartingWith(code: String): Option[String] = { - // signal completion non-completion input has been received - in.completion.resetVerbosity() - - def reallyInterpret = { - val reallyResult = intp.interpret(code) - (reallyResult, reallyResult match { - case IR.Error => None - case IR.Success => Some(code) - case IR.Incomplete => - if (in.interactive && code.endsWith("\n\n")) { - echo("You typed two blank lines. Starting a new command.") - None - } - else in.readLine(ContinueString) match { - case null => - // we know compilation is going to fail since we're at EOF and the - // parser thinks the input is still incomplete, but since this is - // a file being read non-interactively we want to fail. So we send - // it straight to the compiler for the nice error message. - intp.compileString(code) - None - - case line => interpretStartingWith(code + "\n" + line) - } - }) - } - - /** Here we place ourselves between the user and the interpreter and examine - * the input they are ostensibly submitting. We intervene in several cases: - * - * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. - * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation - * on the previous result. - * 3) If the Completion object's execute returns Some(_), we inject that value - * and avoid the interpreter, as it's likely not valid scala code. - */ - if (code == "") None - else if (!paste.running && code.trim.startsWith(PromptString)) { - paste.transcript(code) - None - } - else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { - interpretStartingWith(intp.mostRecentVar + code) - } - else if (code.trim startsWith "//") { - // line comment, do nothing - None - } - else - reallyInterpret._2 - } - - // runs :load `file` on any files passed via -i - def loadFiles(settings: Settings) = settings match { - case settings: GenericRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - addReplay(cmd) - echo("") - } - case _ => - } - - /** Tries to create a JLineReader, falling back to SimpleReader: - * unless settings or properties are such that it should start - * with SimpleReader. - */ - def chooseReader(settings: Settings): InteractiveReader = { - if (settings.Xnojline || Properties.isEmacsShell) - SimpleReader() - else try new JLineReader( - if (settings.noCompletion) NoCompletion - else new SparkJLineCompletion(intp) - ) - catch { - case ex @ (_: Exception | _: NoClassDefFoundError) => - echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") - SimpleReader() - } - } - protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = - u.TypeTag[T]( - m, - new TypeCreator { - def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = - m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] - }) - - private def loopPostInit() { - // Bind intp somewhere out of the regular namespace where - // we can get at it in generated code. - intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) - // Auto-run code via some setting. - ( replProps.replAutorunCode.option - flatMap (f => io.File(f).safeSlurp()) - foreach (intp quietRun _) - ) - // classloader and power mode setup - intp.setContextClassLoader() - if (isReplPower) { - // replProps.power setValue true - // unleashAndSetPhase() - // asyncMessage(power.banner) - } - // SI-7418 Now, and only now, can we enable TAB completion. - in match { - case x: JLineReader => x.consoleReader.postInit - case _ => - } - } - def process(settings: Settings): Boolean = savingContextLoader { - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) - globalFuture = future { - intp.initializeSynchronous() - loopPostInit() - !intp.reporter.hasErrors - } - import scala.concurrent.duration._ - Await.ready(globalFuture, 10 seconds) - printWelcome() + override def commands: List[LoopCommand] = sparkStandardCommands + + /** + * We override `loadFiles` because we need to initialize Spark *before* the REPL + * sees any files, so that the Spark context is visible in those files. This is a bit of a + * hack, but there isn't another hook available to us at this point. + */ + override def loadFiles(settings: Settings): Unit = { initializeSpark() - loadFiles(settings) - - try loop() - catch AbstractOrMissingHandler() - finally closeInterpreter() - - true + super.loadFiles(settings) } - - @deprecated("Use `process` instead", "2.9.0") - def main(settings: Settings): Unit = process(settings) //used by sbt } object SparkILoop { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - // Designed primarily for use by test code: take a String with a - // bunch of code, and prints out a transcript of what it would look - // like if you'd just typed it into the repl. - def runForTranscript(code: String, settings: Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { - override def write(str: String) = { - // completely skip continuation lines - if (str forall (ch => ch.isWhitespace || ch == '|')) () - else super.write(str) - } - } - val input = new BufferedReader(new StringReader(code.trim + "\n")) { - override def readLine(): String = { - val s = super.readLine() - // helping out by printing the line being interpreted. - if (s != null) - output.println(s) - s - } - } - val repl = new SparkILoop(input, output) - if (settings.classpath.isDefault) - settings.classpath.value = sys.props("java.class.path") - - repl process settings - } - } - } - - /** Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ + /** + * Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) if (sets.classpath.isDefault) sets.classpath.value = sys.props("java.class.path") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala deleted file mode 100644 index 1bb62c84abdd..000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ /dev/null @@ -1,1319 +0,0 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Martin Odersky - */ - -package scala -package tools.nsc -package interpreter - -import PartialFunction.cond -import scala.language.implicitConversions -import scala.beans.BeanProperty -import scala.collection.mutable -import scala.concurrent.{ Future, ExecutionContext } -import scala.reflect.runtime.{ universe => ru } -import scala.reflect.{ ClassTag, classTag } -import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } -import scala.tools.util.PathResolver -import scala.tools.nsc.io.AbstractFile -import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } -import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } -import scala.tools.nsc.util.Exceptional.unwrap -import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} - -/** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accomodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ -class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, - protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { - imain => - - setBindings(createBindings, ScriptContext.ENGINE_SCOPE) - object replOutput extends ReplOutput(settings.Yreploutdir) { } - - @deprecated("Use replOutput.dir instead", "2.11.0") - def virtualDirectory = replOutput.dir - // Used in a test case. - def showDirectory() = replOutput.show(out) - - private[nsc] var printResults = true // whether to print result lines - private[nsc] var totalSilence = false // whether to print anything - private var _initializeComplete = false // compiler is initialized - private var _isInitialized: Future[Boolean] = null // set up initialization future - private var bindExceptions = true // whether to bind the lastException variable - private var _executionWrapper = "" // code to be wrapped around all lines - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. - */ - private var _classLoader: util.AbstractFileClassLoader = null // active classloader - private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler - - def compilerClasspath: Seq[java.net.URL] = ( - if (isInitializeComplete) global.classPath.asURLs - else new PathResolver(settings).result.asURLs // the compiler's classpath - ) - def settings = initialSettings - // Run the code body with the given boolean settings flipped to true. - def withoutWarnings[T](body: => T): T = beQuietDuring { - val saved = settings.nowarn.value - if (!saved) - settings.nowarn.value = true - - try body - finally if (!saved) settings.nowarn.value = false - } - - /** construct an interpreter that reports to Console */ - def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) - def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this(factory: ScriptEngineFactory) = this(factory, new Settings()) - def this() = this(new Settings()) - - lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - lazy val reporter: SparkReplReporter = new SparkReplReporter(this) - - import formatting._ - import reporter.{ printMessage, printUntruncatedMessage } - - // This exists mostly because using the reporter too early leads to deadlock. - private def echo(msg: String) { Console println msg } - private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) - private def _initialize() = { - try { - // if this crashes, REPL will hang its head in shame - val run = new _compiler.Run() - assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") - run compileSources _initSources - _initializeComplete = true - true - } - catch AbstractOrMissingHandler() - } - private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - private val logScope = scala.sys.props contains "scala.repl.scope" - private def scopelog(msg: String) = if (logScope) Console.err.println(msg) - - // argument is a thunk to execute after init is done - def initialize(postInitSignal: => Unit) { - synchronized { - if (_isInitialized == null) { - _isInitialized = - Future(try _initialize() finally postInitSignal)(ExecutionContext.global) - } - } - } - def initializeSynchronous(): Unit = { - if (!isInitializeComplete) { - _initialize() - assert(global != null, global) - } - } - def isInitializeComplete = _initializeComplete - - lazy val global: Global = { - if (!isInitializeComplete) _initialize() - _compiler - } - - import global._ - import definitions.{ ObjectClass, termMember, dropNullaryMethod} - - lazy val runtimeMirror = ru.runtimeMirror(classLoader) - - private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } - - def getClassIfDefined(path: String) = ( - noFatal(runtimeMirror staticClass path) - orElse noFatal(rootMirror staticClass path) - ) - def getModuleIfDefined(path: String) = ( - noFatal(runtimeMirror staticModule path) - orElse noFatal(rootMirror staticModule path) - ) - - implicit class ReplTypeOps(tp: Type) { - def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) - } - - // TODO: If we try to make naming a lazy val, we run into big time - // scalac unhappiness with what look like cycles. It has not been easy to - // reduce, but name resolution clearly takes different paths. - object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - def freshUserTermName(): TermName = { - val name = newTermName(freshUserVarName()) - if (replScope containsName name) freshUserTermName() - else name - } - def isInternalTermName(name: Name) = isInternalVarName("" + name) - } - import naming._ - - object deconstruct extends { - val global: imain.global.type = imain.global - } with StructuredTypeStrings - - lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - /** Temporarily be quiet */ - def beQuietDuring[T](body: => T): T = { - val saved = printResults - printResults = false - try body - finally printResults = saved - } - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - /** takes AnyRef because it may be binding a Throwable or an Exceptional */ - private def withLastExceptionLock[T](body: => T, alt: => T): T = { - assert(bindExceptions, "withLastExceptionLock called incorrectly.") - bindExceptions = false - - try beQuietDuring(body) - catch logAndDiscard("withLastExceptionLock", alt) - finally bindExceptions = true - } - - def executionWrapper = _executionWrapper - def setExecutionWrapper(code: String) = _executionWrapper = code - def clearExecutionWrapper() = _executionWrapper = "" - - /** interpreter settings */ - lazy val isettings = new SparkISettings(this) - - /** Instantiate a compiler. Overridable. */ - protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { - settings.outputDirs setSingleOutput replOutput.dir - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) with ReplGlobal { override def toString: String = "" } - } - - /** Parent classloader. Overridable. */ - protected def parentClassLoader: ClassLoader = - settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) - - /* A single class loader is used for all commands interpreted by this Interpreter. - It would also be possible to create a new class loader for each command - to interpret. The advantages of the current approach are: - - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" - - The main disadvantage is: - - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - def resetClassLoader() = { - repldbg("Setting new classloader: was " + _classLoader) - _classLoader = null - ensureClassLoader() - } - final def ensureClassLoader() { - if (_classLoader == null) - _classLoader = makeClassLoader() - } - def classLoader: util.AbstractFileClassLoader = { - ensureClassLoader() - _classLoader - } - - def backticked(s: String): String = ( - (s split '.').toList map { - case "_" => "_" - case s if nme.keywords(newTermName(s)) => s"`$s`" - case s => s - } mkString "." - ) - def readRootPath(readPath: String) = getModuleIfDefined(readPath) - - abstract class PhaseDependentOps { - def shift[T](op: => T): T - - def path(name: => Name): String = shift(path(symbolOfName(name))) - def path(sym: Symbol): String = backticked(shift(sym.fullName)) - def sig(sym: Symbol): String = shift(sym.defString) - } - object typerOp extends PhaseDependentOps { - def shift[T](op: => T): T = exitingTyper(op) - } - object flatOp extends PhaseDependentOps { - def shift[T](op: => T): T = exitingFlatten(op) - } - - def originalPath(name: String): String = originalPath(name: TermName) - def originalPath(name: Name): String = typerOp path name - def originalPath(sym: Symbol): String = typerOp path sym - def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName - def translatePath(path: String) = { - val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) - sym.toOption map flatPath - } - def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath - - private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { - /** Overridden here to try translating a simple name to the generated - * class name if the original attempt fails. This method is used by - * getResourceAsStream as well as findClass. - */ - override protected def findAbstractFile(name: String): AbstractFile = - super.findAbstractFile(name) match { - case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull - case file => file - } - } - private def makeClassLoader(): util.AbstractFileClassLoader = - new TranslatingClassLoader(parentClassLoader match { - case null => ScalaClassLoader fromURLs compilerClasspath - case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) - }) - - // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() - - def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) - def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted - - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) - case _ => () - } - } - None - } - - private def updateReplScope(sym: Symbol, isDefined: Boolean) { - def log(what: String) { - val mark = if (sym.isType) "t " else "v " - val name = exitingTyper(sym.nameString) - val info = cleanTypeAfterTyper(sym) - val defn = sym defStringSeenAs info - - scopelog(f"[$mark$what%6s] $name%-25s $defn%s") - } - if (ObjectClass isSubClass sym.owner) return - // unlink previous - replScope lookupAll sym.name foreach { sym => - log("unlink") - replScope unlink sym - } - val what = if (isDefined) "define" else "import" - log(what) - replScope enter sym - } - - def recordRequest(req: Request) { - if (req == null) - return - - prevRequests += req - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - exitingTyper { - req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => - val oldSym = replScope lookup newSym.name.companionName - if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { - replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") - replwarn("Companions must be defined together; you may wish to use :paste mode for this.") - } - } - } - exitingTyper { - req.imports foreach (sym => updateReplScope(sym, isDefined = false)) - req.defines foreach (sym => updateReplScope(sym, isDefined = true)) - } - } - - private[nsc] def replwarn(msg: => String) { - if (!settings.nowarnings) - printMessage(msg) - } - - def compileSourcesKeepingRun(sources: SourceFile*) = { - val run = new Run() - assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") - reporter.reset() - run compileSources sources.toList - (!reporter.hasErrors, run) - } - - /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false otherwise. - */ - def compileSources(sources: SourceFile*): Boolean = - compileSourcesKeepingRun(sources: _*)._1 - - /** Compile a string. Returns true if there are no - * compilation errors, or false otherwise. - */ - def compileString(code: String): Boolean = - compileSources(new BatchSourceFile(" + + + + // scalastyle:on + } + + private def planVisualization(metrics: Map[Long, Any], graph: SparkPlanGraph): Seq[Node] = { + val metadata = graph.nodes.flatMap { node => + val nodeId = s"plan-meta-data-${node.id}" +
    {node.desc}
    + } + +
    +
    + + {planVisualizationResources} + +
    + } + + private def jobURL(jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala new file mode 100644 index 000000000000..5779c71f64e9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import scala.collection.mutable + +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{JobExecutionStatus, Logging} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} + +private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { + + private val retainedExecutions = + sqlContext.sparkContext.conf.getInt("spark.sql.ui.retainedExecutions", 1000) + + private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() + + // Old data in the following fields must be removed in "trimExecutionsIfNecessary". + // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data + private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() + + /** + * Maintain the relation between job id and execution id so that we can get the execution id in + * the "onJobEnd" method. + */ + private val _jobIdToExecutionId = mutable.HashMap[Long, Long]() + + private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]() + + private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { + _executionIdToData.toMap + } + + def jobIdToExecutionId: Map[Long, Long] = synchronized { + _jobIdToExecutionId.toMap + } + + def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { + _stageIdToStageMetrics.toMap + } + + private def trimExecutionsIfNecessary( + executions: mutable.ListBuffer[SQLExecutionUIData]): Unit = { + if (executions.size > retainedExecutions) { + val toRemove = math.max(retainedExecutions / 10, 1) + executions.take(toRemove).foreach { execution => + for (executionUIData <- _executionIdToData.remove(execution.executionId)) { + for (jobId <- executionUIData.jobs.keys) { + _jobIdToExecutionId.remove(jobId) + } + for (stageId <- executionUIData.stages) { + _stageIdToStageMetrics.remove(stageId) + } + } + } + executions.trimStart(toRemove) + } + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val executionIdString = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdString == null) { + // This is not a job created by SQL + return + } + val executionId = executionIdString.toLong + val jobId = jobStart.jobId + val stageIds = jobStart.stageIds + + synchronized { + activeExecutions.get(executionId).foreach { executionUIData => + executionUIData.jobs(jobId) = JobExecutionStatus.RUNNING + executionUIData.stages ++= stageIds + stageIds.foreach(stageId => + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId = 0)) + _jobIdToExecutionId(jobId) = executionId + } + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { + val jobId = jobEnd.jobId + for (executionId <- _jobIdToExecutionId.get(jobId); + executionUIData <- _executionIdToData.get(executionId)) { + jobEnd.jobResult match { + case JobSucceeded => executionUIData.jobs(jobId) = JobExecutionStatus.SUCCEEDED + case JobFailed(_) => executionUIData.jobs(jobId) = JobExecutionStatus.FAILED + } + if (executionUIData.completionTime.nonEmpty && !executionUIData.hasRunningJobs) { + // We are the last job of this execution, so mark the execution as finished. Note that + // `onExecutionEnd` also does this, but currently that can be called before `onJobEnd` + // since these are called on different threads. + markExecutionFinished(executionId) + } + } + } + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + } + } + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + val stageId = stageSubmitted.stageInfo.stageId + val stageAttemptId = stageSubmitted.stageInfo.attemptId + // Always override metrics for old stage attempt + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics, + finishTask = true) + } + + /** + * Update the accumulator values of a task with the latest metrics for this task. This is called + * every time we receive an executor heartbeat or when a task finishes. + */ + private def updateTaskAccumulatorValues( + taskId: Long, + stageId: Int, + stageAttemptID: Int, + metrics: TaskMetrics, + finishTask: Boolean): Unit = { + if (metrics == null) { + return + } + + _stageIdToStageMetrics.get(stageId) match { + case Some(stageMetrics) => + if (stageAttemptID < stageMetrics.stageAttemptId) { + // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. + } else if (stageAttemptID > stageMetrics.stageAttemptId) { + logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + + s"what we have seen (${stageMetrics.stageAttemptId})") + } else { + // TODO We don't know the attemptId. Currently, what we can do is overriding the + // accumulator updates. However, if there are two same task are running, such as + // speculation, the accumulator updates will be overriding by different task attempts, + // the results will be weird. + stageMetrics.taskIdToMetricUpdates.get(taskId) match { + case Some(taskMetrics) => + if (finishTask) { + taskMetrics.finished = true + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + } else if (!taskMetrics.finished) { + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + } else { + // If a task is finished, we should not override with accumulator updates from + // heartbeat reports + } + case None => + // TODO Now just set attemptId to 0. Should fix here when we can get the attempt + // id from SparkListenerExecutorMetricsUpdate + stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( + attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + } + } + case None => + // This execution and its stage have been dropped + } + } + + def onExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + physicalPlanGraph: SparkPlanGraph, + time: Long): Unit = { + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + + val executionUIData = new SQLExecutionUIData(executionId, description, details, + physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + } + + def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } + } + } + + private def markExecutionFinished(executionId: Long): Unit = { + activeExecutions.remove(executionId).foreach { executionUIData => + if (executionUIData.isFailed) { + failedExecutions += executionUIData + trimExecutionsIfNecessary(failedExecutions) + } else { + completedExecutions += executionUIData + trimExecutionsIfNecessary(completedExecutions) + } + } + } + + def getRunningExecutions: Seq[SQLExecutionUIData] = synchronized { + activeExecutions.values.toSeq + } + + def getFailedExecutions: Seq[SQLExecutionUIData] = synchronized { + failedExecutions + } + + def getCompletedExecutions: Seq[SQLExecutionUIData] = synchronized { + completedExecutions + } + + def getExecution(executionId: Long): Option[SQLExecutionUIData] = synchronized { + _executionIdToData.get(executionId) + } + + /** + * Get all accumulator updates from all tasks which belong to this execution and merge them. + */ + def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized { + _executionIdToData.get(executionId) match { + case Some(executionUIData) => + val accumulatorUpdates = { + for (stageId <- executionUIData.stages; + stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; + taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; + accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { + accumulatorUpdate + } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } + mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => + executionUIData.accumulatorMetrics(accumulatorId).metricParam). + mapValues(_.asInstanceOf[SQLMetricValue[_]].value) + case None => + // This execution has been dropped + Map.empty + } + } + + private def mergeAccumulatorUpdates( + accumulatorUpdates: Seq[(Long, Any)], + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { + accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => + val param = paramFunc(accumulatorId) + (accumulatorId, + values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) + } + } + +} + +/** + * Represent all necessary data for an execution that will be used in Web UI. + */ +private[ui] class SQLExecutionUIData( + val executionId: Long, + val description: String, + val details: String, + val physicalPlanDescription: String, + val physicalPlanGraph: SparkPlanGraph, + val accumulatorMetrics: Map[Long, SQLPlanMetric], + val submissionTime: Long, + var completionTime: Option[Long] = None, + val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty, + val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer()) { + + /** + * Return whether there are running jobs in this execution. + */ + def hasRunningJobs: Boolean = jobs.values.exists(_ == JobExecutionStatus.RUNNING) + + /** + * Return whether there are any failed jobs in this execution. + */ + def isFailed: Boolean = jobs.values.exists(_ == JobExecutionStatus.FAILED) + + def runningJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.RUNNING }.keys.toSeq + + def succeededJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.SUCCEEDED }.keys.toSeq + + def failedJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.FAILED }.keys.toSeq +} + +/** + * Represent a metric in a SQLPlan. + * + * Because we cannot revert our changes for an "Accumulator", we need to maintain accumulator + * updates for each task. So that if a task is retried, we can simply override the old updates with + * the new updates of the new attempt task. Since we cannot add them to accumulator, we need to use + * "AccumulatorParam" to get the aggregation value. + */ +private[ui] case class SQLPlanMetric( + name: String, + accumulatorId: Long, + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) + +/** + * Store all accumulatorUpdates for all tasks in a Spark stage. + */ +private[ui] class SQLStageMetrics( + val stageAttemptId: Long, + val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) + +/** + * Store all accumulatorUpdates for a Spark task. + */ +private[ui] class SQLTaskMetrics( + val attemptId: Long, // TODO not used yet + var finished: Boolean, + var accumulatorUpdates: Map[Long, Any]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala new file mode 100644 index 000000000000..0b0867f67eb6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.ui.{SparkUI, SparkUITab} + +private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + + val parent = sparkUI + val listener = sqlContext.listener + + attachPage(new AllExecutionsPage(this)) + attachPage(new ExecutionPage(this)) + parent.attachTab(this) + + parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql") +} + +private[sql] object SQLTab { + + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL$nextId" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala new file mode 100644 index 000000000000..ae3d752dde34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} + +/** + * A graph used for storing information of an executionPlan of DataFrame. + * + * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the + * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. + */ +private[ui] case class SparkPlanGraph( + nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { + + def makeDotFile(metrics: Map[Long, Any]): String = { + val dotFile = new StringBuilder + dotFile.append("digraph G {\n") + nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n")) + edges.foreach(edge => dotFile.append(edge.makeDotEdge + "\n")) + dotFile.append("}") + dotFile.toString() + } +} + +private[sql] object SparkPlanGraph { + + /** + * Build a SparkPlanGraph from the root of a SparkPlan tree. + */ + def apply(plan: SparkPlan): SparkPlanGraph = { + val nodeIdGenerator = new AtomicLong(0) + val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() + val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() + buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + new SparkPlanGraph(nodes, edges) + } + + private def buildSparkPlanGraphNode( + plan: SparkPlan, + nodeIdGenerator: AtomicLong, + nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + } + val node = SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodes += node + val childrenNodes = plan.children.map( + child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) + for (child <- childrenNodes) { + edges += SparkPlanGraphEdge(child.id, node.id) + } + node + } +} + +/** + * Represent a node in the SparkPlan tree, along with its metrics. + * + * @param id generated by "SparkPlanGraph". There is no duplicate id in a graph + * @param name the name of this SparkPlan node + * @param metrics metrics that this SparkPlan node will track + */ +private[ui] case class SparkPlanGraphNode( + id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { + + def makeDotNode(metricsValue: Map[Long, Any]): String = { + val values = { + for (metric <- metrics; + value <- metricsValue.get(metric.accumulatorId)) yield { + metric.name + ": " + value + } + } + val label = if (values.isEmpty) { + name + } else { + // If there are metrics, display all metrics in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + name + "\\n \\n" + values.mkString("\\n") + } + s""" $id [label="$label"];""" + } +} + +/** + * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child + * node id. + */ +private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { + + def makeDotEdge: String = s""" $fromId->$toId;\n""" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala new file mode 100644 index 000000000000..e9b60841fc28 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +object Window { + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + spec.partitionBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + spec.partitionBy(cols : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + spec.orderBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + spec.orderBy(cols : _*) + } + + private def spec: WindowSpec = { + new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + } + +} + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala new file mode 100644 index 000000000000..c3d224629702 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * :: Experimental :: + * A window specification that defines the partitioning, ordering, and frame boundaries. + * + * Use the static methods in [[Window]] to create a [[WindowSpec]]. + * + * @since 1.4.0 + */ +@Experimental +class WindowSpec private[sql]( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + frame: catalyst.expressions.WindowFrame) { + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + new WindowSpec(cols.map(_.expr), orderSpec, frame) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowSpec(partitionSpec, sortOrder, frame) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rowsBetween(start: Long, end: Long): WindowSpec = { + between(RowFrame, start, end) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rangeBetween(start: Long, end: Long): WindowSpec = { + between(RangeFrame, start, end) + } + + private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-end.toInt) + case x if x > 0 => ValueFollowing(end.toInt) + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + } + + /** + * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. + */ + private[sql] def withAggregate(aggregate: Column): Column = { + val windowExpr = aggregate.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in window operation.") + } + new Column(windowExpr) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala new file mode 100644 index 000000000000..258afadc7695 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * The base class for implementing user-defined aggregate functions (UDAF). + */ +@Experimental +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def dataType: DataType + + /** + * Returns true iff this function is deterministic, i.e. given the same input, + * always return the same output. + */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. + * + * The contract should be that applying the merge function on two initial buffers should just + * return the initial buffer itself, i.e. + * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** + * Updates the given aggregation buffer `buffer` with new input data from `input`. + * + * This is called once per input row. + */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** + * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. + * + * This is called when we merge two partially aggregated data together. + */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any + + /** + * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } + + /** + * Creates a [[Column]] for this UDAF using the distinct values of the given + * [[Column]]s as input arguments. + */ + @scala.annotation.varargs + def distinct(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = true) + Column(aggregateExpression) + } +} + +/** + * :: Experimental :: + * A [[Row]] representing an mutable aggregation buffer. + * + * This is not meant to be extended outside of Spark. + */ +@Experimental +abstract class MutableAggregationBuffer extends Row { + + /** Update the ith value of this buffer. */ + def update(i: Int, value: Any): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala new file mode 100644 index 000000000000..435e6319a64c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -0,0 +1,2548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.util.Try + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Functions available for [[DataFrame]]. + * + * @groupname udf_funcs UDF functions + * @groupname agg_funcs Aggregate functions + * @groupname datetime_funcs Date time functions + * @groupname sort_funcs Sorting functions + * @groupname normal_funcs Non-aggregate functions + * @groupname math_funcs Math functions + * @groupname misc_funcs Misc functions + * @groupname window_funcs Window functions + * @groupname string_funcs String functions + * @groupname collection_funcs Collection functions + * @groupname Ungrouped Support functions for DataFrames. + * @since 1.3.0 + */ +@Experimental +// scalastyle:off +object functions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Returns a [[Column]] based on the given column name. + * + * @group normal_funcs + * @since 1.3.0 + */ + def col(colName: String): Column = Column(colName) + + /** + * Returns a [[Column]] based on the given column name. Alias of [[col]]. + * + * @group normal_funcs + * @since 1.3.0 + */ + def column(colName: String): Column = Column(colName) + + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. + * + * @group normal_funcs + * @since 1.3.0 + */ + def lit(literal: Any): Column = { + literal match { + case c: Column => return c + case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) + case _ => // continue + } + + val literalExpr = Literal(literal) + Column(literalExpr) + } + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Sort functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns a sort expression based on ascending order of the column. + * {{{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 1.3.0 + */ + def asc(columnName: String): Column = Column(columnName).asc + + /** + * Returns a sort expression based on the descending order of the column. + * {{{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 1.3.0 + */ + def desc(columnName: String): Column = Column(columnName).desc + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Aggregate functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approxCountDistinct(Column(columnName), rsd) + } + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def avg(e: Column): Column = Average(e.expr) + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def avg(columnName: String): Column = avg(Column(columnName)) + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def count(e: Column): Column = e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def count(columnName: String): Column = count(Column(columnName)) + + /** + * Aggregate function: returns the number of distinct items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + /** + * Aggregate function: returns the number of distinct items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + @scala.annotation.varargs + def countDistinct(columnName: String, columnNames: String*): Column = + countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) + + /** + * Aggregate function: returns the first value in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def first(e: Column): Column = First(e.expr) + + /** + * Aggregate function: returns the first value of a column in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def first(columnName: String): Column = first(Column(columnName)) + + /** + * Aggregate function: returns the last value in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(e: Column): Column = Last(e.expr) + + /** + * Aggregate function: returns the last value of the column in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def last(columnName: String): Column = last(Column(columnName)) + + /** + * Aggregate function: returns the maximum value of the expression in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def max(e: Column): Column = Max(e.expr) + + /** + * Aggregate function: returns the maximum value of the column in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def max(columnName: String): Column = max(Column(columnName)) + + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + * @since 1.4.0 + */ + def mean(e: Column): Column = avg(e) + + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + * @since 1.4.0 + */ + def mean(columnName: String): Column = avg(columnName) + + /** + * Aggregate function: returns the minimum value of the expression in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def min(e: Column): Column = Min(e.expr) + + /** + * Aggregate function: returns the minimum value of the column in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def min(columnName: String): Column = min(Column(columnName)) + + /** + * Aggregate function: returns the sum of all values in the expression. + * + * @group agg_funcs + * @since 1.3.0 + */ + def sum(e: Column): Column = Sum(e.expr) + + /** + * Aggregate function: returns the sum of all values in the given column. + * + * @group agg_funcs + * @since 1.3.0 + */ + def sum(columnName: String): Column = sum(Column(columnName)) + + /** + * Aggregate function: returns the sum of distinct values in the expression. + * + * @group agg_funcs + * @since 1.3.0 + */ + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + + /** + * Aggregate function: returns the sum of distinct values in the expression. + * + * @group agg_funcs + * @since 1.3.0 + */ + def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the cumulative distribution of values within a window partition, + * i.e. the fraction of rows that are below the current row. + * + * {{{ + * N = total number of rows in the partition + * cumeDist(x) = number of values before (and including) x / N + * }}} + * + * + * This is equivalent to the CUME_DIST function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition, without any gaps. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the DENSE_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int): Column = { + lag(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int): Column = { + lag(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int, defaultValue: Any): Column = { + lag(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int): Column = { + lead(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int): Column = { + lead(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int, defaultValue: Any): Column = { + lead(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def ntile(n: Int): Column = { + UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) + } + + /** + * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. + * + * This is computed by: + * {{{ + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * }}} + * + * This is equivalent to the PERCENT_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rank(): Column = { + UnresolvedWindowFunction("rank", Nil) + } + + /** + * Window function: returns a sequential number starting at 1 within a window partition. + * + * This is equivalent to the ROW_NUMBER function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) + } + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Non-aggregate functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the absolute value. + * + * @group normal_funcs + * @since 1.3.0 + */ + def abs(e: Column): Column = Abs(e.expr) + + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + * @since 1.4.0 + */ + @scala.annotation.varargs + def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + * @since 1.4.0 + */ + def array(colName: String, colNames: String*): Column = { + array((colName +: colNames).map(col) : _*) + } + + /** + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ + def broadcast(df: DataFrame): DataFrame = { + DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + } + + /** + * Returns the first column that is not null, or null if all inputs are null. + * + * For example, `coalesce(a, b, c)` will return a if a is not null, + * or b if a is null and b is not null, or c if both a and b are null but c is not null. + * + * @group normal_funcs + * @since 1.3.0 + */ + @scala.annotation.varargs + def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + + /** + * Creates a string column for the file name of the current Spark task. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() + + /** + * Return true iff the column is NaN. + * + * @group normal_funcs + * @since 1.5.0 + */ + def isNaN(e: Column): Column = IsNaN(e.expr) + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.4.0 + */ + def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() + + /** + * Returns col1 if it is not NaN, or col2 if col1 is NaN. + * + * Both inputs should be floating point columns (DoubleType or FloatType). + * + * @group normal_funcs + * @since 1.5.0 + */ + def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr) + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * // Scala: + * df.select( -df("amount") ) + * + * // Java: + * df.select( negate(df.col("amount")) ); + * }}} + * + * @group normal_funcs + * @since 1.3.0 + */ + def negate(e: Column): Column = -e + + /** + * Inversion of boolean expression, i.e. NOT. + * {{{ + * // Scala: select rows that are not active (isActive === false) + * df.filter( !df("isActive") ) + * + * // Java: + * df.filter( not(df.col("isActive")) ); + * }}} + * + * @group normal_funcs + * @since 1.3.0 + */ + def not(e: Column): Column = !e + + /** + * Generate a random column with i.i.d. samples from U[0.0, 1.0]. + * + * @group normal_funcs + * @since 1.4.0 + */ + def rand(seed: Long): Column = Rand(seed) + + /** + * Generate a random column with i.i.d. samples from U[0.0, 1.0]. + * + * @group normal_funcs + * @since 1.4.0 + */ + def rand(): Column = rand(Utils.random.nextLong) + + /** + * Generate a column with i.i.d. samples from the standard normal distribution. + * + * @group normal_funcs + * @since 1.4.0 + */ + def randn(seed: Long): Column = Randn(seed) + + /** + * Generate a column with i.i.d. samples from the standard normal distribution. + * + * @group normal_funcs + * @since 1.4.0 + */ + def randn(): Column = randn(Utils.random.nextLong) + + /** + * Partition ID of the Spark task. + * + * Note that this is indeterministic because it depends on data partitioning and task scheduling. + * + * @group normal_funcs + * @since 1.4.0 + */ + def sparkPartitionId(): Column = SparkPartitionID() + + /** + * Computes the square root of the specified float value. + * + * @group math_funcs + * @since 1.3.0 + */ + def sqrt(e: Column): Column = Sqrt(e.expr) + + /** + * Computes the square root of the specified float value. + * + * @group math_funcs + * @since 1.5.0 + */ + def sqrt(colName: String): Column = sqrt(Column(colName)) + + /** + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... + * + * @group normal_funcs + * @since 1.4.0 + */ + @scala.annotation.varargs + def struct(cols: Column*): Column = { + CreateStruct(cols.map(_.expr)) + } + + /** + * Creates a new struct column that composes multiple input columns. + * + * @group normal_funcs + * @since 1.4.0 + */ + def struct(colName: String, colNames: String*): Column = { + struct((colName +: colNames).map(col) : _*) + } + + /** + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * + * {{{ + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) + * }}} + * + * @group normal_funcs + * @since 1.4.0 + */ + def when(condition: Column, value: Any): Column = { + CaseWhen(Seq(condition.expr, lit(value).expr)) + } + + /** + * Computes bitwise NOT. + * + * @group normal_funcs + * @since 1.4.0 + */ + def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + + /** + * Parses the expression string into the column that it represents, similar to + * DataFrame.selectExpr + * {{{ + * // get the number of words of each length + * df.groupBy(expr("length(word)")).count() + * }}} + * + * @group normal_funcs + */ + def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Math Functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + * + * @group math_funcs + * @since 1.4.0 + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + * + * @group math_funcs + * @since 1.4.0 + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group math_funcs + * @since 1.4.0 + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group math_funcs + * @since 1.4.0 + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group math_funcs + * @since 1.4.0 + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(e: Column): Column = Bin(e.expr) + + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(columnName: String): Column = bin(Column(columnName)) + + /** + * Computes the cube-root of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the ceiling of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Convert a number in a string column from one base to another. + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(num: Column, fromBase: Int, toBase: Int): Column = + Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + + /** + * Computes the cosine of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the exponential of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + * + * @group math_funcs + * @since 1.4.0 + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Computes the factorial of the given value. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(e: Column): Column = Factorial(e.expr) + + /** + * Computes the floor of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Returns the greatest value of the list of values, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def greatest(exprs: Column*): Column = { + require(exprs.length > 1, "greatest requires at least 2 arguments.") + Greatest(exprs.map(_.expr)) + } + + /** + * Returns the greatest value of the list of column names, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def greatest(columnName: String, columnNames: String*): Column = { + greatest((columnName +: columnNames).map(Column.apply): _*) + } + + /** + * Computes hex value of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(column: Column): Column = Unhex(column.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group math_funcs + * @since 1.4.0 + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Returns the least value of the list of values, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(exprs: Column*): Column = { + require(exprs.length > 1, "least requires at least 2 arguments.") + Least(exprs.map(_.expr)) + } + + /** + * Returns the least value of the list of column names, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(columnName: String, columnNames: String*): Column = { + least((columnName +: columnNames).map(Column.apply): _*) + } + + /** + * Computes the natural logarithm of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, columnName: String): Column = log(base, Column(columnName)) + + /** + * Computes the logarithm of the given value in base 10. + * + * @group math_funcs + * @since 1.4.0 + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in base 10. + * + * @group math_funcs + * @since 1.4.0 + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * + * @group math_funcs + * @since 1.4.0 + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + * + * @group math_funcs + * @since 1.4.0 + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Computes the logarithm of the given column in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(expr: Column): Column = Log2(expr.expr) + + /** + * Computes the logarithm of the given value in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(columnName: String): Column = log2(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group math_funcs + * @since 1.4.0 + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group math_funcs + * @since 1.4.0 + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Round the value of `e` to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(e: Column, numBits: Int): Column = + ShiftRightUnsigned(e.expr, lit(numBits).expr) + + /** + * Computes the signum of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the sine of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the tangent of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + * + * @group math_funcs + * @since 1.4.0 + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + * + * @group math_funcs + * @since 1.4.0 + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group math_funcs + * @since 1.4.0 + */ + def toDegrees(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group math_funcs + * @since 1.4.0 + */ + def toDegrees(columnName: String): Column = toDegrees(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group math_funcs + * @since 1.4.0 + */ + def toRadians(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group math_funcs + * @since 1.4.0 + */ + def toRadians(columnName: String): Column = toRadians(Column(columnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Misc functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(e: Column): Column = Md5(e.expr) + + /** + * Calculates the SHA-1 digest of a binary column and returns the value + * as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-2 family of hash functions of a binary column and + * returns the value as a hex string. + * + * @param e column to compute SHA-2 on. + * @param numBits one of 224, 256, 384, or 512. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the cyclic redundancy check value (CRC32) of a binary column and + * returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(e: Column): Column = Crc32(e.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // String functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the numeric value of the first character of the string column, and returns the + * result as a int column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ascii(e: Column): Column = Ascii(e.expr) + + /** + * Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. + * + * @group string_funcs + * @since 1.5.0 + */ + def base64(e: Column): Column = Base64(e.expr) + + /** + * Concatenates multiple input string columns together into a single string column. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + + /** + * Concatenates multiple input string columns together into a single string column, + * using the given separator. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: String, exprs: Column*): Column = { + ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) + } + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + + /** + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + + /** + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string column. + * + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + + /** + * Formats the arguments in printf-style and returns the result as a string column. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) + } + + /** + * Returns a new string column by converting the first letter of each word to uppercase. + * Words are delimited by whitespace. + * + * For example, "hello world" will become "Hello World". + * + * @group string_funcs + * @since 1.5.0 + */ + def initcap(e: Column): Column = InitCap(e.expr) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) + + /** + * Computes the length of a given string or binary column. + * + * @group string_funcs + * @since 1.5.0 + */ + def length(e: Column): Column = Length(e.expr) + + /** + * Converts a string column to lower case. + * + * @group string_funcs + * @since 1.3.0 + */ + def lower(e: Column): Column = Lower(e.expr) + + /** + * Computes the Levenshtein distance of the two given string columns. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + + /** + * Locate the position of the first occurrence of substr. + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: Column): Column = { + new StringLocate(lit(substr).expr, str.expr) + } + + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * + * NOTE: The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: Column, pos: Int): Column = { + StringLocate(lit(substr).expr, str.expr, lit(pos).expr) + } + + /** + * Left-pad the string column with + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Int, pad: String): Column = { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) + } + + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = StringTrimLeft(e.expr) + + /** + * Extract a specific(idx) group identified by a java regex, from the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) + } + + /** + * Replace all substrings of the specified string value that match regexp with rep. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) + } + + /** + * Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. + * + * @group string_funcs + * @since 1.5.0 + */ + def unbase64(e: Column): Column = UnBase64(e.expr) + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Int, pad: String): Column = { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) + } + + /** + * Repeats a string column n times, and returns it as a new string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, n: Int): Column = { + StringRepeat(str.expr, lit(n).expr) + } + + /** + * Reverses the string column and returns it as a new string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: Column): Column = { + StringReverse(str.expr) + } + + /** + * Trim the spaces from right end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(e: Column): Column = StringTrimRight(e.expr) + + /** + * * Return the soundex code for the specified expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def soundex(e: Column): Column = SoundEx(e.expr) + + /** + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) + } + + /** + * Substring starts at `pos` and is of length `len` when str is String type or + * returns the slice of byte array that starts at `pos` in byte and is of length `len` + * when str is Binary type + * + * @group string_funcs + * @since 1.5.0 + */ + def substring(str: Column, pos: Int, len: Int): Column = + Substring(str.expr, lit(pos).expr, lit(len).expr) + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + * + * @group string_funcs + */ + def substring_index(str: Column, delim: String, count: Int): Column = + SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + + /** + * Translate any character in the src by a character in replaceString. + * The characters in replaceString is corresponding to the characters in matchingString. + * The translate will happen when any character in the string matching with the character + * in the matchingString. + * + * @group string_funcs + * @since 1.5.0 + */ + def translate(src: Column, matchingString: String, replaceString: String): Column = + StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + + /** + * Trim the spaces from both ends for the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(e: Column): Column = StringTrim(e.expr) + + /** + * Converts a string column to upper case. + * + * @group string_funcs + * @since 1.3.0 + */ + def upper(e: Column): Column = Upper(e.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // DateTime functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns the date that is numMonths after startDate. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def add_months(startDate: Column, numMonths: Int): Column = + AddMonths(startDate.expr, Literal(numMonths)) + + /** + * Returns the current date as a date column. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_date(): Column = CurrentDate() + + /** + * Returns the current timestamp as a timestamp column. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_timestamp(): Column = CurrentTimestamp() + + /** + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. + * + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def date_format(dateExpr: Column, format: String): Column = + DateFormatClass(dateExpr.expr, Literal(format)) + + /** + * Returns the date that is `days` days after `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + + /** + * Returns the date that is `days` days before `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + + /** + * Returns the number of days from `start` to `end`. + * @group datetime_funcs + * @since 1.5.0 + */ + def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr) + + /** + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def year(e: Column): Column = Year(e.expr) + + /** + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def quarter(e: Column): Column = Quarter(e.expr) + + /** + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def month(e: Column): Column = Month(e.expr) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + + /** + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofyear(e: Column): Column = DayOfYear(e.expr) + + /** + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def hour(e: Column): Column = Hour(e.expr) + + /** + * Given a date column, returns the last day of the month which the given date belongs to. + * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the + * month in July 2015. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def last_day(e: Column): Column = LastDay(e.expr) + + /** + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def minute(e: Column): Column = Minute(e.expr) + + /* + * Returns number of months between dates `date1` and `date2`. + * @group datetime_funcs + * @since 1.5.0 + */ + def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + + /** + * Given a date column, returns the first date which is later than the value of the date column + * that is on the specified day of the week. + * + * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first + * Sunday after 2015-07-27. + * + * Day of the week parameter is case insensitive, and accepts: + * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + * + * @group datetime_funcs + * @since 1.5.0 + */ + def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + + /** + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def second(e: Column): Column = Second(e.expr) + + /** + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def weekofyear(e: Column): Column = WeekOfYear(e.expr) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + + /** + * Gets current Unix timestamp in seconds. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale, return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Convert time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + + /** + * Converts the column into DateType. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def to_date(e: Column): Column = ToDate(e.expr) + + /** + * Returns date truncated to the unit specified by the format. + * + * @param format: 'year', 'yyyy', 'yy' for truncate by year, + * or 'month', 'mon', 'mm' for truncate by month + * + * @group datetime_funcs + * @since 1.5.0 + */ + def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + + /** + * Assumes given timestamp is UTC and converts to given timezone. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_utc_timestamp(ts: Column, tz: String): Column = + FromUTCTimestamp(ts.expr, Literal(tz).expr) + + /** + * Assumes given timestamp is in given timezone and converts to UTC. + * @group datetime_funcs + * @since 1.5.0 + */ + def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Collection functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = + ArrayContains(column.expr, Literal(value)) + + /** + * Creates a new row for each element in the given array or map column. + * + * @group collection_funcs + * @since 1.3.0 + */ + def explode(e: Column): Column = Explode(e.expr) + + /** + * Returns length of array or map. + * + * @group collection_funcs + * @since 1.5.0 + */ + def size(e: Column): Column = Size(e.expr) + + /** + * Sorts the input array for the given column in ascending order, + * according to the natural ordering of the array elements. + * + * @group collection_funcs + * @since 1.5.0 + */ + def sort_array(e: Column): Column = sort_array(e, asc = true) + + /** + * Sorts the input array for the given column in ascending / descending order, + * according to the natural ordering of the array elements. + * + * @group collection_funcs + * @since 1.5.0 + */ + def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + + // scalastyle:off + + /* Use the following code to generate: + (0 to 10).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) + println(s""" + /** + * Defines a user-defined function of ${x} arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + val inputTypes = Try($inputTypes).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + }""") + } + + (0 to 10).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUDF(f, returnType, Seq($argsInUDF)) + }""") + } + } + */ + /** + * Defines a user-defined function of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { + val inputTypes = Try(Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + /** + * Defines a user-defined function of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + * @since 1.3.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUDF(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ + @deprecated("Use udf", "1.5.0") + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + // scalastyle:on + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) + * }}} + * + * @group udf_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def callUDF(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) + } + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) + * }}} + * + * @group udf_funcs + * @since 1.4.0 + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF + */ + @deprecated("Use callUDF", "1.5.0") + def callUdf(udfName: String, cols: Column*): Column = { + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs, isDistinct = false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala deleted file mode 100644 index 1704be7fcbd3..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import org.apache.spark.sql.types._ - -import java.sql.Types - - -/** - * Encapsulates workarounds for the extensions, quirks, and bugs in various - * databases. Lots of databases define types that aren't explicitly supported - * by the JDBC spec. Some JDBC drivers also report inaccurate - * information---for instance, BIT(n>1) being reported as a BIT type is quite - * common, even though BIT in JDBC is meant for single-bit values. Also, there - * does not appear to be a standard name for an unbounded string or binary - * type; we use BLOB and CLOB by default but override with database-specific - * alternatives when these are absent or do not behave correctly. - * - * Currently, the only thing DriverQuirks does is handle type mapping. - * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` - * is used when writing to a JDBC table. If `getCatalystType` returns `null`, - * the default type handling is used for the given JDBC type. Similarly, - * if `getJDBCType` returns `(null, None)`, the default type handling is used - * for the given Catalyst type. - */ -private[sql] abstract class DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType - def getJDBCType(dt: DataType): (String, Option[Int]) -} - -private[sql] object DriverQuirks { - /** - * Fetch the DriverQuirks class corresponding to a given database url. - */ - def get(url: String): DriverQuirks = { - if (url.substring(0, 10).equals("jdbc:mysql")) { - new MySQLQuirks() - } else if (url.substring(0, 15).equals("jdbc:postgresql")) { - new PostgresQuirks() - } else { - new NoQuirks() - } - } -} - -private[sql] class NoQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = - null - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} - -private[sql] class PostgresQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - BinaryType - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - StringType - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - StringType - } else null - } - - def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { - case StringType => ("TEXT", Some(java.sql.Types.CHAR)) - case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) - case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) - case _ => (null, None) - } -} - -private[sql] class MySQLQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - LongType - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - BooleanType - } else null - } - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala deleted file mode 100644 index a2f94675fb5a..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ /dev/null @@ -1,417 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, ResultSetMetaData, SQLException} -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.util.NextIterator -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources._ - -private[sql] object JDBCRDD extends Logging { - /** - * Maps a JDBC type to a Catalyst type. This function is called only when - * the DriverQuirks class corresponding to your database driver returns null. - * - * @param sqlType - A field of java.sql.Types - * @return The Catalyst type corresponding to sqlType. - */ - private def getCatalystType(sqlType: Int): DataType = { - val answer = sqlType match { - case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => LongType - case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers. - case java.sql.Types.BLOB => BinaryType - case java.sql.Types.BOOLEAN => BooleanType - case java.sql.Types.CHAR => StringType - case java.sql.Types.CLOB => StringType - case java.sql.Types.DATALINK => null - case java.sql.Types.DATE => DateType - case java.sql.Types.DECIMAL => DecimalType.Unlimited - case java.sql.Types.DISTINCT => null - case java.sql.Types.DOUBLE => DoubleType - case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => IntegerType - case java.sql.Types.JAVA_OBJECT => null - case java.sql.Types.LONGNVARCHAR => StringType - case java.sql.Types.LONGVARBINARY => BinaryType - case java.sql.Types.LONGVARCHAR => StringType - case java.sql.Types.NCHAR => StringType - case java.sql.Types.NCLOB => StringType - case java.sql.Types.NULL => null - case java.sql.Types.NUMERIC => DecimalType.Unlimited - case java.sql.Types.OTHER => null - case java.sql.Types.REAL => DoubleType - case java.sql.Types.REF => StringType - case java.sql.Types.ROWID => LongType - case java.sql.Types.SMALLINT => IntegerType - case java.sql.Types.SQLXML => StringType - case java.sql.Types.STRUCT => StringType - case java.sql.Types.TIME => TimestampType - case java.sql.Types.TIMESTAMP => TimestampType - case java.sql.Types.TINYINT => IntegerType - case java.sql.Types.VARBINARY => BinaryType - case java.sql.Types.VARCHAR => StringType - case _ => null - } - - if (answer == null) throw new SQLException("Unsupported type " + sqlType) - answer - } - - /** - * Takes a (schema, table) specification and returns the table's Catalyst - * schema. - * - * @param url - The JDBC url to fetch information from. - * @param table - The table name of the desired table. This may also be a - * SQL query wrapped in parentheses. - * - * @return A StructType giving the table's Catalyst schema. - * @throws SQLException if the table specification is garbage. - * @throws SQLException if the table contains an unsupported type. - */ - def resolveTable(url: String, table: String): StructType = { - val quirks = DriverQuirks.get(url) - val conn: Connection = DriverManager.getConnection(url) - try { - val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() - try { - val rsmd = rs.getMetaData - val ncols = rsmd.getColumnCount - var fields = new Array[StructField](ncols); - var i = 0 - while (i < ncols) { - val columnName = rsmd.getColumnName(i + 1) - val dataType = rsmd.getColumnType(i + 1) - val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) - var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) - if (columnType == null) columnType = getCatalystType(dataType) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) - i = i + 1 - } - return new StructType(fields) - } finally { - rs.close() - } - } finally { - conn.close() - } - - throw new RuntimeException("This line is unreachable.") - } - - /** - * Prune all but the specified columns from the specified Catalyst schema. - * - * @param schema - The Catalyst schema of the master table - * @param columns - The list of desired columns - * - * @return A Catalyst schema corresponding to columns in the given order. - */ - private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { - val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*) - new StructType(columns map { name => fieldMap(name) }) - } - - /** - * Given a driver string and an url, return a function that loads the - * specified driver string then returns a connection to the JDBC url. - * getConnector is run on the driver code, while the function it returns - * is run on the executor. - * - * @param driver - The class name of the JDBC driver for the given url. - * @param url - The JDBC url to connect to. - * - * @return A function that loads the driver and connects to the url. - */ - def getConnector(driver: String, url: String): () => Connection = { - () => { - try { - if (driver != null) Class.forName(driver) - } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } - } - DriverManager.getConnection(url) - } - } - /** - * Build and return JDBCRDD from the given information. - * - * @param sc - Your SparkContext. - * @param schema - The Catalyst schema of the underlying database table. - * @param driver - The class name of the JDBC driver for the given url. - * @param url - The JDBC url to connect to. - * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. - * @param requiredColumns - The names of the columns to SELECT. - * @param filters - The filters to include in all WHERE clauses. - * @param parts - An array of JDBCPartitions specifying partition ids and - * per-partition WHERE clauses. - * - * @return An RDD representing "SELECT requiredColumns FROM fqTable". - */ - def scanTable(sc: SparkContext, - schema: StructType, - driver: String, - url: String, - fqTable: String, - requiredColumns: Array[String], - filters: Array[Filter], - parts: Array[Partition]): RDD[Row] = { - val prunedSchema = pruneSchema(schema, requiredColumns) - - return new JDBCRDD(sc, - getConnector(driver, url), - prunedSchema, - fqTable, - requiredColumns, - filters, - parts) - } -} - -/** - * An RDD representing a table in a database accessed via JDBC. Both the - * driver code and the workers must be able to access the database; the driver - * needs to fetch the schema while the workers need to fetch the data. - */ -private[sql] class JDBCRDD( - sc: SparkContext, - getConnection: () => Connection, - schema: StructType, - fqTable: String, - columns: Array[String], - filters: Array[Filter], - partitions: Array[Partition]) - extends RDD[Row](sc, Nil) { - - /** - * Retrieve the list of partitions corresponding to this RDD. - */ - override def getPartitions: Array[Partition] = partitions - - /** - * `columns`, but as a String suitable for injection into a SQL query. - */ - private val columnList: String = { - val sb = new StringBuilder() - columns.foreach(x => sb.append(",").append(x)) - if (sb.length == 0) "1" else sb.substring(1) - } - - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. - */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = $value" - case LessThan(attr, value) => s"$attr < $value" - case GreaterThan(attr, value) => s"$attr > $value" - case LessThanOrEqual(attr, value) => s"$attr <= $value" - case GreaterThanOrEqual(attr, value) => s"$attr >= $value" - case _ => null - } - - /** - * `filters`, but as a WHERE clause suitable for injection into a SQL query. - */ - private val filterWhereClause: String = { - val filterStrings = filters map compileFilter filter (_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } - - /** - * A WHERE clause representing both `filters`, if any, and the current partition. - */ - private def getWhereClause(part: JDBCPartition): String = { - if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause - } else if (part.whereClause != null) { - "WHERE " + part.whereClause - } else { - filterWhereClause - } - } - - // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that - // we don't have to potentially poke around in the Metadata once for every - // row. - // Is there a better way to do this? I'd rather be using a type that - // contains only the tags I define. - abstract class JDBCConversion - case object BooleanConversion extends JDBCConversion - case object DateConversion extends JDBCConversion - case object DecimalConversion extends JDBCConversion - case object DoubleConversion extends JDBCConversion - case object FloatConversion extends JDBCConversion - case object IntegerConversion extends JDBCConversion - case object LongConversion extends JDBCConversion - case object BinaryLongConversion extends JDBCConversion - case object StringConversion extends JDBCConversion - case object TimestampConversion extends JDBCConversion - case object BinaryConversion extends JDBCConversion - - /** - * Maps a StructType to a type tag list. - */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray - } - - - /** - * Runs the SQL query against the JDBC driver. - */ - override def compute(thePart: Partition, context: TaskContext) = new Iterator[Row] { - var closed = false - var finished = false - var gotNext = false - var nextValue: Row = null - - context.addTaskCompletionListener{ context => close() } - val part = thePart.asInstanceOf[JDBCPartition] - val conn = getConnection() - - // H2's JDBC driver does not support the setSchema() method. We pass a - // fully-qualified table name in the SELECT statement. I don't know how to - // talk about a table in a completely portable way. - - val myWhereClause = getWhereClause(part) - - val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" - val stmt = conn.prepareStatement(sqlText, - ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val rs = stmt.executeQuery() - - val conversions = getConversions(schema) - val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) - - def getNext(): Row = { - if (rs.next()) { - var i = 0 - while (i < conversions.length) { - val pos = i + 1 - conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => mutableRow.update(i, rs.getDate(pos)) - case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) - case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { - val bytes = rs.getBytes(pos) - var ans = 0L - var j = 0 - while (j < bytes.size) { - ans = 256*ans + (255 & bytes(j)) - j = j + 1; - } - mutableRow.setLong(i, ans) - } - } - if (rs.wasNull) mutableRow.setNullAt(i) - i = i + 1 - } - mutableRow - } else { - finished = true - null.asInstanceOf[Row] - } - } - - def close() { - if (closed) return - try { - if (null != rs && ! rs.isClosed()) { - rs.close() - } - } catch { - case e: Exception => logWarning("Exception closing resultset", e) - } - try { - if (null != stmt && ! stmt.isClosed()) { - stmt.close() - } - } catch { - case e: Exception => logWarning("Exception closing statement", e) - } - try { - if (null != conn && ! conn.isClosed()) { - conn.close() - } - logInfo("closed connection") - } catch { - case e: Exception => logWarning("Exception closing connection", e) - } - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - nextValue = getNext() - if (finished) { - close() - } - gotNext = true - } - } - !finished - } - - override def next(): Row = { - if (!hasNext) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala deleted file mode 100644 index e09125e406ba..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import scala.collection.mutable.ArrayBuffer -import java.sql.DriverManager - -import org.apache.spark.Partition -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources._ - -/** - * Data corresponding to one partition of a JDBCRDD. - */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { - override def index: Int = idx -} - -/** - * Instructions on how to partition the table among workers. - */ -private[sql] case class JDBCPartitioningInfo( - column: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int) - -private[sql] object JDBCRelation { - /** - * Given a partitioning schematic (a column of integral type, a number of - * partitions, and upper and lower bounds on the column's value), generate - * WHERE clauses for each partition so that each row in the table appears - * exactly once. The parameters minValue and maxValue are advisory in that - * incorrect values may cause the partitioning to be poor, but no data - * will fail to be represented. - * - * @param column - Column name. Must refer to a column of integral type. - * @param numPartitions - Number of partitions - * @param minValue - Smallest value of column. Advisory. - * @param maxValue - Largest value of column. Advisory. - */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { - if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) - - val numPartitions = partitioning.numPartitions - val column = partitioning.column - if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) - // Overflow and silliness can happen if you subtract then divide. - // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions - - partitioning.lowerBound / numPartitions) - var i: Int = 0 - var currentValue: Long = partitioning.lowerBound - var ans = new ArrayBuffer[Partition]() - while (i < numPartitions) { - val lowerBound = (if (i != 0) s"$column >= $currentValue" else null) - currentValue += stride - val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null) - val whereClause = (if (upperBound == null) lowerBound - else if (lowerBound == null) upperBound - else s"$lowerBound AND $upperBound") - ans += JDBCPartition(whereClause, i) - i = i + 1 - } - ans.toArray - } -} - -private[sql] class DefaultSource extends RelationProvider { - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) Class.forName(driver) - - if ( partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo(partitionColumn, - lowerBound.toLong, upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(url, table, parts)(sqlContext) - } -} - -private[sql] case class JDBCRelation(url: String, - table: String, - parts: Array[Partition])( - @transient val sqlContext: SQLContext) - extends PrunedFilteredScan { - - override val schema = JDBCRDD.resolveTable(url, table) - - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { - val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName - JDBCRDD.scanTable(sqlContext.sparkContext, - schema, - driver, url, - table, - requiredColumns, filters, - parts) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala deleted file mode 100644 index 86bb67ec7425..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import org.apache.spark.sql.DataFrame - -private[jdbc] class JavaJDBCTrampoline { - def createJDBCTable(rdd: DataFrame, url: String, table: String, allowExisting: Boolean) { - rdd.createJDBCTable(url, table, allowExisting); - } - - def insertIntoJDBC(rdd: DataFrame, url: String, table: String, overwrite: Boolean) { - rdd.insertIntoJDBC(url, table, overwrite); - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala new file mode 100644 index 000000000000..8849fc2f1f0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * A database type definition coupled with the jdbc type needed to send null + * values to the database. + * @param databaseTypeDefinition The database type definition + * @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to + * send a null value to the database. + */ +@DeveloperApi +case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) + +/** + * :: DeveloperApi :: + * Encapsulates everything (extensions, workarounds, quirks) to handle the + * SQL dialect of a certain database or jdbc driver. + * Lots of databases define types that aren't explicitly supported + * by the JDBC spec. Some JDBC drivers also report inaccurate + * information---for instance, BIT(n>1) being reported as a BIT type is quite + * common, even though BIT in JDBC is meant for single-bit values. Also, there + * does not appear to be a standard name for an unbounded string or binary + * type; we use BLOB and CLOB by default but override with database-specific + * alternatives when these are absent or do not behave correctly. + * + * Currently, the only thing done by the dialect is type mapping. + * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` + * is used when writing to a JDBC table. If `getCatalystType` returns `null`, + * the default type handling is used for the given JDBC type. Similarly, + * if `getJDBCType` returns `(null, None)`, the default type handling is used + * for the given Catalyst type. + */ +@DeveloperApi +abstract class JdbcDialect { + /** + * Check if this dialect instance can handle a certain jdbc url. + * @param url the jdbc url. + * @return True if the dialect can be applied on the given jdbc url. + * @throws NullPointerException if the url is null. + */ + def canHandle(url : String): Boolean + + /** + * Get the custom datatype mapping for the given jdbc meta information. + * @param sqlType The sql type (see java.sql.Types) + * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") + * @param size The size of the type. + * @param md Result metadata associated with this type. + * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) + * or null if the default type mapping should be used. + */ + def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None + + /** + * Retrieve the jdbc / sql type for a given datatype. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The new JdbcType if there is an override for this DataType + */ + def getJDBCType(dt: DataType): Option[JdbcType] = None + + /** + * Quotes the identifier. This is used to put quotes around the identifier in case the column + * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). + */ + def quoteIdentifier(colName: String): String = { + s""""$colName"""" + } +} + +/** + * :: DeveloperApi :: + * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]]. + * + * If multiple matching dialects are registered then all matching ones will be + * tried in reverse order. A user-added dialect will thus be applied first, + * overwriting the defaults. + * + * Note that all new dialects are applied to new jdbc DataFrames only. Make + * sure to register your dialects first. + */ +@DeveloperApi +object JdbcDialects { + + private var dialects = List[JdbcDialect]() + + /** + * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. + * Readding an existing dialect will cause a move-to-front. + * @param dialect The new dialect. + */ + def registerDialect(dialect: JdbcDialect) : Unit = { + dialects = dialect :: dialects.filterNot(_ == dialect) + } + + /** + * Unregister a dialect. Does nothing if the dialect is not registered. + * @param dialect The jdbc dialect. + */ + def unregisterDialect(dialect : JdbcDialect) : Unit = { + dialects = dialects.filterNot(_ == dialect) + } + + registerDialect(MySQLDialect) + registerDialect(PostgresDialect) + + /** + * Fetch the JdbcDialect class corresponding to a given database url. + */ + private[sql] def get(url: String): JdbcDialect = { + val matchingDialects = dialects.filter(_.canHandle(url)) + matchingDialects.length match { + case 0 => NoopDialect + case 1 => matchingDialects.head + case _ => new AggregatedDialect(matchingDialects) + } + } +} + +/** + * :: DeveloperApi :: + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * @param dialects List of dialects. + */ +@DeveloperApi +class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(dialects.nonEmpty) + + override def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } +} + +/** + * :: DeveloperApi :: + * NOOP dialect object, always returning the neutral element. + */ +@DeveloperApi +case object NoopDialect extends JdbcDialect { + override def canHandle(url : String): Boolean = true +} + +/** + * :: DeveloperApi :: + * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. + */ +@DeveloperApi +case object PostgresDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + Some(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + Some(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case _ => None + } +} + +/** + * :: DeveloperApi :: + * Default mysql dialect to read bit/bitsets correctly. + */ +@DeveloperApi +case object MySQLDialect extends JdbcDialect { + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Some(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Some(BooleanType) + } else None + } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 34a83f0a5dad..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, DriverManager, PreparedStatement} -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql._ -import org.apache.spark.sql.sources.LogicalRelation - -import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition} -import org.apache.spark.sql.types._ - -package object jdbc { - object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - private def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - private[jdbc] def savePartition(url: String, table: String, iterator: Iterator[Row], - rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = { - val conn = DriverManager.getConnection(url) - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case DecimalType.Unlimited => stmt.setBigDecimal(i + 1, - row.getAs[java.math.BigDecimal](i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - } - - /** - * Make it so that you can call createJDBCTable and insertIntoJDBC on a DataFrame. - */ - implicit class JDBCDataFrame(rdd: DataFrame) { - /** - * Compute the schema string for this RDD. - */ - private def schemaString(url: String): String = { - val sb = new StringBuilder() - val quirks = DriverQuirks.get(url) - rdd.schema.fields foreach { field => { - val name = field.name - var typ: String = quirks.getJDBCType(field.dataType)._1 - if (typ == null) typ = field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - } - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - private def saveTable(url: String, table: String) { - val quirks = DriverQuirks.get(url) - var nullTypes: Array[Int] = rdd.schema.fields.map(field => { - var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 - if (nullType.isEmpty) { - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case DecimalType.Unlimited => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - } - } else nullType.get - }).toArray - - val rddSchema = rdd.schema - rdd.mapPartitions(iterator => JDBCWriteDetails.savePartition( - url, table, iterator, rddSchema, nullTypes)).collect() - } - - /** - * Save this RDD to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - */ - def createJDBCTable(url: String, table: String, allowExisting: Boolean) { - val conn = DriverManager.getConnection(url) - try { - if (allowExisting) { - val sql = s"DROP TABLE IF EXISTS $table" - conn.prepareStatement(sql).executeUpdate() - } - val schema = schemaString(url) - val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - saveTable(url, table) - } - - /** - * Save this RDD to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - */ - def insertIntoJDBC(url: String, table: String, overwrite: Boolean) { - if (overwrite) { - val conn = DriverManager.getConnection(url) - try { - val sql = s"TRUNCATE TABLE $table" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - } - saveTable(url, table) - } - } // implicit class JDBCDataFrame -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala deleted file mode 100644 index 8372decbf8aa..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.io.IOException - -import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType - - -private[sql] class DefaultSource - extends RelationProvider with SchemaRelationProvider with CreateableRelationProvider { - - /** Returns a new base relation with the parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - JSONRelation(path, samplingRatio, None)(sqlContext) - } - - /** Returns a new base relation with the given schema and parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - JSONRelation(path, samplingRatio, Some(schema))(sqlContext) - } - - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) { - sys.error(s"path $path already exists.") - } - data.toJSON.saveAsTextFile(path) - - createRelation(sqlContext, parameters, data.schema) - } -} - -private[sql] case class JSONRelation( - path: String, - samplingRatio: Double, - userSpecifiedSchema: Option[StructType])( - @transient val sqlContext: SQLContext) - extends TableScan with InsertableRelation { - - // TODO: Support partitioned JSON relation. - private def baseRDD = sqlContext.sparkContext.textFile(path) - - override val schema = userSpecifiedSchema.getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema( - baseRDD, - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord))) - - override def buildScan() = - JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) - - override def insert(data: DataFrame, overwrite: Boolean) = { - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - - if (overwrite) { - try { - fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to INSERT OVERWRITE a JSON table:\n${e.toString}") - } - data.toJSON.saveAsTextFile(path) - } else { - // TODO: Support INSERT INTO - sys.error("JSON table only support INSERT OVERWRITE for now.") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala deleted file mode 100644 index 9171939f7e8f..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.io.StringWriter -import java.sql.{Date, Timestamp} - -import scala.collection.Map -import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} - -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.core.JsonFactory -import com.fasterxml.jackson.databind.ObjectMapper - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types._ -import org.apache.spark.Logging - -private[sql] object JsonRDD extends Logging { - - private[sql] def jsonStringToRow( - json: RDD[String], - schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { - parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) - } - - private[sql] def inferSchema( - json: RDD[String], - samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) - createSchema(allKeys) - } - - private def createSchema(allKeys: Set[(String, DataType)]): StructType = { - // Resolve type conflicts - val resolved = allKeys.groupBy { - case (key, dataType) => key - }.map { - // Now, keys and types are organized in the format of - // key -> Set(type1, type2, ...). - case (key, typeSet) => { - val fieldName = key.substring(1, key.length - 1).split("`.`").toSeq - val dataType = typeSet.map { - case (_, dataType) => dataType - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - (fieldName, dataType) - } - } - - def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { - val (topLevel, structLike) = values.partition(_.size == 1) - - val topLevelFields = topLevel.filter { - name => resolved.get(prefix ++ name).get match { - case ArrayType(elementType, _) => { - def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => true - case ArrayType(t1, _) => hasInnerStruct(t1) - case o => false - } - - // Check if this array has inner struct. - !hasInnerStruct(elementType) - } - case struct: StructType => false - case _ => true - } - }.map { - a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) - } - val topLevelFieldNameSet = topLevelFields.map(_.name) - - val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { - case (name, _) => !topLevelFieldNameSet.contains(name) - }.map { - case (name, fields) => { - val nestedFields = fields.map(_.tail) - val structType = makeStruct(nestedFields, prefix :+ name) - val dataType = resolved.get(prefix :+ name).get - dataType match { - case array: ArrayType => - // The pattern of this array is ArrayType(...(ArrayType(StructType))). - // Since the inner struct of array is a placeholder (StructType(Nil)), - // we need to replace this placeholder with the actual StructType (structType). - def getActualArrayType( - innerStruct: StructType, - currentArray: ArrayType): ArrayType = currentArray match { - case ArrayType(s: StructType, containsNull) => - ArrayType(innerStruct, containsNull) - case ArrayType(a: ArrayType, containsNull) => - ArrayType(getActualArrayType(innerStruct, a), containsNull) - } - Some(StructField(name, getActualArrayType(structType, array), nullable = true)) - case struct: StructType => Some(StructField(name, structType, nullable = true)) - // dataType is StringType means that we have resolved type conflicts involving - // primitive types and complex types. So, the type of name has been relaxed to - // StringType. Also, this field should have already been put in topLevelFields. - case StringType => None - } - } - }.flatMap(field => field).toSeq - - StructType((topLevelFields ++ structFields).sortBy(_.name)) - } - - makeStruct(resolved.keySet.toSeq, Nil) - } - - private[sql] def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - /** - * Returns the most general data type for two given data types. - */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { - case Some(commonType) => commonType - case None => - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other - case (StructType(fields1), StructType(fields2)) => { - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => { - val dataType = fieldTypes.map(field => field.dataType).reduce( - (type1: DataType, type2: DataType) => compatibleType(type1, type2)) - StructField(name, dataType, true) - } - } - StructType(newFields.toSeq.sortBy(_.name)) - } - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - // TODO: We should use JsonObjectStringType to mark that values of field will be - // strings and every string is a Json object. - case (_, _) => StringType - } - } - } - - private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - ScalaReflection.typeOfObject orElse { - // Since we do not have a data type backed by BigInteger, - // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType.Unlimited - // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType.Unlimited - // Unexpected data type. - case _ => StringType - } - } - - /** - * Returns the element type of an JSON array. We go through all elements of this array - * to detect any possible type conflict. We use [[compatibleType]] to resolve - * type conflicts. - */ - private def typeOfArray(l: Seq[Any]): ArrayType = { - val containsNull = l.exists(v => v == null) - val elements = l.flatMap(v => Option(v)) - if (elements.isEmpty) { - // If this JSON array is empty, we use NullType as a placeholder. - // If this array is not empty in other JSON objects, we can resolve - // the type after we have passed through all JSON objects. - ArrayType(NullType, containsNull) - } else { - val elementType = elements.map { - e => e match { - case map: Map[_, _] => StructType(Nil) - // We have an array of arrays. If those element arrays do not have the same - // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) - case value => typeOfPrimitiveValue(value) - } - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - ArrayType(elementType, containsNull) - } - } - - /** - * Figures out all key names and data types of values from a parsed JSON object - * (in the format of Map[Stirng, Any]). When the value of a key is an JSON object, we - * only use a placeholder (StructType(Nil)) to mark that it should be a struct - * instead of getting all fields of this struct because a field does not appear - * in this JSON object can appear in other JSON objects. - */ - private def allKeysWithValueTypes(m: Map[String, Any]): Set[(String, DataType)] = { - val keyValuePairs = m.map { - // Quote the key with backticks to handle cases which have dots - // in the field name. - case (key, value) => (s"`$key`", value) - }.toSet - keyValuePairs.flatMap { - case (key: String, struct: Map[_, _]) => { - // The value associated with the key is an JSON object. - allKeysWithValueTypes(struct.asInstanceOf[Map[String, Any]]).map { - case (k, dataType) => (s"$key.$k", dataType) - } ++ Set((key, StructType(Nil))) - } - case (key: String, array: Seq[_]) => { - // The value associated with the key is an array. - // Handle inner structs of an array. - def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(e: StructType, containsNull) => { - // The elements of this arrays are structs. - v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { - element => allKeysWithValueTypes(element) - }.map { - case (k, t) => (s"$key.$k", t) - } - } - case ArrayType(t1, containsNull) => - v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { - element => buildKeyPathForInnerStructs(element, t1) - } - case other => Nil - } - val elementType = typeOfArray(array) - buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType) - } - // we couldn't tell what the type is if the value is null or empty string - case (key: String, value) if value == "" || value == null => (key, NullType) :: Nil - case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil - } - } - - /** - * Converts a Java Map/List to a Scala Map/Seq. - * We do not use Jackson's scala module at here because - * DefaultScalaModule in jackson-module-scala will make - * the parsing very slow. - */ - private def scalafy(obj: Any): Any = obj match { - case map: java.util.Map[_, _] => - // .map(identity) is used as a workaround of non-serializable Map - // generated by .mapValues. - // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - JMapWrapper(map).mapValues(scalafy).map(identity) - case list: java.util.List[_] => - JListWrapper(list).map(scalafy) - case atom => atom - } - - private def parseJson( - json: RDD[String], - columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { - // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], - // ObjectMapper will not return BigDecimal when - // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled - // (see NumberDeserializer.deserialize for the logic). - // But, we do not want to enable this feature because it will use BigDecimal - // for every float number, which will be slow. - // So, right now, we will have Infinity for those BigDecimal number. - // TODO: Support BigDecimal. - json.mapPartitions(iter => { - // When there is a key appearing multiple times (a duplicate key), - // the ObjectMapper will take the last value associated with this duplicate key. - // For example: for {"key": 1, "key":2}, we will get "key"->2. - val mapper = new ObjectMapper() - iter.flatMap { record => - try { - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - } - - parsed - } catch { - case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil - } - } - }) - } - - private def toLong(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong - case value: java.lang.Long => value.asInstanceOf[Long] - } - } - - private def toDouble(value: Any): Double = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toDouble - case value: java.lang.Long => value.asInstanceOf[Long].toDouble - case value: java.lang.Double => value.asInstanceOf[Double] - } - } - - private def toDecimal(value: Any): Decimal = { - value match { - case value: java.lang.Integer => Decimal(value) - case value: java.lang.Long => Decimal(value) - case value: java.math.BigInteger => Decimal(new java.math.BigDecimal(value)) - case value: java.lang.Double => Decimal(value) - case value: java.math.BigDecimal => Decimal(value) - } - } - - private def toJsonArrayString(seq: Seq[Any]): String = { - val builder = new StringBuilder - builder.append("[") - var count = 0 - seq.foreach { - element => - if (count > 0) builder.append(",") - count += 1 - builder.append(toString(element)) - } - builder.append("]") - - builder.toString() - } - - private def toJsonObjectString(map: Map[String, Any]): String = { - val builder = new StringBuilder - builder.append("{") - var count = 0 - map.foreach { - case (key, value) => - if (count > 0) builder.append(",") - count += 1 - val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) - builder.append(s"""\"${key}\":${stringValue}""") - } - builder.append("}") - - builder.toString() - } - - private def toString(value: Any): String = { - value match { - case value: Map[_, _] => toJsonObjectString(value.asInstanceOf[Map[String, Any]]) - case value: Seq[_] => toJsonArrayString(value) - case value => Option(value).map(_.toString).orNull - } - } - - private def toDate(value: Any): Date = { - value match { - // only support string as date - case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) - } - } - - private def toTimestamp(value: Any): Timestamp = { - value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) - } - } - - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ - if (value == null) { - null - } else { - desiredType match { - case StringType => toString(value) - case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.JvmType] - case LongType => toLong(value) - case DoubleType => toDouble(value) - case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.JvmType] - case NullType => null - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) - case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) - case DateType => toDate(value) - case TimestampType => toTimestamp(value) - } - } - } - - private def asRow(json: Map[String,Any], schema: StructType): Row = { - // TODO: Reuse the row instead of creating a new one for every record. - val row = new GenericMutableRow(schema.fields.length) - schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).orNull) - } - - row - } - - /** Transforms a single Row to JSON using Jackson - * - * @param jsonFactory a JsonFactory object to construct a JsonGenerator - * @param rowSchema the schema object used for conversion - * @param row The row to convert - */ - private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = { - val writer = new StringWriter() - val gen = jsonFactory.createGenerator(writer) - - def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: String) => gen.writeString(v) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) - case (IntegerType, v: Int) => gen.writeNumber(v) - case (ShortType, v: Short) => gen.writeNumber(v) - case (FloatType, v: Float) => gen.writeNumber(v) - case (DoubleType, v: Double) => gen.writeNumber(v) - case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) - case (ByteType, v: Byte) => gen.writeNumber(v.toInt) - case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) - case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) - - case (ArrayType(ty, _), v: Seq[_] ) => - gen.writeStartArray() - v.foreach(valWriter(ty,_)) - gen.writeEndArray() - - case (MapType(kv,vv, _), v: Map[_,_]) => - gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv,p._2) - } - gen.writeEndObject() - - case (StructType(ty), v: Row) => - gen.writeStartObject() - ty.zip(v.toSeq).foreach { - case (_, null) => - case (field, v) => - gen.writeFieldName(field.name) - valWriter(field.dataType, v) - } - gen.writeEndObject() - } - - valWriter(rowSchema, row) - gen.close() - writer.toString - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 02e5b015e8ec..a9c600b139b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -34,14 +34,18 @@ import org.apache.spark.sql.execution.SparkPlan package object sql { /** - * Converts a logical plan into zero or more SparkPlans. + * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting + * with the query planner and is not designed to be stable across spark releases. Developers + * writing libraries should instead consider using the stable APIs provided in + * [[org.apache.spark.sql.sources]] */ @DeveloperApi - protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + * @deprecated As of 1.3.0, replaced by `DataFrame`. */ - @deprecated("1.3.0", "use DataFrame") + @deprecated("use DataFrame", "1.3.0") type SchemaRDD = DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala deleted file mode 100644 index 10df8c331009..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ /dev/null @@ -1,818 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} - -import parquet.column.Dictionary -import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import parquet.schema.MessageType - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.parquet.CatalystConverter.FieldType -import org.apache.spark.sql.types._ - -/** - * Collection of converters of Parquet types (group and primitive types) that - * model arrays and maps. The conversions are partly based on the AvroParquet - * converters that are part of Parquet in order to be able to process these - * types. - * - * There are several types of converters: - *
      - *
    • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive - * (numeric, boolean and String) types
    • - *
    • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays - * of native JVM element types; note: currently null values are not supported!
    • - *
    • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of - * arbitrary element types (including nested element types); note: currently - * null values are not supported!
    • - *
    • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
    • - *
    • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: - * currently null values are not supported!
    • - *
    • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows - * of only primitive element types
    • - *
    • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested - * records, including the top-level row record
    • - *
    - */ - -private[sql] object CatalystConverter { - // The type internally used for fields - type FieldType = StructField - - // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). - // Note that "array" for the array elements is chosen by ParquetAvro. - // Using a different value will result in Parquet silently dropping columns. - val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" - val ARRAY_ELEMENTS_SCHEMA_NAME = "array" - val MAP_KEY_SCHEMA_NAME = "key" - val MAP_VALUE_SCHEMA_NAME = "value" - val MAP_SCHEMA_NAME = "map" - - // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = Row - type MapScalaType[K, V] = Map[K, V] - - protected[parquet] def createConverter( - field: FieldType, - fieldIndex: Int, - parent: CatalystConverter): Converter = { - val fieldType: DataType = field.dataType - fieldType match { - case udt: UserDefinedType[_] => { - createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) - } - // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType, false) => { - new CatalystNativeArrayConverter(elementType, fieldIndex, parent) - } - // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType, false) => { - new CatalystArrayConverter(elementType, fieldIndex, parent) - } - case ArrayType(elementType: DataType, true) => { - new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) - } - case StructType(fields: Array[StructField]) => { - new CatalystStructConverter(fields, fieldIndex, parent) - } - case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { - new CatalystMapConverter( - Array( - new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), - fieldIndex, - parent) - } - // Strings, Shorts and Bytes do not have a corresponding type in Parquet - // so we need to treat them separately - case StringType => - new CatalystPrimitiveStringConverter(parent, fieldIndex) - case ShortType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) - } - } - case ByteType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) - } - } - case d: DecimalType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateDecimal(fieldIndex, value, d) - } - } - // All other primitive types use the default converter - case ctype: PrimitiveType => { // note: need the type tag here! - new CatalystPrimitiveConverter(parent, fieldIndex) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") - } - } - - protected[parquet] def createRootConverter( - parquetSchema: MessageType, - attributes: Seq[Attribute]): CatalystConverter = { - // For non-nested types we use the optimized Row converter - if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { - new CatalystPrimitiveRowConverter(attributes.toArray) - } else { - new CatalystGroupConverter(attributes.toArray) - } - } -} - -private[parquet] abstract class CatalystConverter extends GroupConverter { - /** - * The number of fields this group has - */ - protected[parquet] val size: Int - - /** - * The index of this converter in the parent - */ - protected[parquet] val index: Int - - /** - * The parent converter - */ - protected[parquet] val parent: CatalystConverter - - /** - * Called by child converters to update their value in its parent (this). - * Note that if possible the more specific update methods below should be used - * to avoid auto-boxing of native JVM types. - * - * @param fieldIndex - * @param value - */ - protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit - - protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.getBytes) - - protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { - updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - } - - protected[parquet] def isRootConverter: Boolean = parent == null - - protected[parquet] def clearBuffer(): Unit - - /** - * Should only be called in the root (group) converter! - * - * @return - */ - def getCurrentRecord: Row = throw new UnsupportedOperationException - - /** - * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in - * a long (i.e. precision <= 18) - */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { - val precision = ctype.precisionInfo.get.precision - val scale = ctype.precisionInfo.get.scale - val bytes = value.getBytes - require(bytes.length <= 16, "Decimal field too large to read") - var unscaled = 0L - var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xFF) - i += 1 - } - // Make sure unscaled has the right sign, by sign-extending the first bit - val numBits = 8 * bytes.length - unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) - dest.set(unscaled, precision, scale) - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[Row]) - extends CatalystConverter { - - def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = - this( - schema, - index, - parent, - current = null, - buffer = new ArrayBuffer[Row]( - CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - /** - * This constructor is used for the root converter only! - */ - def this(attributes: Array[Attribute]) = - this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override def getCurrentRecord: Row = { - assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") - // TODO: use iterators if possible - // Note: this will ever only be called in the root converter when the record has been - // fully processed. Therefore it will be difficult to use mutable rows instead, since - // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericRow(current.toArray) - } - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current.update(fieldIndex, value) - } - - override protected[parquet] def clearBuffer(): Unit = buffer.clear() - - override def start(): Unit = { - current = ArrayBuffer.fill(size)(null) - converters.foreach { - converter => if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer - } - } - } - - override def end(): Unit = { - if (!isRootConverter) { - assert(current != null) // there should be no empty groups - buffer.append(new GenericRow(current.toArray)) - parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his - * converter is optimized for rows of primitive types (non-nested records). - */ -private[parquet] class CatalystPrimitiveRowConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] var current: MutableRow) - extends CatalystConverter { - - // This constructor is used for the root converter only - def this(attributes: Array[Attribute]) = - this( - attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), - new SpecificMutableRow(attributes.map(_.dataType))) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override val index = 0 - - override val parent = null - - // Should be only called in root group converter! - override def getCurrentRecord: Row = current - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - throw new UnsupportedOperationException // child converters should use the - // specific update methods below - } - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - var i = 0 - while (i < size) { - current.setNullAt(i) - i = i + 1 - } - } - - override def end(): Unit = {} - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - current.setBoolean(fieldIndex, value) - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - current.setLong(fieldIndex, value) - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - current.setShort(fieldIndex, value) - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - current.setByte(fieldIndex, value) - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - current.setDouble(fieldIndex, value) - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - current.setFloat(fieldIndex, value) - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, value.getBytes) - - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - current.setString(fieldIndex, value) - - override protected[parquet] def updateDecimal( - fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { - var decimal = current(fieldIndex).asInstanceOf[Decimal] - if (decimal == null) { - decimal = new Decimal - current(fieldIndex) = decimal - } - readDecimal(decimal, value, ctype) - } -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystConverter, - fieldIndex: Int) extends PrimitiveConverter { - override def addBinary(value: Binary): Unit = - parent.updateBinary(fieldIndex, value) - - override def addBoolean(value: Boolean): Unit = - parent.updateBoolean(fieldIndex, value) - - override def addDouble(value: Double): Unit = - parent.updateDouble(fieldIndex, value) - - override def addFloat(value: Float): Unit = - parent.updateFloat(fieldIndex, value) - - override def addInt(value: Int): Unit = - parent.updateInt(fieldIndex, value) - - override def addLong(value: Long): Unit = - parent.updateLong(fieldIndex, value) -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. - * Supports dictionaries to reduce Binary to String conversion overhead. - * - * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) - extends CatalystPrimitiveConverter(parent, fieldIndex) { - - private[this] var dict: Array[String] = null - - override def hasDictionarySupport: Boolean = true - - override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} - - - override def addValueFromDictionary(dictionaryId: Int): Unit = - parent.updateString(fieldIndex, dict(dictionaryId)) - - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.toStringUsingUTF8) -} - -private[parquet] object CatalystArrayConverter { - val INITIAL_ARRAY_SIZE = 20 -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex=0, - parent=this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - // fieldIndex is ignored (assumed to be zero but not checked) - if(value == null) { - throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") - } - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = { - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer - } - } - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (native) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param capacity The (initial) capacity of the buffer - */ -private[parquet] class CatalystNativeArrayConverter( - val elementType: NativeType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) - extends CatalystConverter { - - type NativeType = elementType.JvmType - - private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) - - private var elements: Int = 0 - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex=0, - parent=this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { - checkGrowBuffer() - buffer(elements) = value.getBytes.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def clearBuffer(): Unit = { - elements = 0 - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField( - index, - buffer.slice(0, elements).toSeq) - clearBuffer() - } - - private def checkGrowBuffer(): Unit = { - if (elements >= capacity) { - val newCapacity = 2 * capacity - val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) - Array.copy(buffer, 0, tmp, 0, capacity) - buffer = tmp - capacity = newCapacity - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array contains null (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayContainsNullConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = new CatalystConverter { - - private var current: Any = null - - val converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - override def end(): Unit = parent.updateField(index, current) - - override def start(): Unit = { - current = null - } - - override protected[parquet] val size: Int = 1 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent = CatalystArrayContainsNullConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current = value - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * This converter is for multi-element groups of primitive or complex types - * that have repetition level optional or required (so struct fields). - * - * @param schema The corresponding Catalyst schema in the form of a list of - * attributes. - * @param index - * @param parent - */ -private[parquet] class CatalystStructConverter( - override protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystGroupConverter(schema, index, parent) { - - override protected[parquet] def clearBuffer(): Unit = {} - - // TODO: think about reusing the buffer - override def end(): Unit = { - assert(!isRootConverter) - // here we need to make sure to use StructScalaType - // Note: we need to actually make a copy of the array since we - // may be in a nested field - parent.updateField(index, new GenericRow(current.toArray)) - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts two-element groups that - * match the characteristics of a map (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.MapType]]. - * - * @param schema - * @param index - * @param parent - */ -private[parquet] class CatalystMapConverter( - protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystConverter { - - private val map = new HashMap[Any, Any]() - - private val keyValueConverter = new CatalystConverter { - private var currentKey: Any = null - private var currentValue: Any = null - val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) - val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) - - override def getConverter(fieldIndex: Int): Converter = { - if (fieldIndex == 0) keyConverter else valueConverter - } - - override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue - - override def start(): Unit = { - currentKey = null - currentValue = null - } - - override protected[parquet] val size: Int = 2 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - fieldIndex match { - case 0 => - currentKey = value - case 1 => - currentValue = value - case _ => - new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") - } - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override protected[parquet] val size: Int = 1 - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - map.clear() - } - - override def end(): Unit = { - // here we need to make sure to use MapScalaType - parent.updateField(index, map.toMap) - } - - override def getConverter(fieldIndex: Int): Converter = keyValueConverter - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala deleted file mode 100644 index 0357dcc4688b..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.nio.ByteBuffer - -import com.google.common.io.BaseEncoding -import org.apache.hadoop.conf.Configuration -import parquet.filter2.compat.FilterCompat -import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterApi._ -import parquet.filter2.predicate.{FilterApi, FilterPredicate} -import parquet.io.api.Binary - -import org.apache.spark.SparkEnv -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -private[sql] object ParquetFilters { - val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - - def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { - filterExpressions.flatMap(createFilter).reduceOption(FilterApi.and).map(FilterCompat.get) - } - - def createFilter(predicate: Expression): Option[FilterPredicate] = { - val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => - (n: String, v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => - (n: String, v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) - } - - val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => - (n: String, v: Any) => FilterApi.notEq( - binaryColumn(n), - Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) - } - - val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case BinaryType => - (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } - - val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case BinaryType => - (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } - - val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case BinaryType => - (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } - - val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) - case BinaryType => - (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } - - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `SimplifyFilters` rule for details. - predicate match { - case IsNull(NamedExpression(name, dataType)) => - makeEq.lift(dataType).map(_(name, null)) - case IsNotNull(NamedExpression(name, dataType)) => - makeNotEq.lift(dataType).map(_(name, null)) - - case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeEq.lift(dataType).map(_(name, value)) - - case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - - case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGt.lift(dataType).map(_(name, value)) - - case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - - case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLt.lift(dataType).map(_(name, value)) - - case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - - case And(lhs, rhs) => - (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) - - case Or(lhs, rhs) => - for { - lhsFilter <- createFilter(lhs) - rhsFilter <- createFilter(rhs) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case Not(pred) => - createFilter(pred).map(FilterApi.not) - - case _ => None - } - } - - /** - * Note: Inside the Hadoop API we only have access to `Configuration`, not to - * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey - * the actual filter predicate. - */ - def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { - if (filters.nonEmpty) { - val serialized: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(filters).array() - val encoded: String = BaseEncoding.base64().encode(serialized) - conf.set(PARQUET_FILTER_DATA, encoded) - } - } - - /** - * Note: Inside the Hadoop API we only have access to `Configuration`, not to - * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey - * the actual filter predicate. - */ - def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = { - val data = conf.get(PARQUET_FILTER_DATA) - if (data != null) { - val decoded: Array[Byte] = BaseEncoding.base64().decode(data) - SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(decoded)) - } else { - Seq() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala deleted file mode 100644 index a54485e719da..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.permission.FsAction -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.metadata.CompressionCodecName -import parquet.schema.MessageType - -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} - -/** - * Relation that consists of data stored in a Parquet columnar format. - * - * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] - * instead of using this class directly. - * - * {{{ - * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") - * }}} - * - * @param path The path to the Parquet file. - */ -private[sql] case class ParquetRelation( - path: String, - @transient conf: Option[Configuration], - @transient sqlContext: SQLContext, - partitioningAttributes: Seq[Attribute] = Nil) - extends LeafNode with MultiInstanceRelation { - - self: Product => - - /** Schema derived from ParquetFile */ - def parquetSchema: MessageType = - ParquetTypesConverter - .readMetaData(new Path(path), conf) - .getFileMetaData - .getSchema - - /** Attributes */ - override val output = - partitioningAttributes ++ - ParquetTypesConverter.readSchemaFromFile( - new Path(path.split(",").head), - conf, - sqlContext.conf.isParquetBinaryAsString) - - lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - - override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] - - // Equals must also take into account the output attributes so that we can distinguish between - // different instances of the same relation, - override def equals(other: Any) = other match { - case p: ParquetRelation => - p.path == path && p.output == output - case _ => false - } - - // TODO: Use data from the footers. - override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) -} - -private[sql] object ParquetRelation { - - def enableLogForwarding() { - // Note: the parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "parquet". This - // checks first to see if there's any handlers already set - // and if not it creates them. If this method executes prior - // to that class being loaded then: - // 1) there's no handlers installed so there's none to - // remove. But when it IS finally loaded the desired affect - // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHanders(false) - // undoing the attempt to override the logging here. - // - // Therefore we need to force the class to be loaded. - // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName()) - - // Note: Logger.getLogger("parquet") has a default logger - // that appends to Console which needs to be cleared. - val parquetLogger = java.util.logging.Logger.getLogger("parquet") - parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - // TODO(witgo): Need to set the log level ? - // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) - if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) - } - - // The element type for the RDDs that this relation maps to. - type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - // The compression type - type CompressionType = parquet.hadoop.metadata.CompressionCodecName - - // The parquet compression short names - val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, - "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) - - /** - * Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that - * this is used inside [[org.apache.spark.sql.execution.SparkStrategies SparkStrategies]] to - * create a resolved relation as a data sink for writing to a Parquetfile. The relation is empty - * but is initialized with ParquetMetadata and can be inserted into. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param child The child node that will be used for extracting the schema. - * @param conf A configuration to be used. - * @return An empty ParquetRelation with inferred metadata. - */ - def create(pathString: String, - child: LogicalPlan, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - if (!child.resolved) { - throw new UnresolvedException[LogicalPlan]( - child, - "Attempt to create Parquet table from unresolved child (when schema is not available)") - } - createEmpty(pathString, child.output, false, conf, sqlContext) - } - - /** - * Creates an empty ParquetRelation and underlying Parquetfile that only - * consists of the Metadata for the given schema. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param attributes The schema of the relation. - * @param conf A configuration to be used. - * @return An empty ParquetRelation. - */ - def createEmpty(pathString: String, - attributes: Seq[Attribute], - allowExisting: Boolean, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - val path = checkPath(pathString, allowExisting, conf) - conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) - .name()) - ParquetRelation.enableLogForwarding() - ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString, Some(conf), sqlContext) { - override val output = attributes - } - } - - private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { - if (pathStr == null) { - throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") - } - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") - } - val path = origPath.makeQualified(fs) - if (!allowExisting && fs.exists(path)) { - sys.error(s"File $pathStr already exists.") - } - - if (fs.exists(path) && - !fs.getFileStatus(path) - .getPermission - .getUserAction - .implies(FsAction.READ_WRITE)) { - throw new IOException( - s"Unable to create ParquetRelation: path $path not read-writable") - } - path - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala deleted file mode 100644 index 28cd17fde46a..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ /dev/null @@ -1,652 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException -import java.lang.{Long => JLong} -import java.text.SimpleDateFormat -import java.text.NumberFormat -import java.util.concurrent.{Callable, TimeUnit} -import java.util.{ArrayList, Collections, Date, List => JList} - -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.Try - -import com.google.common.cache.CacheBuilder -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import parquet.hadoop._ -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{InitContext, ReadSupport} -import parquet.hadoop.metadata.GlobalMetaData -import parquet.hadoop.util.ContextUtil -import parquet.io.ParquetDecodingException -import parquet.schema.MessageType - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.{Logging, SerializableWritable, TaskContext} - -/** - * :: DeveloperApi :: - * Parquet table scan operator. Imports the file that backs the given - * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. - */ -case class ParquetTableScan( - attributes: Seq[Attribute], - relation: ParquetRelation, - columnPruningPred: Seq[Expression]) - extends LeafNode { - - // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes - // by exprId. note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map(relation.attributeMap) - - // A mapping of ordinals partitionRow -> finalOutput. - val requestedPartitionOrdinals = { - val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - - attributes.zipWithIndex.flatMap { - case (attribute, finalOrdinal) => - partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) - } - }.toArray - - override def execute(): RDD[Row] = { - import parquet.filter2.compat.FilterCompat.FilterPredicateCompat - - val sc = sqlContext.sparkContext - val job = new Job(sc.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - - val conf: Configuration = ContextUtil.getConfiguration(job) - - relation.path.split(",").foreach { curPath => - val qualifiedPath = { - val path = new Path(curPath) - path.getFileSystem(conf).makeQualified(path) - } - NewFileInputFormat.addInputPath(job, qualifiedPath) - } - - // Store both requested and original schema in `Configuration` - conf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(relation.output)) - - // Store record filtering predicate in `Configuration` - // Note 1: the input format ignores all predicates that cannot be expressed - // as simple column predicate filters in Parquet. Here we just record - // the whole pruning predicate. - ParquetFilters - .createRecordFilter(columnPruningPred) - .map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate) - // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.set( - SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) - - val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( - sc, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row], - conf) - - if (requestedPartitionOrdinals.nonEmpty) { - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - // Convert the partitioning attributes into the correct types - val partitionRowValues = - relation.partitioningAttributes - .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) - - new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { - val row = iter.next()._2.asInstanceOf[SpecificMutableRow] - - // Parquet will leave partitioning columns empty, so we fill them in here. - var i = 0 - while (i < requestedPartitionOrdinals.size) { - row(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - row - } - } - } - } else { - baseRDD.map(_._2) - } - } - - /** - * Applies a (candidate) projection. - * - * @param prunedAttributes The list of attributes to be used in the projection. - * @return Pruned TableScan. - */ - def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { - val success = validateProjection(prunedAttributes) - if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred) - } else { - sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - } - } - - /** - * Evaluates a candidate projection by checking whether the candidate is a subtype - * of the original type. - * - * @param projection The candidate projection. - * @return True if the projection is valid, false otherwise. - */ - private def validateProjection(projection: Seq[Attribute]): Boolean = { - val original: MessageType = relation.parquetSchema - val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) - Try(original.checkContains(candidate)).isSuccess - } -} - -/** - * :: DeveloperApi :: - * Operator that acts as a sink for queries on RDDs and can be used to - * store the output inside a directory of Parquet files. This operator - * is similar to Hive's INSERT INTO TABLE operation in the sense that - * one can choose to either overwrite or append to a directory. Note - * that consecutive insertions to the same table must have compatible - * (source) schemas. - * - * WARNING: EXPERIMENTAL! InsertIntoParquetTable with overwrite=false may - * cause data corruption in the case that multiple users try to append to - * the same table simultaneously. Inserting into a table that was - * previously generated by other means (e.g., by creating an HDFS - * directory and importing Parquet files generated by other tools) may - * cause unpredicted behaviour and therefore results in a RuntimeException - * (only detected via filename pattern so will not catch all cases). - */ -@DeveloperApi -case class InsertIntoParquetTable( - relation: ParquetRelation, - child: SparkPlan, - overwrite: Boolean = false) - extends UnaryNode with SparkHadoopMapReduceUtil { - - /** - * Inserts all rows into the Parquet file. - */ - override def execute() = { - // TODO: currently we do not check whether the "schema"s are compatible - // That means if one first creates a table and then INSERTs data with - // and incompatible schema the execution will fail. It would be nice - // to catch this early one, maybe having the planner validate the schema - // before calling execute(). - - val childRdd = child.execute() - assert(childRdd != null) - - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - - val writeSupport = - if (child.output.map(_.dataType).forall(_.isPrimitive)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] - } else { - classOf[org.apache.spark.sql.parquet.RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - - val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(relation.output, conf) - - val fspath = new Path(relation.path) - val fs = fspath.getFileSystem(conf) - - if (overwrite) { - try { - fs.delete(fspath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${fspath.toString} prior" - + s" to InsertIntoParquetTable:\n${e.toString}") - } - } - saveAsHadoopFile(childRdd, relation.path.toString, conf) - - // We return the child RDD to allow chaining (alternatively, one could return nothing). - childRdd - } - - override def output = child.output - - /** - * Stores the given Row RDD as a Hadoop file. - * - * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]] - * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses - * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing - * directory and need to determine which was the largest written file index before starting to - * write. - * - * @param rdd The [[org.apache.spark.rdd.RDD]] to writer - * @param path The directory to write to. - * @param conf A [[org.apache.hadoop.conf.Configuration]]. - */ - private def saveAsHadoopFile( - rdd: RDD[Row], - path: String, - conf: Configuration) { - val job = new Job(conf) - val keyType = classOf[Void] - job.setOutputKeyClass(keyType) - job.setOutputValueClass(classOf[Row]) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableWritable(job.getConfiguration) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = sqlContext.sparkContext.newRddId() - - val taskIdOffset = - if (overwrite) { - 1 - } else { - FileSystemHelper - .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 - } - - def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - context.attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = new AppendingParquetOutputFormat(taskIdOffset) - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext) - try { - while (iter.hasNext) { - val row = iter.next() - writer.write(null, row) - } - } finally { - writer.close(hadoopContext) - } - committer.commitTask(hadoopContext) - 1 - } - val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - sqlContext.sparkContext.runJob(rdd, writeShard _) - jobCommitter.commitJob(jobTaskContext) - } -} - -/** - * TODO: this will be able to append to directories it created itself, not necessarily - * to imported ones. - */ -private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends parquet.hadoop.ParquetOutputFormat[Row] { - // override to accept existing directories as valid output directory - override def checkOutputSpecs(job: JobContext): Unit = {} - - // override to choose output filename so not overwrite existing ones - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val taskId: TaskID = getTaskAttemptID(context).getTaskID - val partition: Int = taskId.getId - val filename = "part-r-" + numfmt.format(partition + offset) + ".parquet" - val committer: FileOutputCommitter = - getOutputCommitter(context).asInstanceOf[FileOutputCommitter] - new Path(committer.getWorkPath, filename) - } - - // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. - // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions - // are the same, so the method calls are source-compatible but NOT binary-compatible because - // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. - private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { - context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] - } -} - -/** - * We extend ParquetInputFormat in order to have more control over which - * RecordFilter we want to use. - */ -private[parquet] class FilteringParquetRowInputFormat - extends parquet.hadoop.ParquetInputFormat[Row] with Logging { - - private var footers: JList[Footer] = _ - - private var fileStatuses = Map.empty[Path, FileStatus] - - override def createRecordReader( - inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { - - import parquet.filter2.compat.FilterCompat.NoOpFilter - - val readSupport: ReadSupport[Row] = new RowReadSupport() - - val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) - if (!filter.isInstanceOf[NoOpFilter]) { - new ParquetRecordReader[Row]( - readSupport, - filter) - } else { - new ParquetRecordReader[Row](readSupport) - } - } - - override def getFooters(jobContext: JobContext): JList[Footer] = { - import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache - - if (footers eq null) { - val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) - val statuses = listStatus(jobContext) - fileStatuses = statuses.map(file => file.getPath -> file).toMap - if (statuses.isEmpty) { - footers = Collections.emptyList[Footer] - } else if (!cacheMetadata) { - // Read the footers from HDFS - footers = getFooters(conf, statuses) - } else { - // Read only the footers that are not in the footerCache - val foundFooters = footerCache.getAllPresent(statuses) - val toFetch = new ArrayList[FileStatus] - for (s <- statuses) { - if (!foundFooters.containsKey(s)) { - toFetch.add(s) - } - } - val newFooters = new mutable.HashMap[FileStatus, Footer] - if (toFetch.size > 0) { - val startFetch = System.currentTimeMillis - val fetched = getFooters(conf, toFetch) - logInfo(s"Fetched $toFetch footers in ${System.currentTimeMillis - startFetch} ms") - for ((status, i) <- toFetch.zipWithIndex) { - newFooters(status) = fetched.get(i) - } - footerCache.putAll(newFooters) - } - footers = new ArrayList[Footer](statuses.size) - for (status <- statuses) { - footers.add(newFooters.getOrElse(status, foundFooters.get(status))) - } - } - } - - footers - } - - // TODO Remove this method and related code once PARQUET-16 is fixed - // This method together with the `getFooters` method and the `fileStatuses` field are just used - // to mimic this PR: https://github.com/apache/incubator-parquet-mr/pull/17 - override def getSplits( - configuration: Configuration, - footers: JList[Footer]): JList[ParquetInputSplit] = { - - // Use task side strategy by default - val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) - val minSplitSize: JLong = - Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L)) - if (maxSplitSize < 0 || minSplitSize < 0) { - throw new ParquetDecodingException( - s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + - s" minSplitSize = $minSplitSize") - } - - // Uses strict type checking by default - val getGlobalMetaData = - classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) - getGlobalMetaData.setAccessible(true) - val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] - - if (globalMetaData == null) { - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - return splits - } - - val readContext = getReadSupport(configuration).init( - new InitContext(configuration, - globalMetaData.getKeyValueMetaData, - globalMetaData.getSchema)) - - if (taskSideMetaData){ - logInfo("Using Task Side Metadata Split Strategy") - getTaskSideSplits(configuration, - footers, - maxSplitSize, - minSplitSize, - readContext) - } else { - logInfo("Using Client Side Metadata Split Strategy") - getClientSideSplits(configuration, - footers, - maxSplitSize, - minSplitSize, - readContext) - } - - } - - def getClientSideSplits( - configuration: Configuration, - footers: JList[Footer], - maxSplitSize: JLong, - minSplitSize: JLong, - readContext: ReadContext): JList[ParquetInputSplit] = { - - import parquet.filter2.compat.FilterCompat.Filter - import parquet.filter2.compat.RowGroupFilter - import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache - - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) - - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - val filter: Filter = ParquetInputFormat.getFilter(configuration) - var rowGroupsDropped: Long = 0 - var totalRowGroups: Long = 0 - - // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 - // is resolved - val generateSplits = - Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") - .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( - sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) - generateSplits.setAccessible(true) - - for (footer <- footers) { - val fs = footer.getFile.getFileSystem(configuration) - val file = footer.getFile - val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) - val parquetMetaData = footer.getParquetMetadata - val blocks = parquetMetaData.getBlocks - totalRowGroups = totalRowGroups + blocks.size - val filteredBlocks = RowGroupFilter.filterRowGroups( - filter, - blocks, - parquetMetaData.getFileMetaData.getSchema) - rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) - - if (!filteredBlocks.isEmpty){ - var blockLocations: Array[BlockLocation] = null - if (!cacheMetadata) { - blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - } else { - blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { - def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) - }) - } - splits.addAll( - generateSplits.invoke( - null, - filteredBlocks, - blockLocations, - status, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) - } - } - - if (rowGroupsDropped > 0 && totalRowGroups > 0){ - val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt - logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate " - + s"($percentDropped %) !") - } - else { - logInfo("There were no row groups that could be dropped due to filter predicates") - } - splits - - } - - def getTaskSideSplits( - configuration: Configuration, - footers: JList[Footer], - maxSplitSize: JLong, - minSplitSize: JLong, - readContext: ReadContext): JList[ParquetInputSplit] = { - - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - - // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 - // is resolved - val generateSplits = - Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") - .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( - sys.error( - s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) - generateSplits.setAccessible(true) - - for (footer <- footers) { - val file = footer.getFile - val fs = file.getFileSystem(configuration) - val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) - val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - splits.addAll( - generateSplits.invoke( - null, - blockLocations, - status, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) - } - - splits - } - -} - -private[parquet] object FilteringParquetRowInputFormat { - private val footerCache = CacheBuilder.newBuilder() - .maximumSize(20000) - .build[FileStatus, Footer]() - - private val blockLocationCache = CacheBuilder.newBuilder() - .maximumSize(20000) - .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move - .build[FileStatus, Array[BlockLocation]]() -} - -private[parquet] object FileSystemHelper { - def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"ParquetTableOperations: Path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (!fs.exists(path) || !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"ParquetTableOperations: path $path does not exist or is not a directory") - } - fs.globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } - .map(_.getPath) - } - - /** - * Finds the maximum taskid in the output file names at the given path. - */ - def findMaxTaskId(pathStr: String, conf: Configuration): Int = { - val files = FileSystemHelper.listFiles(pathStr, conf) - // filename pattern is part-r-.parquet - val nameP = new scala.util.matching.Regex("""part-r-(\d{1,}).parquet""", "taskid") - val hiddenFileP = new scala.util.matching.Regex("_.*") - files.map(_.getName).map { - case nameP(taskid) => taskid.toInt - case hiddenFileP() => 0 - case other: String => - sys.error("ERROR: attempting to append to set of Parquet files and found file" + - s"that does not match name pattern: $other") - case _ => 0 - }.reduceLeft((a, b) => if (a < b) b else a) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala deleted file mode 100644 index fd63ad814406..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.util.{HashMap => JHashMap} - -import org.apache.hadoop.conf.Configuration -import parquet.column.ParquetProperties -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{ReadSupport, WriteSupport} -import parquet.io.api._ -import parquet.schema.MessageType - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} -import org.apache.spark.sql.types._ - -/** - * A `parquet.io.api.RecordMaterializer` for Rows. - * - *@param root The root group converter for the record. - */ -private[parquet] class RowRecordMaterializer(root: CatalystConverter) - extends RecordMaterializer[Row] { - - def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = - this(CatalystConverter.createRootConverter(parquetSchema, attributes)) - - override def getCurrentRecord: Row = root.getCurrentRecord - - override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] -} - -/** - * A `parquet.hadoop.api.ReadSupport` for Row objects. - */ -private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { - - override def prepareForRead( - conf: Configuration, - stringMap: java.util.Map[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[Row] = { - log.debug(s"preparing for read with Parquet file schema $fileSchema") - // Note: this very much imitates AvroParquet - val parquetSchema = readContext.getRequestedSchema - var schema: Seq[Attribute] = null - - if (readContext.getReadSupportMetadata != null) { - // first try to find the read schema inside the metadata (can result from projections) - if ( - readContext - .getReadSupportMetadata - .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) != null) { - schema = ParquetTypesConverter.convertFromString( - readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - } else { - // if unavailable, try the schema that was read originally from the file or provided - // during the creation of the Parquet relation - if (readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - schema = ParquetTypesConverter.convertFromString( - readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } - } - } - // if both unavailable, fall back to deducing the schema from the given Parquet schema - // TODO: Why it can be null? - if (schema == null) { - log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) - } - log.debug(s"list of attributes that will be read: $schema") - new RowRecordMaterializer(parquetSchema, schema) - } - - override def init( - configuration: Configuration, - keyValueMetaData: java.util.Map[String, String], - fileSchema: MessageType): ReadContext = { - var parquetSchema = fileSchema - val metadata = new JHashMap[String, String]() - val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) - - if (requestedAttributes != null) { - parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) - metadata.put( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedAttributes)) - } - - val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - if (origAttributesStr != null) { - metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) - } - - new ReadSupport.ReadContext(parquetSchema, metadata) - } -} - -private[parquet] object RowReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" - - private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { - val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) - if (schemaString == null) null else ParquetTypesConverter.convertFromString(schemaString) - } -} - -/** - * A `parquet.hadoop.api.WriteSupport` for Row ojects. - */ -private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { - - private[parquet] var writer: RecordConsumer = null - private[parquet] var attributes: Array[Attribute] = null - - override def init(configuration: Configuration): WriteSupport.WriteContext = { - val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - val metadata = new JHashMap[String, String]() - metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) - - if (attributes == null) { - attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray - } - - log.debug(s"write support initialized for requested schema $attributes") - ParquetRelation.enableLogForwarding() - new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) - } - - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - writer = recordConsumer - log.debug(s"preparing for write with schema $attributes") - } - - override def write(record: Row): Unit = { - val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") - } - - var index = 0 - writer.startMessage() - while(index < attributesSize) { - // null values indicate optional fields but we do not check currently - if (record(index) != null) { - writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record(index)) - writer.endField(attributes(index).name, index) - } - index = index + 1 - } - writer.endMessage() - } - - private[parquet] def writeValue(schema: DataType, value: Any): Unit = { - if (value != null) { - schema match { - case t: UserDefinedType[_] => writeValue(t.sqlType, value) - case t @ ArrayType(_, _) => writeArray( - t, - value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) - case t @ MapType(_, _, _) => writeMap( - t, - value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) - case t @ StructType(_) => writeStruct( - t, - value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) - } - } - } - - private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { - if (value != null) { - schema match { - case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") - ) - ) - case BinaryType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(value.asInstanceOf[Int]) - case ShortType => writer.addInteger(value.asInstanceOf[Short]) - case LongType => writer.addLong(value.asInstanceOf[Long]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) - case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) - case _ => sys.error(s"Do not know how to writer $schema to consumer") - } - } - } - - private[parquet] def writeStruct( - schema: StructType, - struct: CatalystConverter.StructScalaType[_]): Unit = { - if (struct != null) { - val fields = schema.fields.toArray - writer.startGroup() - var i = 0 - while(i < fields.size) { - if (struct(i) != null) { - writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct(i)) - writer.endField(fields(i).name, i) - } - i = i + 1 - } - writer.endGroup() - } - } - - private[parquet] def writeArray( - schema: ArrayType, - array: CatalystConverter.ArrayScalaType[_]): Unit = { - val elementType = schema.elementType - writer.startGroup() - if (array.size > 0) { - if (schema.containsNull) { - writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) - var i = 0 - while (i < array.size) { - writer.startGroup() - if (array(i) != null) { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array(i)) - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - } - writer.endGroup() - i = i + 1 - } - writer.endField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) - } else { - writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - var i = 0 - while (i < array.size) { - writeValue(elementType, array(i)) - i = i + 1 - } - writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - } - } - writer.endGroup() - } - - private[parquet] def writeMap( - schema: MapType, - map: CatalystConverter.MapScalaType[_, _]): Unit = { - writer.startGroup() - if (map.size > 0) { - writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) - for ((key, value) <- map) { - writer.startGroup() - writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - writeValue(schema.keyType, key) - writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) - if (value != null) { - writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - writeValue(schema.valueType, value) - writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) - } - writer.endGroup() - } - writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) - } - writer.endGroup() - } - - // Scratch array used to write decimals as fixed-length binary - private val scratchBytes = new Array[Byte](8) - - private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) - val unscaledLong = decimal.toUnscaledLong - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - scratchBytes(i) = (unscaledLong >> shift).toByte - i += 1 - shift -= 8 - } - writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) - } - -} - -// Optimized for non-nested rows -private[parquet] class MutableRowWriteSupport extends RowWriteSupport { - override def write(record: Row): Unit = { - val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") - } - - var index = 0 - writer.startMessage() - while(index < attributesSize) { - // null values indicate optional fields but we do not check currently - if (record(index) != null && record(index) != Nil) { - writer.startField(attributes(index).name, index) - consumeType(attributes(index).dataType, record, index) - writer.endField(attributes(index).name, index) - } - index = index + 1 - } - writer.endMessage() - } - - private def consumeType( - ctype: DataType, - record: Row, - index: Int): Unit = { - ctype match { - case StringType => writer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) - case BinaryType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(record.getInt(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case LongType => writer.addLong(record.getLong(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) - case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") - } - } -} - -private[parquet] object RowWriteSupport { - val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" - - def getSchema(configuration: Configuration): Seq[Attribute] = { - val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - if (schemaString == null) { - throw new RuntimeException("Missing schema!") - } - ParquetTypesConverter.convertFromString(schemaString) - } - - def setSchema(schema: Seq[Attribute], configuration: Configuration) { - val encoded = ParquetTypesConverter.convertToString(schema) - configuration.set(SPARK_ROW_SCHEMA, encoded) - configuration.set( - ParquetOutputFormat.WRITER_VERSION, - ParquetProperties.WriterVersion.PARQUET_1_0.toString) - } -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala deleted file mode 100644 index 9d6c529574da..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag -import scala.util.Try - -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.util -import org.apache.spark.util.Utils - -/** - * A helper trait that provides convenient facilities for Parquet testing. - * - * NOTE: Considering classes `Tuple1` ... `Tuple22` all extend `Product`, it would be more - * convenient to use tuples rather than special case classes when writing test cases/suites. - * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. - */ -trait ParquetTest { - val sqlContext: SQLContext - - import sqlContext._ - - protected def configuration = sparkContext.hadoopConfiguration - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(getConf(key)).toOption) - (keys, values).zipped.foreach(setConf) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => setConf(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } - - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val file = util.getTempFilePath("parquetTest").getCanonicalFile - try f(file) finally if (file.exists()) Utils.deleteRecursively(file) - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } - - /** - * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` - * returns. - */ - protected def withParquetFile[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) - (f: String => Unit): Unit = { - withTempPath { file => - sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath) - f(file.getCanonicalPath) - } - } - - /** - * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], - * which is then passed to `f`. The Parquet file will be deleted after `f` returns. - */ - protected def withParquetRDD[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) - (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(parquetFile(path))) - } - - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally dropTempTable(tableName) - } - - /** - * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a - * temporary table named `tableName`, then call `f`. The temporary table together with the - * Parquet file will be dropped/deleted after `f` returns. - */ - protected def withParquetTable[T <: Product: ClassTag: TypeTag] - (data: Seq[T], tableName: String) - (f: => Unit): Unit = { - withParquetRDD(data) { rdd => - sqlContext.registerRDDAsTable(rdd, tableName) - withTempTable(tableName)(f) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala deleted file mode 100644 index d5993656e022..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ /dev/null @@ -1,462 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.apache.spark.sql.test.TestSQLContext - -import parquet.example.data.{GroupWriter, Group} -import parquet.example.data.simple.SimpleGroup -import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.example.GroupReadSupport -import parquet.hadoop.util.ContextUtil -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} - -import org.apache.spark.util.Utils - -// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport -// with an empty configuration (it is after all not intended to be used in this way?) -// and members are private so we need to make our own in order to pass the schema -// to the writer. -private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { - var groupWriter: GroupWriter = null - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - groupWriter = new GroupWriter(recordConsumer, schema) - } - override def init(configuration: Configuration): WriteContext = { - new WriteContext(schema, new java.util.HashMap[String, String]()) - } - override def write(record: Group) { - groupWriter.write(record) - } -} - -private[sql] object ParquetTestData { - - val testSchema = - """message myrecord { - optional boolean myboolean; - optional int32 myint; - optional binary mystring (UTF8); - optional int64 mylong; - optional float myfloat; - optional double mydouble; - }""" - - // field names for test assertion error messages - val testSchemaFieldNames = Seq( - "myboolean:Boolean", - "myint:Int", - "mystring:String", - "mylong:Long", - "myfloat:Float", - "mydouble:Double" - ) - - val subTestSchema = - """ - message myrecord { - optional boolean myboolean; - optional int64 mylong; - } - """ - - val testFilterSchema = - """ - message myrecord { - required boolean myboolean; - required int32 myint; - required binary mystring (UTF8); - required int64 mylong; - required float myfloat; - required double mydouble; - optional boolean myoptboolean; - optional int32 myoptint; - optional binary myoptstring (UTF8); - optional int64 myoptlong; - optional float myoptfloat; - optional double myoptdouble; - } - """ - - // field names for test assertion error messages - val subTestSchemaFieldNames = Seq( - "myboolean:Boolean", - "mylong:Long" - ) - - val testDir = Utils.createTempDir() - val testFilterDir = Utils.createTempDir() - - lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) - - val testNestedSchema1 = - // based on blogpost example, source: - // https://blog.twitter.com/2013/dremel-made-simple-with-parquet - // note: instead of string we have to use binary (?) otherwise - // Parquet gives us: - // IllegalArgumentException: expected one of [INT64, INT32, BOOLEAN, - // BINARY, FLOAT, DOUBLE, INT96, FIXED_LEN_BYTE_ARRAY] - // Also repeated primitives seem tricky to convert (AvroParquet - // only uses them in arrays?) so only use at most one in each group - // and nothing else in that group (-> is mapped to array)! - // The "values" inside ownerPhoneNumbers is a keyword currently - // so that array types can be translated correctly. - """ - message AddressBook { - required binary owner (UTF8); - optional group ownerPhoneNumbers { - repeated binary array (UTF8); - } - optional group contacts { - repeated group array { - required binary name (UTF8); - optional binary phoneNumber (UTF8); - } - } - } - """ - - - val testNestedSchema2 = - """ - message TestNested2 { - required int32 firstInt; - optional int32 secondInt; - optional group longs { - repeated int64 array; - } - required group entries { - repeated group array { - required double value; - optional boolean truth; - } - } - optional group outerouter { - repeated group array { - repeated group array { - repeated int32 array; - } - } - } - } - """ - - val testNestedSchema3 = - """ - message TestNested3 { - required int32 x; - optional group booleanNumberPairs { - repeated group array { - required int32 key; - optional group value { - repeated group array { - required double nestedValue; - optional boolean truth; - } - } - } - } - } - """ - - val testNestedSchema4 = - """ - message TestNested4 { - required int32 x; - optional group data1 { - repeated group map { - required binary key (UTF8); - required int32 value; - } - } - required group data2 { - repeated group map { - required binary key (UTF8); - required group value { - required int64 payload1; - optional binary payload2 (UTF8); - } - } - } - } - """ - - val testNestedDir1 = Utils.createTempDir() - val testNestedDir2 = Utils.createTempDir() - val testNestedDir3 = Utils.createTempDir() - val testNestedDir4 = Utils.createTempDir() - - lazy val testNestedData1 = - new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) - lazy val testNestedData2 = - new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) - - def writeFile() = { - testDir.delete() - val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) - val job = new Job() - val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - for(i <- 0 until 15) { - val record = new SimpleGroup(schema) - if (i % 3 == 0) { - record.add(0, true) - } else { - record.add(0, false) - } - if (i % 5 == 0) { - record.add(1, 5) - } - record.add(2, "abc") - record.add(3, i.toLong << 33) - record.add(4, 2.5F) - record.add(5, 4.5D) - writer.write(record) - } - writer.close() - } - - def writeFilterFile(records: Int = 200) = { - // for microbenchmark use: records = 300000000 - testFilterDir.delete - val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema) - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - for(i <- 0 to records) { - val record = new SimpleGroup(schema) - if (i % 4 == 0) { - record.add(0, true) - } else { - record.add(0, false) - } - record.add(1, i) - record.add(2, i.toString) - record.add(3, i.toLong) - record.add(4, i.toFloat + 0.5f) - record.add(5, i.toDouble + 0.5d) - if (i % 2 == 0) { - if (i % 3 == 0) { - record.add(6, true) - } else { - record.add(6, false) - } - record.add(7, i) - record.add(8, i.toString) - record.add(9, i.toLong) - record.add(10, i.toFloat + 0.5f) - record.add(11, i.toDouble + 0.5d) - } - - writer.write(record) - } - writer.close() - } - - def writeNestedFile1() { - // example data from https://blog.twitter.com/2013/dremel-made-simple-with-parquet - testNestedDir1.delete() - val path: Path = new Path(new Path(testNestedDir1.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema1) - - val r1 = new SimpleGroup(schema) - r1.add(0, "Julien Le Dem") - r1.addGroup(1) - .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 123 4567") - .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 666 1337") - .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "XXX XXX XXXX") - val contacts = r1.addGroup(2) - contacts.addGroup(0) - .append("name", "Dmitriy Ryaboy") - .append("phoneNumber", "555 987 6543") - contacts.addGroup(0) - .append("name", "Chris Aniszczyk") - - val r2 = new SimpleGroup(schema) - r2.add(0, "A. Nonymous") - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.write(r2) - writer.close() - } - - def writeNestedFile2() { - testNestedDir2.delete() - val path: Path = new Path(new Path(testNestedDir2.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema2) - - val r1 = new SimpleGroup(schema) - r1.add(0, 1) - r1.add(1, 7) - val longs = r1.addGroup(2) - longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME , 1.toLong << 32) - longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 33) - longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 34) - val booleanNumberPair = r1.addGroup(3).addGroup(0) - booleanNumberPair.add("value", 2.5) - booleanNumberPair.add("truth", false) - val top_level = r1.addGroup(4) - val second_level_a = top_level.addGroup(0) - val second_level_b = top_level.addGroup(0) - val third_level_aa = second_level_a.addGroup(0) - val third_level_ab = second_level_a.addGroup(0) - val third_level_c = second_level_b.addGroup(0) - third_level_aa.add( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - 7) - third_level_ab.add( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - 8) - third_level_c.add( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - 9) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.close() - } - - def writeNestedFile3() { - testNestedDir3.delete() - val path: Path = new Path(new Path(testNestedDir3.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema3) - - val r1 = new SimpleGroup(schema) - r1.add(0, 1) - val booleanNumberPairs = r1.addGroup(1) - val g1 = booleanNumberPairs.addGroup(0) - g1.add(0, 1) - val nested1 = g1.addGroup(1) - val ng1 = nested1.addGroup(0) - ng1.add(0, 1.5) - ng1.add(1, false) - val ng2 = nested1.addGroup(0) - ng2.add(0, 2.5) - ng2.add(1, true) - val g2 = booleanNumberPairs.addGroup(0) - g2.add(0, 2) - val ng3 = g2.addGroup(1) - .addGroup(0) - ng3.add(0, 3.5) - ng3.add(1, false) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.close() - } - - def writeNestedFile4() { - testNestedDir4.delete() - val path: Path = new Path(new Path(testNestedDir4.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema4) - - val r1 = new SimpleGroup(schema) - r1.add(0, 7) - val map1 = r1.addGroup(1) - val keyValue1 = map1.addGroup(0) - keyValue1.add(0, "key1") - keyValue1.add(1, 1) - val keyValue2 = map1.addGroup(0) - keyValue2.add(0, "key2") - keyValue2.add(1, 2) - val map2 = r1.addGroup(2) - val keyValue3 = map2.addGroup(0) - // TODO: currently only string key type supported - keyValue3.add(0, "seven") - val valueGroup1 = keyValue3.addGroup(1) - valueGroup1.add(0, 42.toLong) - valueGroup1.add(1, "the answer") - val keyValue4 = map2.addGroup(0) - // TODO: currently only string key type supported - keyValue4.add(0, "eight") - val valueGroup2 = keyValue4.addGroup(1) - valueGroup2.add(0, 49.toLong) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.close() - } - - // TODO: this is not actually used anywhere but useful for debugging - /* def readNestedFile(file: File, schemaString: String): Unit = { - val configuration = new Configuration() - val path = new Path(new Path(file.toURI), new Path("part-r-0.parquet")) - val fs: FileSystem = path.getFileSystem(configuration) - val schema: MessageType = MessageTypeParser.parseMessageType(schemaString) - assert(schema != null) - val outputStatus: FileStatus = fs.getFileStatus(new Path(path.toString)) - val footers = ParquetFileReader.readFooter(configuration, outputStatus) - assert(footers != null) - val reader = new ParquetReader(new Path(path.toString), new GroupReadSupport()) - val first = reader.read() - assert(first != null) - } */ - - // to test golb pattern (wild card pattern matching for parquetFile input - val testGlobDir = Utils.createTempDir() - val testGlobSubDir1 = Utils.createTempDir(testGlobDir.getPath) - val testGlobSubDir2 = Utils.createTempDir(testGlobDir.getPath) - val testGlobSubDir3 = Utils.createTempDir(testGlobDir.getPath) - - def writeGlobFiles() = { - val subDirs = Array(testGlobSubDir1, testGlobSubDir2, testGlobSubDir3) - - subDirs.foreach { dir => - val path: Path = new Path(new Path(dir.toURI), new Path("part-r-0.parquet")) - val job = new Job() - val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - for(i <- 0 until 15) { - val record = new SimpleGroup(schema) - if(i % 3 == 0) { - record.add(0, true) - } else { - record.add(0, false) - } - if(i % 5 == 0) { - record.add(1, 5) - } - record.add(2, "abc") - record.add(3, i.toLong << 33) - record.add(4, 2.5F) - record.add(5, 4.5D) - writer.write(record) - } - writer.close() - } - } -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala deleted file mode 100644 index 6d8c682ccced..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ /dev/null @@ -1,493 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException - -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job - -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} -import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} -import org.apache.spark.sql.types._ - -// Implicits -import scala.collection.JavaConversions._ - -/** A class representing Parquet info fields we care about, for passing back to Parquet */ -private[parquet] case class ParquetTypeInfo( - primitiveType: ParquetPrimitiveTypeName, - originalType: Option[ParquetOriginalType] = None, - decimalMetadata: Option[DecimalMetadata] = None, - length: Option[Int] = None) - -private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = - classOf[PrimitiveType] isAssignableFrom ctype.getClass - - def toPrimitiveDataType( - parquetType: ParquetPrimitiveType, - binaryAsString: Boolean): DataType = { - val originalType = parquetType.getOriginalType - val decimalInfo = parquetType.getDecimalMetadata - parquetType.getPrimitiveTypeName match { - case ParquetPrimitiveTypeName.BINARY - if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType - case ParquetPrimitiveTypeName.BINARY => BinaryType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Potential loss of precision: cannot convert INT96") - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => - // TODO: for now, our reader only supports decimals that fit in a Long - DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } - } - - /** - * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.types.DataType]]. - * - * We apply the following conversion rules: - *
      - *
    • Primitive types are converter to the corresponding primitive type.
    • - *
    • Group types that have a single field that is itself a group, which has repetition - * level `REPEATED`, are treated as follows:
        - *
      • If the nested group has name `values`, the surrounding group is converted - * into an [[ArrayType]] with the corresponding field type (primitive or - * complex) as element type.
      • - *
      • If the nested group has name `map` and two fields (named `key` and `value`), - * the surrounding group is converted into a [[MapType]] - * with the corresponding key and value (value possibly complex) types. - * Note that we currently assume map values are not nullable.
      • - *
      • Other group types are converted into a [[StructType]] with the corresponding - * field types.
    • - *
    - * Note that fields are determined to be `nullable` if and only if their Parquet repetition - * level is not `REQUIRED`. - * - * @param parquetType The type to convert. - * @return The corresponding Catalyst type. - */ - def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { - def correspondsToMap(groupType: ParquetGroupType): Boolean = { - if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { - false - } else { - // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - keyValueGroup.getRepetition == Repetition.REPEATED && - keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && - keyValueGroup.getFieldCount == 2 && - keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && - keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME - } - } - - def correspondsToArray(groupType: ParquetGroupType): Boolean = { - groupType.getFieldCount == 1 && - groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && - groupType.getFields.apply(0).getRepetition == Repetition.REPEATED - } - - if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) - } else { - val groupType = parquetType.asGroupType() - parquetType.getOriginalType match { - // if the schema was constructed programmatically there may be hints how to convert - // it inside the metadata via the OriginalType field - case ParquetOriginalType.LIST => { // TODO: check enums! - assert(groupType.getFieldCount == 1) - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) - } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) - } - } - case ParquetOriginalType.MAP => { - assert( - !groupType.getFields.apply(0).isPrimitive, - "Parquet Map type malformatted: expected nested group for map!") - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert( - keyValueGroup.getFieldCount == 2, - "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } - case _ => { - // Note: the order of these checks is important! - if (correspondsToMap(groupType)) { // MapType - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } else if (correspondsToArray(groupType)) { // ArrayType - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) - } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) - } - } else { // everything else: StructType - val fields = groupType - .getFields - .map(ptype => new StructField( - ptype.getName, - toDataType(ptype, isBinaryAsString), - ptype.getRepetition != Repetition.REQUIRED)) - StructType(fields) - } - } - } - } - } - - /** - * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return - * the name of the corresponding Parquet primitive type or None if the given type - * is not primitive. - * - * @param ctype The type to convert - * @return The name of the corresponding Parquet type properties - */ - def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { - case StringType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) - case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) - case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) - case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) - case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) - case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) - case DecimalType.Fixed(precision, scale) if precision <= 18 => - // TODO: for now, our writer only supports decimals that fit in a Long - Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, - Some(ParquetOriginalType.DECIMAL), - Some(new DecimalMetadata(precision, scale)), - Some(BYTES_FOR_PRECISION(precision)))) - case _ => None - } - - /** - * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. - */ - private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => - var length = 1 - while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { - length += 1 - } - length - } - - /** - * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into - * the corresponding Parquet `Type`. - * - * The conversion follows the rules below: - *
      - *
    • Primitive types are converted into Parquet's primitive types.
    • - *
    • [[org.apache.spark.sql.types.StructType]]s are converted - * into Parquet's `GroupType` with the corresponding field types.
    • - *
    • [[org.apache.spark.sql.types.ArrayType]]s are converted - * into a 2-level nested group, where the outer group has the inner - * group as sole field. The inner group has name `values` and - * repetition level `REPEATED` and has the element type of - * the array as schema. We use Parquet's `ConversionPatterns` for this - * purpose.
    • - *
    • [[org.apache.spark.sql.types.MapType]]s are converted - * into a nested (2-level) Parquet `GroupType` with two fields: a key - * type and a value type. The nested group has repetition level - * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` - * for this purpose
    • - *
    - * Parquet's repetition level is generally set according to the following rule: - *
      - *
    • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or - * `MapType`, then the repetition level is set to `REPEATED`.
    • - *
    • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet - * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
    • - *
    - * - *@param ctype The type to convert - * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] - * whose type is converted - * @param nullable When true indicates that the attribute is nullable - * @param inArray When true indicates that this is a nested attribute inside an array. - * @return The corresponding Parquet type. - */ - def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false): ParquetType = { - val repetition = - if (inArray) { - Repetition.REPEATED - } else { - if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED - } - val typeInfo = fromPrimitiveDataType(ctype) - typeInfo.map { - case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => - val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) - for (len <- length) { - builder.length(len) - } - for (metadata <- decimalMetadata) { - builder.precision(metadata.getPrecision).scale(metadata.getScale) - } - builder.named(name) - }.getOrElse { - ctype match { - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray) - } - case ArrayType(elementType, false) => { - val parquetElementType = fromDataType( - elementType, - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - nullable = false, - inArray = true) - ConversionPatterns.listType(repetition, name, parquetElementType) - } - case ArrayType(elementType, true) => { - val parquetElementType = fromDataType( - elementType, - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - nullable = true, - inArray = false) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) - } - case StructType(structFields) => { - val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) - } - new ParquetGroupType(repetition, name, fields.toSeq) - } - case MapType(keyType, valueType, valueContainsNull) => { - val parquetKeyType = - fromDataType( - keyType, - CatalystConverter.MAP_KEY_SCHEMA_NAME, - nullable = false, - inArray = false) - val parquetValueType = - fromDataType( - valueType, - CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = valueContainsNull, - inArray = false) - ConversionPatterns.mapType( - repetition, - name, - parquetKeyType, - parquetValueType) - } - case _ => sys.error(s"Unsupported datatype $ctype") - } - } - } - - def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { - parquetSchema - .asGroupType() - .getFields - .map( - field => - new AttributeReference( - field.getName, - toDataType(field, isBinaryAsString), - field.getRepetition != Repetition.REQUIRED)()) - } - - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val fields = attributes.map( - attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable)) - new MessageType("root", fields) - } - - def convertFromString(string: String): Seq[Attribute] = { - Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { - case s: StructType => s.toAttributes - case other => sys.error(s"Can convert $string to row") - } - } - - def convertToString(schema: Seq[Attribute]): String = { - StructType.fromAttributes(schema).json - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put( - RowReadSupport.SPARK_METADATA_KEY, - ParquetTypesConverter.convertToString(attributes)) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetRelation.enableLogForwarding() - ParquetFileWriter.writeMetadataFile( - conf, - path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @param configuration The Hadoop configuration to use. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - - val children = - fs - .globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } - .filterNot { status => - val name = status.getPath.getName - (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE - } - - ParquetRelation.enableLogForwarding() - - // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row - // groups. Since Parquet schema is replicated among all row groups, we only need to touch a - // single row group to read schema related metadata. Notice that we are making assumptions that - // all data in a single Parquet file have the same schema, which is normally true. - children - // Try any non-"_metadata" file first... - .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) - // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is - // empty, thus normally the "_metadata" file is expected to be fairly small). - .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) - .map(ParquetFileReader.readFooter(conf, _, ParquetMetadataConverter.NO_FILTER)) - .getOrElse( - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) - } - - /** - * Reads in Parquet Metadata from the given path and tries to extract the schema - * (Catalyst attributes) from the application-specific key-value map. If this - * is empty it falls back to converting from the Parquet file schema which - * may lead to an upcast of types (e.g., {byte, short} to int). - * - * @param origPath The path at which we expect one (or more) Parquet files. - * @param conf The Hadoop configuration to use. - * @return A list of attributes that make up the schema. - */ - def readSchemaFromFile( - origPath: Path, - conf: Option[Configuration], - isBinaryAsString: Boolean): Seq[Attribute] = { - val keyValueMetadata: java.util.Map[String, String] = - readMetaData(origPath, conf) - .getFileMetaData - .getKeyValueMetaData - if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } else { - val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) - log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") - attributes - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala deleted file mode 100644 index 1e794cad7393..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.parquet - -import java.util.{List => JList} - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} -import parquet.filter2.predicate.FilterApi -import parquet.hadoop.ParquetInputFormat -import parquet.hadoop.util.ContextUtil - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{NewHadoopPartition, RDD} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.sql.{Row, SQLConf, SQLContext} -import org.apache.spark.{Logging, Partition => SparkPartition} - - -/** - * Allows creation of parquet based tables using the syntax - * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option - * required is `path`, which should be the location of a collection of, optionally partitioned, - * parquet files. - */ -class DefaultSource extends RelationProvider { - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val path = - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) - - ParquetRelation2(path)(sqlContext) - } -} - -private[parquet] case class Partition(partitionValues: Map[String, Any], files: Seq[FileStatus]) - -/** - * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * currently not intended as a full replacement of the parquet support in Spark SQL though it is - * likely that it will eventually subsume the existing physical plan implementation. - * - * Compared with the current implementation, this class has the following notable differences: - * - * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/` - * located at `path`. Currently only a single partitioning column is supported and it must - * be an integer. This class supports both fully self-describing data, which contains the partition - * key, and data where the partition key is only present in the folder structure. The presence - * of the partitioning key in the data is also auto-detected. The `null` partition is not yet - * supported. - * - * Metadata: The metadata is automatically discovered by reading the first parquet file present. - * There is currently no support for working with files that have different schema. Additionally, - * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached - * to improve the speed of interactive querying. When data is added to a table it must be dropped - * and recreated to pick up any changes. - * - * Statistics: Statistics for the size of the table are automatically populated during metadata - * discovery. - */ -@DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) - extends CatalystScan with Logging { - - def sparkContext = sqlContext.sparkContext - - // Minor Hack: scala doesnt seem to respect @transient for vals declared via extraction - @transient - private var partitionKeys: Seq[String] = _ - @transient - private var partitions: Seq[Partition] = _ - discoverPartitions() - - // TODO: Only finds the first partition, assumes the key is of type Integer... - private def discoverPartitions() = { - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val partValue = "([^=]+)=([^=]+)".r - - val childrenOfPath = fs.listStatus(new Path(path)).filterNot(_.getPath.getName.startsWith("_")) - val childDirs = childrenOfPath.filter(s => s.isDir) - - if (childDirs.size > 0) { - val partitionPairs = childDirs.map(_.getPath.getName).map { - case partValue(key, value) => (key, value) - } - - val foundKeys = partitionPairs.map(_._1).distinct - if (foundKeys.size > 1) { - sys.error(s"Too many distinct partition keys: $foundKeys") - } - - // Do a parallel lookup of partition metadata. - val partitionFiles = - childDirs.par.map { d => - fs.listStatus(d.getPath) - // TODO: Is there a standard hadoop function for this? - .filterNot(_.getPath.getName.startsWith("_")) - .filterNot(_.getPath.getName.startsWith(".")) - }.seq - - partitionKeys = foundKeys.toSeq - partitions = partitionFiles.zip(partitionPairs).map { case (files, (key, value)) => - Partition(Map(key -> value.toInt), files) - }.toSeq - } else { - partitionKeys = Nil - partitions = Partition(Map.empty, childrenOfPath) :: Nil - } - } - - override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum - - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.conf.isParquetBinaryAsString)) - - val dataIncludesKey = - partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) - - override val schema = - if (dataIncludesKey) { - dataSchema - } else { - StructType(dataSchema.fields :+ StructField(partitionKeys.head, IntegerType)) - } - - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - // This is mostly a hack so that we can use the existing parquet filter code. - val requiredColumns = output.map(_.name) - - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) - - val requestedSchema = StructType(requiredColumns.map(schema(_))) - - val partitionKeySet = partitionKeys.toSet - val rawPredicate = - predicates - .filter(_.references.map(_.name).toSet.subsetOf(partitionKeySet)) - .reduceOption(And) - .getOrElse(Literal(true)) - - // Translate the predicate so that it reads from the information derived from the - // folder structure - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = partitionKeys.indexWhere(a.name == _) - BoundReference(idx, IntegerType, nullable = true) - } - - val inputData = new GenericMutableRow(partitionKeys.size) - val pruningCondition = InterpretedPredicate(castedPredicate) - - val selectedPartitions = - if (partitionKeys.nonEmpty && predicates.nonEmpty) { - partitions.filter { part => - inputData(0) = part.partitionValues.values.head - pruningCondition(inputData) - } - } else { - partitions - } - - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) - // FileInputFormat cannot handle empty lists. - if (selectedFiles.nonEmpty) { - org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles: _*) - } - - // Push down filters when possible. Notice that not all filters can be converted to Parquet - // filter predicate. Here we try to convert each individual predicate and only collect those - // convertible ones. - predicates - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) - - def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - logInfo(s"Reading $percentRead% of $path partitions") - - // Store both requested and original schema in `Configuration` - jobConf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedSchema.toAttributes)) - jobConf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(schema.toAttributes)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean - jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) - - val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( - sparkContext, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row], - jobConf) { - val cacheMetadata = useCache - - @transient - val cachedStatus = selectedPartitions.flatMap(_.files) - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = - if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - } - } else { - new FilteringParquetRowInputFormat - } - - inputFormat match { - case configurable: Configurable => - configurable.setConf(getConf) - case _ => - } - val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - } - - // The ordinal for the partition key in the result row, if requested. - val partitionKeyLocation = - partitionKeys - .headOption - .map(requiredColumns.indexOf(_)) - .getOrElse(-1) - - // When the data does not include the key and the key is requested then we must fill it in - // based on information from the input split. - if (!dataIncludesKey && partitionKeyLocation != -1) { - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - val currentValue = partValues.values.head.toInt - iter.map { pair => - val res = pair._2.asInstanceOf[SpecificMutableRow] - res.setInt(partitionKeyLocation, currentValue) - res - } - } - } else { - baseRDD.map(_._2) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala deleted file mode 100644 index 386ff2452f1a..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, Strategy} -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, InsertIntoTable => LogicalInsertIntoTable} -import org.apache.spark.sql.execution - -/** - * A Strategy for planning scans over data sources defined using the sources API. - */ -private[sql] object DataSourceStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => - pruneFilterProjectRaw( - l, - projectList, - filters, - (a, f) => t.buildScan(a, f)) :: Nil - - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => - pruneFilterProject( - l, - projectList, - filters, - (a, f) => t.buildScan(a, f)) :: Nil - - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => - pruneFilterProject( - l, - projectList, - filters, - (a, _) => t.buildScan(a)) :: Nil - - case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil - - case i @ LogicalInsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => - if (partition.nonEmpty) { - sys.error(s"Insert into a partition is not allowed because $l is not partitioned.") - } - execution.ExecutedCommand(InsertIntoRelation(t, query, overwrite)) :: Nil - - case _ => Nil - } - - // Based on Public API. - protected def pruneFilterProject( - relation: LogicalRelation, - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { - pruneFilterProjectRaw( - relation, - projectList, - filterPredicates, - (requestedColumns, pushedFilters) => { - scanBuilder(requestedColumns.map(_.name).toArray, selectFilters(pushedFilters).toArray) - }) - } - - // Based on Catalyst expressions. - protected def pruneFilterProjectRaw( - relation: LogicalRelation, - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[Row]) = { - - val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(And) - - val pushedFilters = filterPredicates.map { _ transform { - case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. - }} - - if (projectList.map(_.toAttribute) == projectList && - projectSet.size == projectList.size && - filterSet.subsetOf(projectSet)) { - // When it is possible to just use column pruning to get the right projection and - // when the columns of this projection are enough to evaluate all filter conditions, - // just do a scan followed by a filter, with no extra project. - val requestedColumns = - projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. - .map(relation.attributeMap) // Match original case of attributes. - - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters)) - filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) - } else { - val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) - execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) - } - } - - /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ - protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - - case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - - case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - - case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => - GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => - LessThanOrEqual(a.name, v) - - case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => - LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => - GreaterThanOrEqual(a.name, v) - - case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala deleted file mode 100644 index 12b59ba20bb1..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.sources - -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} - -/** - * Used to link a [[BaseRelation]] in to a logical query plan. - */ -private[sql] case class LogicalRelation(relation: BaseRelation) - extends LeafNode - with MultiInstanceRelation { - - override val output: Seq[AttributeReference] = relation.schema.toAttributes - - // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any) = other match { - case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output - case _ => false - } - - override def sameResult(otherPlan: LogicalPlan) = otherPlan match { - case LogicalRelation(otherRelation) => relation == otherRelation - case _ => false - } - - @transient override lazy val statistics = Statistics( - sizeInBytes = BigInt(relation.sizeInBytes) - ) - - /** Used to lookup original attribute capitalization */ - val attributeMap = AttributeMap(output.map(o => (o, o))) - - def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] - - override def simpleString = s"Relation[${output.mkString(",")}] $relation" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala deleted file mode 100644 index d7942dc30934..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.sources - -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.RunnableCommand - -private[sql] case class InsertIntoRelation( - relation: InsertableRelation, - query: LogicalPlan, - overwrite: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext) = { - relation.insert(DataFrame(sqlContext, query), overwrite) - - Seq.empty[Row] - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala deleted file mode 100644 index ead827728cf4..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ /dev/null @@ -1,353 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import scala.language.implicitConversions - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { - - def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { - try { - Some(apply(input)) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => None - case x: Throwable => throw x - } - } - - def parseType(input: String): DataType = { - lexical.initialize(reservedWords) - phrase(dataType)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => throw new DDLException(s"Unsupported dataType: $x") - } - } - - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - - // Data types. - protected val STRING = Keyword("STRING") - protected val BINARY = Keyword("BINARY") - protected val BOOLEAN = Keyword("BOOLEAN") - protected val TINYINT = Keyword("TINYINT") - protected val SMALLINT = Keyword("SMALLINT") - protected val INT = Keyword("INT") - protected val BIGINT = Keyword("BIGINT") - protected val FLOAT = Keyword("FLOAT") - protected val DOUBLE = Keyword("DOUBLE") - protected val DECIMAL = Keyword("DECIMAL") - protected val DATE = Keyword("DATE") - protected val TIMESTAMP = Keyword("TIMESTAMP") - protected val VARCHAR = Keyword("VARCHAR") - protected val ARRAY = Keyword("ARRAY") - protected val MAP = Keyword("MAP") - protected val STRUCT = Keyword("STRUCT") - - protected lazy val ddl: Parser[LogicalPlan] = createTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - ( - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident - ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - opts, - allowExisting.isDefined, - query.get) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - opts, - allowExisting.isDefined) - } - } - ) - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase(), comment).build() - case None => Metadata.empty - } - StructField(columnName, typ, true, meta) - } - - protected lazy val primitiveType: Parser[DataType] = - STRING ^^^ StringType | - BINARY ^^^ BinaryType | - BOOLEAN ^^^ BooleanType | - TINYINT ^^^ ByteType | - SMALLINT ^^^ ShortType | - INT ^^^ IntegerType | - BIGINT ^^^ LongType | - FLOAT ^^^ FloatType | - DOUBLE ^^^ DoubleType | - fixedDecimalType | // decimal with precision/scale - DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale - DATE ^^^ DateType | - TIMESTAMP ^^^ TimestampType | - VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - ARRAY ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } - - protected lazy val mapType: Parser[DataType] = - MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } - - protected lazy val structField: Parser[StructField] = - ident ~ ":" ~ dataType ^^ { - case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true) - } - - protected lazy val structType: Parser[DataType] = - (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { - case fields => StructType(fields) - }) | - (STRUCT ~> "<>" ^^ { - case fields => StructType(Nil) - }) - - private[sql] lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType -} - -object ResolvedDataSource { - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { - case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { - case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") - } - } - - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => { - clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") - } - } - case None => { - clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") - } - } - } - - new ResolvedDataSource(clazz, relation) - } - - def apply( - sqlContext: SQLContext, - provider: String, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { - case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { - case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") - } - } - - val relation = clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.CreateableRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.CreateableRelationProvider] - .createRelation(sqlContext, options, data) - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) - -private[sql] case class CreateTableUsing( - tableName: String, - userSpecifiedSchema: Option[StructType], - provider: String, - temporary: Boolean, - options: Map[String, String], - allowExisting: Boolean) extends Command - -private[sql] case class CreateTableUsingAsSelect( - tableName: String, - provider: String, - temporary: Boolean, - options: Map[String, String], - allowExisting: Boolean, - query: String) extends Command - -private[sql] case class CreateTableUsingAsLogicalPlan( - tableName: String, - provider: String, - temporary: Boolean, - options: Map[String, String], - allowExisting: Boolean, - query: LogicalPlan) extends Command - -private [sql] case class CreateTempTableUsing( - tableName: String, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]) extends RunnableCommand { - - def run(sqlContext: SQLContext) = { - val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - sqlContext.registerRDDAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty - } -} - -private [sql] case class CreateTempTableUsingAsSelect( - tableName: String, - provider: String, - options: Map[String, String], - query: LogicalPlan) extends RunnableCommand { - - def run(sqlContext: SQLContext) = { - val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, options, df) - sqlContext.registerRDDAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - - Seq.empty - } -} - -/** - * Builds a map in which keys are case insensitive - */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] - with Serializable { - - val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) - - override def get(k: String): Option[String] = baseMap.get(k.toLowerCase) - - override def + [B1 >: String](kv: (String, B1)): Map[String, B1] = - baseMap + kv.copy(_1 = kv._1.toLowerCase) - - override def iterator: Iterator[(String, String)] = baseMap.iterator - - override def -(key: String): Map[String, String] = baseMap - key.toLowerCase() -} - -/** - * The exception thrown from the DDL parser. - * @param message - */ -protected[sql] class DDLException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4a9fefc12b9a..3780cbbcc963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,11 +17,128 @@ package org.apache.spark.sql.sources +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the filters that we can push down to the data sources. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * A filter predicate for data sources. + * + * @since 1.3.0 + */ abstract class Filter +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * equal to `value`. + * + * @since 1.3.0 + */ case class EqualTo(attribute: String, value: Any) extends Filter + +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than `value`. + * + * @since 1.3.0 + */ case class GreaterThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than or equal to `value`. + * + * @since 1.3.0 + */ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than `value`. + * + * @since 1.3.0 + */ case class LessThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than or equal to `value`. + * + * @since 1.3.0 + */ case class LessThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. + * + * @since 1.3.0 + */ case class In(attribute: String, values: Array[Any]) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to null. + * + * @since 1.3.0 + */ +case class IsNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. + * + * @since 1.3.0 + */ +case class IsNotNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + * + * @since 1.3.0 + */ +case class And(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + * + * @since 1.3.0 + */ +case class Or(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + * + * @since 1.3.0 + */ +case class Not(child: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + * + * @since 1.3.1 + */ +case class StringStartsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + * + * @since 1.3.1 + */ +case class StringEndsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that contains the string `value`. + * + * @since 1.3.1 + */ +case class StringContains(attribute: String, value: String) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ad0a35b91ebc..b3b326fe612c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -14,13 +14,54 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources -import org.apache.spark.annotation.{Experimental, DeveloperApi} +import scala.collection.mutable +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql._ +import org.apache.spark.util.SerializableConfiguration + +/** + * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.5.0 + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 + */ + def shortName(): String +} /** * ::DeveloperApi:: @@ -34,6 +75,8 @@ import org.apache.spark.sql.types.StructType * data source 'org.apache.spark.sql.json.DefaultSource' * * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.3.0 */ @DeveloperApi trait RelationProvider { @@ -60,9 +103,11 @@ trait RelationProvider { * A new instance of this class with be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that - * users need to provide a schema when using a SchemaRelationProvider. + * users need to provide a schema when using a [[SchemaRelationProvider]]. * A relation provider can inherits both [[RelationProvider]] and [[SchemaRelationProvider]] * if it can support both schema inference and user-specified schemas. + * + * @since 1.3.0 */ @DeveloperApi trait SchemaRelationProvider { @@ -77,24 +122,84 @@ trait SchemaRelationProvider { schema: StructType): BaseRelation } +/** + * ::Experimental:: + * Implemented by objects that produce relations for a specific kind of data source + * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a + * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined + * schema, and an optional list of partition columns, this interface is used to pass in the + * parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + * + * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is + * that users need to provide a schema and a (possibly empty) list of partition columns when + * using a [[HadoopFsRelationProvider]]. A relation provider can inherits both [[RelationProvider]], + * and [[HadoopFsRelationProvider]] if it can support schema inference, user-specified + * schemas, and accessing partitioned relations. + * + * @since 1.4.0 + */ +@Experimental +trait HadoopFsRelationProvider { + /** + * Returns a new base relation with the given parameters, a user defined schema, and a list of + * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity + * is enforced by the Map that is passed to the function. + * + * @param dataSchema Schema of data columns (i.e., columns that are not partition columns). + */ + def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation +} + +/** + * @since 1.3.0 + */ @DeveloperApi -trait CreateableRelationProvider { +trait CreatableRelationProvider { + /** + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + * + * @since 1.3.0 + */ def createRelation( sqlContext: SQLContext, + mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation } /** * ::DeveloperApi:: - * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must - * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]]. Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * * BaseRelations must also define a equality function that only returns true when the two - * instances will return the same data. This equality function is used when determining when + * instances will return the same data. This equality function is used when determining when * it is safe to substitute cached results for a given relation. + * + * @since 1.3.0 */ @DeveloperApi abstract class BaseRelation { @@ -102,21 +207,42 @@ abstract class BaseRelation { def schema: StructType /** - * Returns an estimated size of this relation in bytes. This information is used by the planner + * Returns an estimated size of this relation in bytes. This information is used by the planner * to decided when it is safe to broadcast a relation and can be overridden by sources that * know the size ahead of time. By default, the system will assume that tables are too - * large to broadcast. This method will be called multiple times during query planning + * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. + * + * Note that it is always better to overestimate size than underestimate, because underestimation + * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). + * + * @since 1.3.0 */ - def sizeInBytes = sqlContext.conf.defaultSizeInBytes + def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + * + * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]] + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. + * + * @since 1.4.0 + */ + def needConversion: Boolean = true } /** * ::DeveloperApi:: * A BaseRelation that can produce all of its tuples as an RDD of Row objects. + * + * @since 1.3.0 */ @DeveloperApi -trait TableScan extends BaseRelation { +trait TableScan { def buildScan(): RDD[Row] } @@ -124,9 +250,11 @@ trait TableScan extends BaseRelation { * ::DeveloperApi:: * A BaseRelation that can eliminate unneeded columns before producing an RDD * containing all of its tuples as Row objects. + * + * @since 1.3.0 */ @DeveloperApi -trait PrunedScan extends BaseRelation { +trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -135,29 +263,530 @@ trait PrunedScan extends BaseRelation { * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * + * The actual filter should be the conjunction of all `filters`, + * i.e. they should be "and" together. + * * The pushed down filters are currently purely an optimization as they will all be evaluated * again. This means it is safe to use them with methods that produce false positives such * as filtering partitions based on a bloom filter. + * + * @since 1.3.0 */ @DeveloperApi -trait PrunedFilteredScan extends BaseRelation { +trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } +/** + * ::DeveloperApi:: + * A BaseRelation that can be used to insert data into it through the insert method. + * If overwrite in insert method is true, the old data in the relation should be overwritten with + * the new data. If overwrite in insert method is false, the new data should be appended. + * + * InsertableRelation has the following three assumptions. + * 1. It assumes that the data (Rows in the DataFrame) provided to the insert method + * exactly matches the ordinal of fields in the schema of the BaseRelation. + * 2. It assumes that the schema of this relation will not be changed. + * Even if the insert method updates the schema (e.g. a relation of JSON or Parquet data may have a + * schema update after an insert operation), the new schema will not be used. + * 3. It assumes that fields of the data provided in the insert method are nullable. + * If a data source needs to check the actual nullability of a field, it needs to do it in the + * insert method. + * + * @since 1.3.0 + */ +@DeveloperApi +trait InsertableRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit +} + /** * ::Experimental:: * An interface for experimenting with a more direct connection to the query planner. Compared to * [[PrunedFilteredScan]], this operator receives the raw expressions from the * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. Unlike the other APIs this - * interface is not designed to be binary compatible across releases and thus should only be used + * interface is NOT designed to be binary compatible across releases and thus should only be used * for experimentation. + * + * @since 1.3.0 */ @Experimental -trait CatalystScan extends BaseRelation { +trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } -@DeveloperApi -trait InsertableRelation extends BaseRelation { - def insert(data: DataFrame, overwrite: Boolean): Unit +/** + * ::Experimental:: + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriterFactory extends Serializable { + /** + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. + * + * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that + * this may not point to the final output file. For example, `FileOutputFormat` writes to + * temporary directories and then merge written files back to the final destination. In + * this case, `path` points to a temporary output file under the temporary directory. + * @param dataSchema Schema of the rows to be written. Partition columns are not included in the + * schema if the relation being written is partitioned. + * @param context The Hadoop MapReduce task context. + * + * @since 1.4.0 + */ + def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter +} + +/** + * ::Experimental:: + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriter { + /** + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * tables, dynamic partition columns are not included in rows to be written. + * + * @since 1.4.0 + */ + def write(row: Row): Unit + + /** + * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before + * the task output is committed. + * + * @since 1.4.0 + */ + def close(): Unit + + private var converter: InternalRow => Row = _ + + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } + + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } +} + +/** + * ::Experimental:: + * A [[BaseRelation]] that provides much of the common code required for formats that store their + * data to an HDFS compatible filesystem. + * + * For the read path, similar to [[PrunedFilteredScan]], it can eliminate unneeded columns and + * filter using selected predicates before producing an RDD containing all matching tuples as + * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file + * systems, it's able to discover partitioning information from the paths of input directories, and + * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] + * must override one of the three `buildScan` methods to implement the read path. + * + * For the write path, it provides the ability to write to both non-partitioned and partitioned + * tables. Directory layout of the partitioned tables is compatible with Hive. + * + * @constructor This constructor is for internal uses only. The [[PartitionSpec]] argument is for + * implementing metastore table conversion. + * + * @param maybePartitionSpec An [[HadoopFsRelation]] can be created with an optional + * [[PartitionSpec]], so that partition discovery can be skipped. + * + * @since 1.4.0 + */ +@Experimental +abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) + extends BaseRelation with FileRelation with Logging { + + override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") + + def this() = this(None) + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + private val codegenEnabled = sqlContext.conf.codegenEnabled + + private var _partitionSpec: PartitionSpec = _ + + private class FileStatusCache { + var leafFiles = mutable.Map.empty[Path, FileStatus] + + var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] + + private def listLeafFiles(paths: Array[String]): Set[FileStatus] = { + if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) + } else { + val statuses = paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + logInfo(s"Listing $qualified on driver") + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + }.filterNot { status => + val name = status.getPath.getName + name.toLowerCase == "_temporary" || name.startsWith(".") + } + + val (dirs, files) = statuses.partition(_.isDir) + + if (dirs.isEmpty) { + files.toSet + } else { + files.toSet ++ listLeafFiles(dirs.map(_.getPath.toString)) + } + } + } + + def refresh(): Unit = { + val files = listLeafFiles(paths) + + leafFiles.clear() + leafDirToChildrenFiles.clear() + + leafFiles ++= files.map(f => f.getPath -> f).toMap + leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) + } + } + + private lazy val fileStatusCache = { + val cache = new FileStatusCache + cache.refresh() + cache + } + + protected def cachedLeafStatuses(): Set[FileStatus] = { + fileStatusCache.leafFiles.values.toSet + } + + final private[sql] def partitionSpec: PartitionSpec = { + if (_partitionSpec == null) { + _partitionSpec = maybePartitionSpec + .flatMap { + case spec if spec.partitions.nonEmpty => + Some(spec.copy(partitionColumns = spec.partitionColumns.asNullable)) + case _ => + None + } + .orElse { + // We only know the partition columns and their data types. We need to discover + // partition values. + userDefinedPartitionColumns.map { partitionSchema => + val spec = discoverPartitions() + val partitionColumnTypes = spec.partitionColumns.map(_.dataType) + val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) + } + val castedValues = partitionSchema.zip(literals).map { case (field, literal) => + Cast(literal, field.dataType).eval() + } + p.copy(values = InternalRow.fromSeq(castedValues)) + } + PartitionSpec(partitionSchema, castedPartitions) + } + } + .getOrElse { + if (sqlContext.conf.partitionDiscoveryEnabled()) { + discoverPartitions() + } else { + PartitionSpec(StructType(Nil), Array.empty[Partition]) + } + } + } + _partitionSpec + } + + /** + * Base paths of this relation. For partitioned relations, it should be either root directories + * of all partition directories. + * + * @since 1.4.0 + */ + def paths: Array[String] + + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + + /** + * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically + * discovered. Note that they should always be nullable. + * + * @since 1.4.0 + */ + final def partitionColumns: StructType = + userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) + + /** + * Optional user defined partition columns. + * + * @since 1.4.0 + */ + def userDefinedPartitionColumns: Option[StructType] = None + + private[sql] def refresh(): Unit = { + fileStatusCache.refresh() + if (sqlContext.conf.partitionDiscoveryEnabled()) { + _partitionSpec = discoverPartitions() + } + } + + private def discoverPartitions(): PartitionSpec = { + val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() + // We use leaf dirs containing data files to discover the schema. + val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference) + } + + /** + * Schema of this relation. It consists of columns appearing in [[dataSchema]] and all partition + * columns not appearing in [[dataSchema]]. + * + * @since 1.4.0 + */ + override lazy val schema: StructType = { + val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet + StructType(dataSchema ++ partitionColumns.filterNot { column => + dataSchemaColumnNames.contains(column.name.toLowerCase) + }) + } + + private[sql] def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val inputStatuses = inputPaths.flatMap { input => + val path = new Path(input) + + // First assumes `input` is a directory path, and tries to get all files contained in it. + fileStatusCache.leafDirToChildrenFiles.getOrElse( + path, + // Otherwise, `input` might be a file path + fileStatusCache.leafFiles.get(path).toArray + ).filter { status => + val name = status.getPath.getName + !name.startsWith("_") && !name.startsWith(".") + } + } + + buildScan(requiredColumns, filters, inputStatuses, broadcastedConf) + } + + /** + * Specifies schema of actual data files. For partitioned relations, if one or more partitioned + * columns are contained in the data files, they should also appear in `dataSchema`. + * + * @since 1.4.0 + */ + def dataSchema: StructType + + /** + * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within + * this relation. For partitioned relations, this method is called for each selected partition, + * and builds an `RDD[Row]` containing all rows within that single partition. + * + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * + * @since 1.4.0 + */ + def buildScan(inputFiles: Array[FileStatus]): RDD[Row] = { + throw new UnsupportedOperationException( + "At least one buildScan() method should be overridden to read the relation.") + } + + /** + * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within + * this relation. For partitioned relations, this method is called for each selected partition, + * and builds an `RDD[Row]` containing all rows within that single partition. + * + * @param requiredColumns Required columns. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * + * @since 1.4.0 + */ + // TODO Tries to eliminate the extra Catalyst-to-Scala conversion when `needConversion` is true + // + // PR #7626 separated `Row` and `InternalRow` completely. One of the consequences is that we can + // no longer treat an `InternalRow` containing Catalyst values as a `Row`. Thus we have to + // introduce another row value conversion for data sources whose `needConversion` is true. + def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { + // Yeah, to workaround serialization... + val dataSchema = this.dataSchema + val codegenEnabled = this.codegenEnabled + val needConversion = this.needConversion + + val requiredOutput = requiredColumns.map { col => + val field = dataSchema(col) + BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) + }.toSeq + + val rdd: RDD[Row] = buildScan(inputFiles) + val converted: RDD[InternalRow] = + if (needConversion) { + RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) + } else { + rdd.asInstanceOf[RDD[InternalRow]] + } + + converted.mapPartitions { rows => + val buildProjection = if (codegenEnabled) { + GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) + } else { + () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) + } + + val projectedRows = { + val mutableProjection = buildProjection() + rows.map(r => mutableProjection(r)) + } + + if (needConversion) { + val requiredSchema = StructType(requiredColumns.map(dataSchema(_))) + val toScala = CatalystTypeConverters.createToScalaConverter(requiredSchema) + projectedRows.map(toScala(_).asInstanceOf[Row]) + } else { + projectedRows + } + }.asInstanceOf[RDD[Row]] + } + + /** + * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within + * this relation. For partitioned relations, this method is called for each selected partition, + * and builds an `RDD[Row]` containing all rows within that single partition. + * + * @param requiredColumns Required columns. + * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction + * of all `filters`. The pushed down filters are currently purely an optimization as they + * will all be evaluated again. This means it is safe to use them with methods that produce + * false positives such as filtering partitions based on a bloom filter. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * + * @since 1.4.0 + */ + def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus]): RDD[Row] = { + buildScan(requiredColumns, inputFiles) + } + + /** + * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within + * this relation. For partitioned relations, this method is called for each selected partition, + * and builds an `RDD[Row]` containing all rows within that single partition. + * + * Note: This interface is subject to change in future. + * + * @param requiredColumns Required columns. + * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction + * of all `filters`. The pushed down filters are currently purely an optimization as they + * will all be evaluated again. This means it is safe to use them with methods that produce + * false positives such as filtering partitions based on a bloom filter. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the + * overhead of broadcasting the Configuration for every Hadoop RDD. + * + * @since 1.4.0 + */ + private[sql] def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + buildScan(requiredColumns, filters, inputFiles) + } + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + * + * Note that the only side effect expected here is mutating `job` via its setters. Especially, + * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states + * may cause unexpected behaviors. + * + * @since 1.4.0 + */ + def prepareJobForWrite(job: Job): OutputWriterFactory +} + +private[sql] object HadoopFsRelation extends Logging { + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name + // start with "." are also ignored. + def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { + logInfo(s"Listing ${status.getPath}") + val name = status.getPath.getName.toLowerCase + if (name == "_temporary" || name.startsWith(".")) { + Array.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } + } + + // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play + // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. + // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from + // executor side and reconstruct it on driver side. + case class FakeFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long) + + def listLeafFilesInParallel( + paths: Array[String], + hadoopConf: Configuration, + sparkContext: SparkContext): Set[FileStatus] = { + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val fakeStatuses = sparkContext.parallelize(paths).flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(serializableConfiguration.value) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + Try(listLeafFiles(fs, fs.getFileStatus(qualified))).getOrElse(Array.empty) + }.map { status => + FakeFileStatus( + status.getPath.toString, + status.getLen, + status.isDir, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime) + }.collect() + + fakeStatuses.map { f => + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) + }.toSet + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 006b16fbe07b..2fdd798b44bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. @@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" override def serialize(obj: Any): Seq[Double] = { obj match { @@ -59,4 +59,6 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { } override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/README.md b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md new file mode 100644 index 000000000000..d867f181b972 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md @@ -0,0 +1,7 @@ +README +====== + +Please do not add any class in this place unless it is used by `sql/console` or Python tests. +If you need to create any classes or traits that will be used by tests from both `sql/core` and +`sql/hive`, you can add them in the `src/test` of `sql/core` (tests of `sql/hive` +depend on the test jar of `sql/core`). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala deleted file mode 100644 index 4e1ec38bd015..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.language.implicitConversions - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -object TestSQLContext - extends SQLContext( - new SparkContext( - "local[2]", - "TestSQLContext", - new SparkConf().set("spark.sql.testkey", "true"))) { - - /** Fewer partitions to speed up testing. */ - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt - } - - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) - } - -} diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md new file mode 100644 index 000000000000..421c2ea4f7ae --- /dev/null +++ b/sql/core/src/test/README.md @@ -0,0 +1,29 @@ +# Notes for Parquet compatibility tests + +The following directories and files are used for Parquet compatibility tests: + +``` +. +├── README.md # This file +├── avro +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) +├── gen-java # !! NO TOUCH !! Generated Java code +├── scripts +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift +└── thrift + └── *.thrift # Testing Thrift schema(s) +``` + +To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. + +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. + +## Prerequisites + +Please ensure `avro-tools` and `thrift` are installed. You may install these two on Mac OS X via: + +```bash +$ brew install thrift avro-tools +``` diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl new file mode 100644 index 000000000000..c5eb5b5164cf --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is a test protocol for testing parquet-avro compatibility. +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") +protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + + record Nested { + array nested_ints_column; + string nested_string_column; + } + + record AvroPrimitives { + boolean bool_column; + int int_column; + long long_column; + float float_column; + double double_column; + bytes binary_column; + string string_column; + } + + record AvroOptionalPrimitives { + union { null, boolean } maybe_bool_column; + union { null, int } maybe_int_column; + union { null, long } maybe_long_column; + union { null, float } maybe_float_column; + union { null, double } maybe_double_column; + union { null, bytes } maybe_binary_column; + union { null, string } maybe_string_column; + } + + record AvroNonNullableArrays { + array strings_column; + union { null, array } maybe_ints_column; + } + + record AvroArrayOfArray { + array> int_arrays_column; + } + + record AvroMapOfArray { + map> string_to_ints_column; + } + + record ParquetAvroCompat { + array strings_column; + map string_to_int_column; + map> complex_column; + } +} diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr new file mode 100644 index 000000000000..9ad315b74fb4 --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -0,0 +1,147 @@ +{ + "protocol" : "CompatibilityTest", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", + "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { + "type" : "record", + "name" : "Nested", + "fields" : [ { + "name" : "nested_ints_column", + "type" : { + "type" : "array", + "items" : "int" + } + }, { + "name" : "nested_string_column", + "type" : "string" + } ] + }, { + "type" : "record", + "name" : "AvroPrimitives", + "fields" : [ { + "name" : "bool_column", + "type" : "boolean" + }, { + "name" : "int_column", + "type" : "int" + }, { + "name" : "long_column", + "type" : "long" + }, { + "name" : "float_column", + "type" : "float" + }, { + "name" : "double_column", + "type" : "double" + }, { + "name" : "binary_column", + "type" : "bytes" + }, { + "name" : "string_column", + "type" : "string" + } ] + }, { + "type" : "record", + "name" : "AvroOptionalPrimitives", + "fields" : [ { + "name" : "maybe_bool_column", + "type" : [ "null", "boolean" ] + }, { + "name" : "maybe_int_column", + "type" : [ "null", "int" ] + }, { + "name" : "maybe_long_column", + "type" : [ "null", "long" ] + }, { + "name" : "maybe_float_column", + "type" : [ "null", "float" ] + }, { + "name" : "maybe_double_column", + "type" : [ "null", "double" ] + }, { + "name" : "maybe_binary_column", + "type" : [ "null", "bytes" ] + }, { + "name" : "maybe_string_column", + "type" : [ "null", "string" ] + } ] + }, { + "type" : "record", + "name" : "AvroNonNullableArrays", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } + }, { + "name" : "maybe_ints_column", + "type" : [ "null", { + "type" : "array", + "items" : "int" + } ] + } ] + }, { + "type" : "record", + "name" : "AvroArrayOfArray", + "fields" : [ { + "name" : "int_arrays_column", + "type" : { + "type" : "array", + "items" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "AvroMapOfArray", + "fields" : [ { + "name" : "string_to_ints_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } + }, { + "name" : "string_to_int_column", + "type" : { + "type" : "map", + "values" : "int" + } + }, { + "name" : "complex_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "Nested" + } + } + } ] + } ], + "messages" : { } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java new file mode 100644 index 000000000000..ee327827903e --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroArrayOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List> int_arrays_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroArrayOfArray() {} + + /** + * All-args constructor. + */ + public AvroArrayOfArray(java.util.List> int_arrays_column) { + this.int_arrays_column = int_arrays_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return int_arrays_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: int_arrays_column = (java.util.List>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'int_arrays_column' field. + */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** + * Sets the value of the 'int_arrays_column' field. + * @param value the value to set. + */ + public void setIntArraysColumn(java.util.List> value) { + this.int_arrays_column = value; + } + + /** Creates a new AvroArrayOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing AvroArrayOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroArrayOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List> int_arrays_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroArrayOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'int_arrays_column' field */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** Sets the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder setIntArraysColumn(java.util.List> value) { + validate(fields()[0], value); + this.int_arrays_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'int_arrays_column' field has been set */ + public boolean hasIntArraysColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder clearIntArraysColumn() { + int_arrays_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroArrayOfArray build() { + try { + AvroArrayOfArray record = new AvroArrayOfArray(); + record.int_arrays_column = fieldSetFlags()[0] ? this.int_arrays_column : (java.util.List>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java new file mode 100644 index 000000000000..727f6a7bf733 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroMapOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.Map> string_to_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroMapOfArray() {} + + /** + * All-args constructor. + */ + public AvroMapOfArray(java.util.Map> string_to_ints_column) { + this.string_to_ints_column = string_to_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return string_to_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: string_to_ints_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'string_to_ints_column' field. + */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** + * Sets the value of the 'string_to_ints_column' field. + * @param value the value to set. + */ + public void setStringToIntsColumn(java.util.Map> value) { + this.string_to_ints_column = value; + } + + /** Creates a new AvroMapOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing AvroMapOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroMapOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.Map> string_to_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroMapOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'string_to_ints_column' field */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** Sets the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder setStringToIntsColumn(java.util.Map> value) { + validate(fields()[0], value); + this.string_to_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'string_to_ints_column' field has been set */ + public boolean hasStringToIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder clearStringToIntsColumn() { + string_to_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroMapOfArray build() { + try { + AvroMapOfArray record = new AvroMapOfArray(); + record.string_to_ints_column = fieldSetFlags()[0] ? this.string_to_ints_column : (java.util.Map>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java new file mode 100644 index 000000000000..934793f42f9c --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroNonNullableArrays extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.List maybe_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroNonNullableArrays() {} + + /** + * All-args constructor. + */ + public AvroNonNullableArrays(java.util.List strings_column, java.util.List maybe_ints_column) { + this.strings_column = strings_column; + this.maybe_ints_column = maybe_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return maybe_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: maybe_ints_column = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'maybe_ints_column' field. + */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** + * Sets the value of the 'maybe_ints_column' field. + * @param value the value to set. + */ + public void setMaybeIntsColumn(java.util.List value) { + this.maybe_ints_column = value; + } + + /** Creates a new AvroNonNullableArrays RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing AvroNonNullableArrays instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** + * RecordBuilder for AvroNonNullableArrays instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.List maybe_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing AvroNonNullableArrays instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_ints_column' field */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** Sets the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setMaybeIntsColumn(java.util.List value) { + validate(fields()[1], value); + this.maybe_ints_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_ints_column' field has been set */ + public boolean hasMaybeIntsColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearMaybeIntsColumn() { + maybe_ints_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public AvroNonNullableArrays build() { + try { + AvroNonNullableArrays record = new AvroNonNullableArrays(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.maybe_ints_column = fieldSetFlags()[1] ? this.maybe_ints_column : (java.util.List) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java new file mode 100644 index 000000000000..e4d1ead8dd15 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java @@ -0,0 +1,466 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroOptionalPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroOptionalPrimitives() {} + + /** + * All-args constructor. + */ + public AvroOptionalPrimitives(java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column) { + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return maybe_bool_column; + case 1: return maybe_int_column; + case 2: return maybe_long_column; + case 3: return maybe_float_column; + case 4: return maybe_double_column; + case 5: return maybe_binary_column; + case 6: return maybe_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: maybe_bool_column = (java.lang.Boolean)value$; break; + case 1: maybe_int_column = (java.lang.Integer)value$; break; + case 2: maybe_long_column = (java.lang.Long)value$; break; + case 3: maybe_float_column = (java.lang.Float)value$; break; + case 4: maybe_double_column = (java.lang.Double)value$; break; + case 5: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 6: maybe_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing AvroOptionalPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroOptionalPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroOptionalPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[0], value); + this.maybe_bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[1], value); + this.maybe_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[2], value); + this.maybe_long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[3], value); + this.maybe_float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[4], value); + this.maybe_double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.maybe_binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.maybe_string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroOptionalPrimitives build() { + try { + AvroOptionalPrimitives record = new AvroOptionalPrimitives(); + record.maybe_bool_column = fieldSetFlags()[0] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.maybe_int_column = fieldSetFlags()[1] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.maybe_long_column = fieldSetFlags()[2] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[2]); + record.maybe_float_column = fieldSetFlags()[3] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[3]); + record.maybe_double_column = fieldSetFlags()[4] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[4]); + record.maybe_binary_column = fieldSetFlags()[5] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.maybe_string_column = fieldSetFlags()[6] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java new file mode 100644 index 000000000000..1c2afed16781 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java @@ -0,0 +1,461 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroPrimitives() {} + + /** + * All-args constructor. + */ + public AvroPrimitives(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** Creates a new AvroPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing AvroPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroPrimitives build() { + try { + AvroPrimitives record = new AvroPrimitives(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java new file mode 100644 index 000000000000..28fdc1dfb911 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -0,0 +1,17 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; + +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public interface CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]},{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]},{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]},{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + + @SuppressWarnings("all") + public interface Callback extends CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; + } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java new file mode 100644 index 000000000000..a7bf4841919c --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List nested_ints_column; + @Deprecated public java.lang.String nested_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public Nested() {} + + /** + * All-args constructor. + */ + public Nested(java.util.List nested_ints_column, java.lang.String nested_string_column) { + this.nested_ints_column = nested_ints_column; + this.nested_string_column = nested_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return nested_ints_column; + case 1: return nested_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: nested_ints_column = (java.util.List)value$; break; + case 1: nested_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'nested_ints_column' field. + */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** + * Sets the value of the 'nested_ints_column' field. + * @param value the value to set. + */ + public void setNestedIntsColumn(java.util.List value) { + this.nested_ints_column = value; + } + + /** + * Gets the value of the 'nested_string_column' field. + */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** + * Sets the value of the 'nested_string_column' field. + * @param value the value to set. + */ + public void setNestedStringColumn(java.lang.String value) { + this.nested_string_column = value; + } + + /** Creates a new Nested RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); + } + + /** Creates a new Nested RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); + } + + /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); + } + + /** + * RecordBuilder for Nested instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List nested_ints_column; + private java.lang.String nested_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + super(other); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing Nested instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'nested_ints_column' field */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** Sets the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + validate(fields()[0], value); + this.nested_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'nested_ints_column' field has been set */ + public boolean hasNestedIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + nested_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'nested_string_column' field */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** Sets the value of the 'nested_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + validate(fields()[1], value); + this.nested_string_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'nested_string_column' field has been set */ + public boolean hasNestedStringColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'nested_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + nested_string_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public Nested build() { + try { + Nested record = new Nested(); + record.nested_ints_column = fieldSetFlags()[0] ? this.nested_ints_column : (java.util.List) defaultValue(fields()[0]); + record.nested_string_column = fieldSetFlags()[1] ? this.nested_string_column : (java.lang.String) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java new file mode 100644 index 000000000000..ef12d193f916 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -0,0 +1,250 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.Map string_to_int_column; + @Deprecated public java.util.Map> complex_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetAvroCompat() {} + + /** + * All-args constructor. + */ + public ParquetAvroCompat(java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + this.strings_column = strings_column; + this.string_to_int_column = string_to_int_column; + this.complex_column = complex_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return string_to_int_column; + case 2: return complex_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: string_to_int_column = (java.util.Map)value$; break; + case 2: complex_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'string_to_int_column' field. + */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** + * Sets the value of the 'string_to_int_column' field. + * @param value the value to set. + */ + public void setStringToIntColumn(java.util.Map value) { + this.string_to_int_column = value; + } + + /** + * Gets the value of the 'complex_column' field. + */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** + * Sets the value of the 'complex_column' field. + * @param value the value to set. + */ + public void setComplexColumn(java.util.Map> value) { + this.complex_column = value; + } + + /** Creates a new ParquetAvroCompat RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** + * RecordBuilder for ParquetAvroCompat instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.Map string_to_int_column; + private java.util.Map> complex_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Creates a Builder by copying an existing ParquetAvroCompat instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'string_to_int_column' field */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** Sets the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + validate(fields()[1], value); + this.string_to_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'string_to_int_column' field has been set */ + public boolean hasStringToIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + string_to_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'complex_column' field */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** Sets the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + validate(fields()[2], value); + this.complex_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'complex_column' field has been set */ + public boolean hasComplexColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + complex_column = null; + fieldSetFlags()[2] = false; + return this; + } + + @Override + public ParquetAvroCompat build() { + try { + ParquetAvroCompat record = new ParquetAvroCompat(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.string_to_int_column = fieldSetFlags()[1] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[1]); + record.complex_column = fieldSetFlags()[2] ? this.complex_column : (java.util.Map>) defaultValue(fields()[2]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 000000000000..05fefe4cee75 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 000000000000..00711a0c2a26 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java deleted file mode 100644 index e5588938ea16..000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.io.Serializable; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaAPISuite"); - sqlContext = new SQLContext(sc); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - - @SuppressWarnings("unchecked") - @Test - public void udf1Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - - sqlContext.udf().register("stringLengthTest", new UDF1() { - @Override - public Integer call(String str) throws Exception { - return str.length(); - } - }, DataTypes.IntegerType); - - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); - } - - @SuppressWarnings("unchecked") - @Test - public void udf2Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", - // (String str1, String str2) -> str1.length() + str2.length, - // DataType.IntegerType); - - sqlContext.udf().register("stringLengthTest", new UDF2() { - @Override - public Integer call(String str1, String str2) throws Exception { - return str1.length() + str2.length(); - } - }, DataTypes.IntegerType); - - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java deleted file mode 100644 index badd00d34b9b..000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.*; - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext javaSqlCtx; - - @Before - public void setUp() { - javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); - javaSqlCtx = new SQLContext(javaCtx); - } - - @After - public void tearDown() { - javaCtx.stop(); - javaCtx = null; - javaSqlCtx = null; - } - - public static class Person implements Serializable { - private String name; - private int age; - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public int getAge() { - return age; - } - - public void setAge(int age) { - this.age = age; - } - } - - @Test - public void applySchema() { - List personList = new ArrayList(2); - Person person1 = new Person(); - person1.setName("Michael"); - person1.setAge(29); - personList.add(person1); - Person person2 = new Person(); - person2.setName("Yin"); - person2.setAge(28); - personList.add(person2); - - JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); - fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); - StructType schema = DataTypes.createStructType(fields); - - DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema); - df.registerTempTable("people"); - Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); - - List expected = new ArrayList(2); - expected.add(RowFactory.create("Michael", 29)); - expected.add(RowFactory.create("Yin", 28)); - - Assert.assertEquals(expected, Arrays.asList(actual)); - } - - @Test - public void applySchemaToJSON() { - JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( - "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + - "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + - "\"boolean\":true, \"null\":null}", - "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + - "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + - "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); - fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); - fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); - fields.add(DataTypes.createStructField("integer", DataTypes.IntegerType, true)); - fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); - fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); - fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); - StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); - expectedResult.add( - RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.")); - expectedResult.add( - RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), - false, - 1.7976931348623157E305, - 11, - 21474836469L, - null, - "this is another simple string.")); - - DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); - StructType actualSchema1 = df1.schema(); - Assert.assertEquals(expectedSchema, actualSchema1); - df1.registerTempTable("jsonTable1"); - List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); - Assert.assertEquals(expectedResult, actual1); - - DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); - StructType actualSchema2 = df2.schema(); - Assert.assertEquals(expectedSchema, actualSchema2); - df2.registerTempTable("jsonTable2"); - List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); - Assert.assertEquals(expectedResult, actual2); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java deleted file mode 100644 index 639436368c4a..000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import com.google.common.collect.ImmutableMap; - -import org.apache.spark.sql.Column; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.types.DataTypes; - -import static org.apache.spark.sql.Dsl.*; - -/** - * This test doesn't actually run anything. It is here to check the API compatibility for Java. - */ -public class JavaDsl { - - public static void testDataFrame(final DataFrame df) { - DataFrame df1 = df.select("colA"); - df1 = df.select("colA", "colB"); - - df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); - - df1 = df.filter(col("colA")); - - java.util.Map aggExprs = ImmutableMap.builder() - .put("colA", "sum") - .put("colB", "avg") - .build(); - - df1 = df.agg(aggExprs); - - df1 = df.groupBy("groupCol").agg(aggExprs); - - df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); - - df.orderBy("colA"); - df.orderBy("colA", "colB", "colC"); - df.orderBy(col("colA").desc()); - df.orderBy(col("colA").desc(), col("colB").asc()); - - df.sort("colA"); - df.sort("colA", "colB", "colC"); - df.sort(col("colA").desc()); - df.sort(col("colA").desc(), col("colB").asc()); - - df.as("b"); - - df.limit(5); - - df.unionAll(df1); - df.intersect(df1); - df.except(df1); - - df.sample(true, 0.1, 234); - - df.head(); - df.head(5); - df.first(); - df.count(); - } - - public static void testColumn(final Column c) { - c.asc(); - c.desc(); - - c.endsWith("abcd"); - c.startsWith("afgasdf"); - - c.like("asdf%"); - c.rlike("wef%asdf"); - - c.as("newcol"); - - c.cast("int"); - c.cast(DataTypes.IntegerType); - } - - public static void testDsl() { - // Creating a column. - Column c = col("abcd"); - Column c1 = column("abcd"); - - // Literals - Column l1 = lit(1); - Column l2 = lit(1.0); - Column l3 = lit("abcd"); - - // Functions - Column a = upper(c); - a = lower(c); - a = sqrt(c); - a = abs(c); - - // Aggregates - a = min(c); - a = max(c); - a = sum(c); - a = sumDistinct(c); - a = countDistinct(c, a); - a = avg(c); - a = first(c); - a = last(c); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java deleted file mode 100644 index 80bd74f5b552..000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc; - -import org.junit.*; -import static org.junit.Assert.*; -import java.sql.Connection; -import java.sql.DriverManager; - -import org.apache.spark.SparkEnv; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.api.java.*; -import org.apache.spark.sql.test.TestSQLContext$; - -public class JavaJDBCTest { - static String url = "jdbc:h2:mem:testdb1"; - - static Connection conn = null; - - // This variable will always be null if TestSQLContext is intact when running - // these tests. Some Java tests do not play nicely with others, however; - // they create a SparkContext of their own at startup and stop it at exit. - // This renders TestSQLContext inoperable, meaning we have to do the same - // thing. If this variable is nonnull, that means we allocated a - // SparkContext of our own and that we need to stop it at teardown. - static JavaSparkContext localSparkContext = null; - - static SQLContext sql = TestSQLContext$.MODULE$; - - @Before - public void beforeTest() throws Exception { - if (SparkEnv.get() == null) { // A previous test destroyed TestSQLContext. - localSparkContext = new JavaSparkContext("local", "JavaAPISuite"); - sql = new SQLContext(localSparkContext); - } - Class.forName("org.h2.Driver"); - conn = DriverManager.getConnection(url); - conn.prepareStatement("create schema test").executeUpdate(); - conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate(); - conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate(); - conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate(); - conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate(); - conn.commit(); - } - - @After - public void afterTest() throws Exception { - if (localSparkContext != null) { - localSparkContext.stop(); - localSparkContext = null; - } - try { - conn.close(); - } finally { - conn = null; - } - } - - @Test - public void basicTest() throws Exception { - DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE"); - Row[] rows = rdd.collect(); - assertEquals(rows.length, 3); - } - - @Test - public void partitioningTest() throws Exception { - String[] parts = new String[2]; - parts[0] = "THEID < 2"; - parts[1] = "THEID = 2"; // Deliberately forget about one of them. - DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE", parts); - Row[] rows = rdd.collect(); - assertEquals(rows.length, 2); - } - - @Test - public void writeTest() throws Exception { - DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE"); - JDBCUtils.createJDBCTable(rdd, url, "TEST.PEOPLECOPY", false); - DataFrame rdd2 = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLECOPY"); - Row[] rows = rdd2.collect(); - assertEquals(rows.length, 3); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java new file mode 100644 index 000000000000..bf693c7c393f --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaApplySchemaSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); + } + + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + javaCtx = null; + } + + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + + @Test + public void applySchema() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); + + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(RowFactory.create("Michael", 29)); + expected.add(RowFactory.create("Yin", 28)); + + Assert.assertEquals(expected, Arrays.asList(actual)); + } + + @Test + public void dataFrameRDDOperations() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); + + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { + + public String call(Row row) { + return row.getString(0) + "_" + row.get(1).toString(); + } + }).collect(); + + List expected = new ArrayList(2); + expected.add("Michael_29"); + expected.add("Yin_28"); + + Assert.assertEquals(expected, actual); + } + + @Test + public void applySchemaToJSON() { + JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + + "\"boolean\":true, \"null\":null}", + "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + + "\"boolean\":false, \"null\":null}")); + List fields = new ArrayList(7); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), + true)); + fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); + fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); + fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); + fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); + StructType expectedSchema = DataTypes.createStructType(fields); + List expectedResult = new ArrayList(2); + expectedResult.add( + RowFactory.create( + new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")); + expectedResult.add( + RowFactory.create( + new java.math.BigDecimal("92233720368547758069"), + false, + 1.7976931348623157E305, + 11, + 21474836469L, + null, + "this is another simple string.")); + + DataFrame df1 = sqlContext.read().json(jsonRDD); + StructType actualSchema1 = df1.schema(); + Assert.assertEquals(expectedSchema, actualSchema1); + df1.registerTempTable("jsonTable1"); + List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + Assert.assertEquals(expectedResult, actual1); + + DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + StructType actualSchema2 = df2.schema(); + Assert.assertEquals(expectedSchema, actualSchema2); + df2.registerTempTable("jsonTable2"); + List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + Assert.assertEquals(expectedResult, actual2); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java new file mode 100644 index 000000000000..4867cebf5328 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConverters; +import scala.collection.Seq; + +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; +import org.junit.*; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.types.*; + +public class JavaDataFrameSuite { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + @Test + public void testExecution() { + DataFrame df = context.table("testData").filter("key = 1"); + Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + } + + /** + * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. + */ + @Test + public void testVarargMethods() { + DataFrame df = context.table("testData"); + + df.toDF("key1", "value1"); + + df.select("key", "value"); + df.select(col("key"), col("value")); + df.selectExpr("key", "value + 1"); + + df.sort("key", "value"); + df.sort(col("key"), col("value")); + df.orderBy("key", "value"); + df.orderBy(col("key"), col("value")); + + df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); + df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); + df.agg(first("key"), sum("value")); + + df.groupBy().avg("key"); + df.groupBy().mean("key"); + df.groupBy().max("key"); + df.groupBy().min("key"); + df.groupBy().sum("key"); + + // Varargs in column expressions + df.groupBy().agg(countDistinct("key", "value")); + df.groupBy().agg(countDistinct(col("key"), col("value"))); + df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); + + df2.select(rand(), acos("b")); + df2.select(col("*"), randn(5L)); + } + + @Ignore + public void testShow() { + // This test case is intended ignored, but to make sure it compiles correctly + DataFrame df = context.table("testData"); + df.show(); + df.show(1000); + } + + public static class Bean implements Serializable { + private double a = 0.0; + private Integer[] b = new Integer[]{0, 1}; + private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); + private List d = Arrays.asList("floppy", "disk"); + + public double getA() { + return a; + } + + public Integer[] getB() { + return b; + } + + public Map getC() { + return c; + } + + public List getD() { + return d; + } + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + StructType schema = df.schema(); + Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), + schema.apply("a")); + Assert.assertEquals( + new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), + schema.apply("b")); + ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); + MapType mapType = new MapType(DataTypes.StringType, valueType, true); + Assert.assertEquals( + new StructField("c", mapType, true, Metadata.empty()), + schema.apply("c")); + Assert.assertEquals( + new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), + schema.apply("d")); + Row first = df.select("a", "b", "c", "d").first(); + Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); + // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // verify that it has the expected length, and contains expected elements. + Seq result = first.getAs(1); + Assert.assertEquals(bean.getB().length, result.length()); + for (int i = 0; i < result.length(); i++) { + Assert.assertEquals(bean.getB()[i], result.apply(i)); + } + @SuppressWarnings("unchecked") + Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); + Assert.assertArrayEquals( + bean.getC().get("hello"), + Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava())); + Seq d = first.getAs(3); + Assert.assertEquals(bean.getD().size(), d.length()); + for (int i = 0; i < d.length(); i++) { + Assert.assertEquals(bean.getD().get(i), d.apply(i)); + } + } + + private static Comparator CrosstabRowComparator = new Comparator() { + public int compare(Row row1, Row row2) { + String item1 = row1.getString(0); + String item2 = row2.getString(0); + return item1.compareTo(item2); + } + }; + + @Test + public void testCrosstab() { + DataFrame df = context.table("testData2"); + DataFrame crosstab = df.stat().crosstab("a", "b"); + String[] columnNames = crosstab.schema().fieldNames(); + Assert.assertEquals(columnNames[0], "a_b"); + Assert.assertEquals(columnNames[1], "1"); + Assert.assertEquals(columnNames[2], "2"); + Row[] rows = crosstab.collect(); + Arrays.sort(rows, CrosstabRowComparator); + Integer count = 1; + for (Row row : rows) { + Assert.assertEquals(row.get(0).toString(), count.toString()); + Assert.assertEquals(row.getLong(1), 1L); + Assert.assertEquals(row.getLong(2), 1L); + count++; + } + } + + @Test + public void testFrequentItems() { + DataFrame df = context.table("testData2"); + String[] cols = new String[]{"a"}; + DataFrame results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + } + + @Test + public void testCorrelation() { + DataFrame df = context.table("testData2"); + Double pearsonCorr = df.stat().corr("a", "b", "pearson"); + Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + } + + @Test + public void testCovariance() { + DataFrame df = context.table("testData2"); + Double result = df.stat().cov("a", "b"); + Assert.assertTrue(Math.abs(result) < 1e-6); + } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java similarity index 99% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index fbfcd3f59d91..4ce1d1dddb26 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.math.BigDecimal; import java.sql.Date; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java new file mode 100644 index 000000000000..bb02b58cca9b --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkContext; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.types.DataTypes; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaUDFSuite implements Serializable { + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); + } + + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); + + sqlContext.udf().register("stringLengthTest", new UDF1() { + @Override + public Integer call(String str) throws Exception { + return str.length(); + } + }, DataTypes.IntegerType); + + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + assert(result.getInt(0) == 4); + } + + @SuppressWarnings("unchecked") + @Test + public void udf2Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", + // (String str1, String str2) -> str1.length() + str2.length, + // DataType.IntegerType); + + sqlContext.udf().register("stringLengthTest", new UDF2() { + @Override + public Integer call(String str1, String str2) throws Exception { + return str1.length() + str2.length(); + } + }, DataTypes.IntegerType); + + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + assert(result.getInt(0) == 9); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java new file mode 100644 index 000000000000..6f9e7f68dc39 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaSaveLoadSuite { + + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + String originalDefaultSource; + File path; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.read().json(rdd); + df.registerTempTable("jsonTable"); + } + + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + + @Test + public void saveAndLoad() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); + DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); + checkAnswer(loadedDF, df.collectAsList()); + } + + @Test + public void saveAndLoadWithSchema() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + + checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + } +} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..cfd7889b4ac2 --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index fbed0a782dd3..12fb128149d3 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,8 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. -log4j.additivity.parquet.hadoop.ParquetRecordReader=false -log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF @@ -49,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.parquet.hadoop=WARN +log4j.logger.org.apache.parquet.hadoop=WARN log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 000000000000..41a43fa35d39 Binary files /dev/null and b/sql/core/src/test/resources/nested-array-struct.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 000000000000..520922f73ebb Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-int.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/old-repeated-message.parquet new file mode 100644 index 000000000000..548db9916277 Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-message.parquet differ diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 000000000000..213f1a90291b Binary files /dev/null and b/sql/core/src/test/resources/old-repeated.parquet differ diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet new file mode 100644 index 000000000000..837e4876eea6 Binary files /dev/null and b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 000000000000..8a7eea601d01 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-string.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 000000000000..c29eee35c350 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-struct.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/proto-struct-with-array-many.parquet new file mode 100644 index 000000000000..ff9809675fc0 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array-many.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/proto-struct-with-array.parquet new file mode 100644 index 000000000000..325a8370ad20 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c9221f8f934a..af7590c3d3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,19 +17,24 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) +private case class BigData(s: String) -class CachedTableSuite extends QueryTest { - TestData // Load test tables. +class CachedTableSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -39,150 +44,170 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + } + + test("withColumn doesn't invalidate cached dataframe") { + var evalCount = 0 + val myUDF = udf((x: String) => { evalCount += 1; "result" }) + val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) + df.cache() + + df.collect() + assert(evalCount === 1) + + df.collect() + assert(evalCount === 1) + + val df2 = df.withColumn("newColumn", lit(1)) + df2.collect() + + // We should not reevaluate the cached dataframe + assert(evalCount === 1) } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - cacheTable("tempTable") + ctx.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(true) - assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = true) + assert(None == ctx.cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = false) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) - testData.unpersist(true) - assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None != ctx.cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = true) + assert(None == ctx.cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = false) + assert(None == ctx.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - cacheTable("tempTable1") + ctx.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - uncacheTable("tempTable2") + ctx.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) } test("too big for memory") { - val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") - table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) - table("bigData").unpersist(blocking = true) + val data = "*" * 1000 + ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + .registerTempTable("bigData") + ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(ctx.table("bigData").count() === 200000L) + ctx.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - table("testData").cache() - assertCached(table("testData")) - table("testData").unpersist(blocking = true) + ctx.table("testData").cache() + assertCached(ctx.table("testData")) + ctx.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - table("testData").cache() - table("testData").count() - table("testData").unpersist(blocking = true) - assertCached(table("testData"), 0) + ctx.table("testData").cache() + ctx.table("testData").count() + ctx.table("testData").unpersist(blocking = true) + assertCached(ctx.table("testData"), 0) } test("isCached") { - cacheTable("testData") + ctx.cacheTable("testData") - assertCached(table("testData")) - assert(table("testData").queryExecution.withCachedData match { + assertCached(ctx.table("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - uncacheTable("testData") - assert(!isCached("testData")) - assert(table("testData").queryExecution.withCachedData match { + ctx.uncacheTable("testData") + assert(!ctx.isCached("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - cacheTable("testData") - assertCached(table("testData")) + ctx.cacheTable("testData") + assertCached(ctx.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - cacheTable("testData") + ctx.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - uncacheTable("testData") + ctx.uncacheTable("testData") } test("read from cached table and uncache") { - cacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData")) + ctx.cacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData")) - uncacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData"), 0) + ctx.uncacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("testData") + ctx.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - cacheTable("selectStar") + ctx.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - uncacheTable("selectStar") + ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - cacheTable("testData") + ctx.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - uncacheTable("testData") + ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -190,39 +215,46 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!isCached("testData"), "Table 'testData' should not be cached") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + ctx.uncacheTable("testCacheTable") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + ctx.uncacheTable("testCacheTable") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -234,13 +266,15 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("testData") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + ctx.uncacheTable("testData") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -249,21 +283,57 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - table("t1") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + ctx.table("t1") + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - cacheTable("t1") + ctx.cacheTable("t1") + + assert(ctx.isCached("t1")) + assert(ctx.isCached("t2")) + + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!ctx.isCached("t2")) + } - assert(isCached("t1")) - assert(isCached("t2")) + test("Clear all cache") { + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") + ctx.clearCache() + assert(ctx.cacheManager.isEmpty) + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") + sql("Clear CACHE") + assert(ctx.cacheManager.isEmpty) + } + + test("Clear accumulators when uncacheTable to prevent memory leaking") { + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + + ctx.cacheTable("t1") + ctx.cacheTable("t2") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) - assert(!isCached("t2")) + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + ctx.uncacheTable("t1") + ctx.uncacheTable("t2") + assert((accsSize - 2) == Accumulators.originals.size) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa4cdecbcb34..37738ec5b3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,53 +17,172 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.{Project, TungstenProject} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ -class ColumnExpressionSuite extends QueryTest { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - // TODO: Add test cases for bitwise operations. + private lazy val booleanData = { + ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } - test("computability check") { - def shouldBeComputable(c: Column): Unit = assert(c.isComputable === true) + test("column names with space") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot") - def shouldNotBeComputable(c: Column): Unit = { - assert(c.isComputable === false) - intercept[UnsupportedOperationException] { c.head() } - } + checkAnswer( + df.select(df("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select($"name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(col("name with space")), + Row(1) :: Nil) - shouldBeComputable(testData2("a")) - shouldBeComputable(testData2("b")) + checkAnswer( + df.select("name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(expr("`name with space`")), + Row(1) :: Nil) + } + + test("column names with dot") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot").as("a") - shouldBeComputable(testData2("a") + testData2("b")) - shouldBeComputable(testData2("a") + testData2("b") + 1) + checkAnswer( + df.select(df("`name.with.dot`")), + Row("a") :: Nil) - shouldBeComputable(-testData2("a")) - shouldBeComputable(!testData2("a")) + checkAnswer( + df.select($"`name.with.dot`"), + Row("a") :: Nil) - shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b")) - shouldBeComputable( - testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d")) - shouldBeComputable( - testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b")) + checkAnswer( + df.select(col("`name.with.dot`")), + Row("a") :: Nil) - // Literals and unresolved columns should not be computable. - shouldNotBeComputable(col("1")) - shouldNotBeComputable(col("1") + 2) - shouldNotBeComputable(lit(100)) - shouldNotBeComputable(lit(100) + 10) - shouldNotBeComputable(-col("1")) - shouldNotBeComputable(!col("1")) + checkAnswer( + df.select("`name.with.dot`"), + Row("a") :: Nil) - // Getting data from different frames should not be computable. - shouldNotBeComputable(testData2("a") + testData("key")) - shouldNotBeComputable(testData2("a") + 1 + testData("key")) + checkAnswer( + df.select(expr("`name.with.dot`")), + Row("a") :: Nil) - // Aggregate functions alone should not be computable. - shouldNotBeComputable(sum(testData2("a"))) + checkAnswer( + df.select(df("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("a.`name.with.dot`")), + Row("a") :: Nil) + } + + test("alias") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + assert(df.select(df("a").as("b")).columns.head === "b") + assert(df.select(df("a").alias("b")).columns.head === "b") + } + + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + + test("single explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + + test("self join explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } + + test("collect on column produced by a binary operator") { + val df = Seq((1, 2, 3)).toDF("a", "b", "c") + checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) + checkAnswer(df.select(df("a") + df("b").as("c")), Seq(Row(3))) } test("star") { @@ -71,8 +190,7 @@ class ColumnExpressionSuite extends QueryTest { } test("star qualified by data frame object") { - // This is not yet supported. - val df = testData.toDataFrame + val df = testData.toDF val goldAnswer = df.collect().toSeq checkAnswer(df.select(df("*")), goldAnswer) @@ -149,14 +267,66 @@ class ColumnExpressionSuite extends QueryTest { test("isNull") { checkAnswer( - nullStrings.toDataFrame.where($"s".isNull), + nullStrings.toDF.where($"s".isNull), nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + + checkAnswer( + sql("select isnull(null), isnull(1)"), + Row(true, false)) } test("isNotNull") { checkAnswer( - nullStrings.toDataFrame.where($"s".isNotNull), + nullStrings.toDF.where($"s".isNotNull), nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + + checkAnswer( + sql("select isnotnull(null), isnotnull('a')"), + Row(false, true)) + } + + test("isNaN") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(Double.NaN, Float.NaN) :: + Row(math.log(-1), math.log(-3).toFloat) :: + Row(null, null) :: + Row(Double.MaxValue, Float.MinValue):: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", FloatType)))) + + checkAnswer( + testData.select($"a".isNaN, $"b".isNaN), + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) + + checkAnswer( + testData.select(isNaN($"a"), isNaN($"b")), + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) + + checkAnswer( + sql("select isnan(15), isnan('invalid')"), + Row(false, false)) + } + + test("nanvl") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), + StructField("c", DoubleType), StructField("d", DoubleType), + StructField("e", FloatType), StructField("f", IntegerType)))) + + checkAnswer( + testData.select( + nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"), + nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)), + nanvl($"b", $"e"), nanvl($"e", $"f")), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) + ) + testData.registerTempTable("t") + checkAnswer( + sql( + "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + + " nanvl(b, e), nanvl(e, f) from t"), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) + ) } test("===") { @@ -180,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -240,12 +410,41 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } - val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + test("between") { + val testData = ctx.sparkContext.parallelize( + (0, 1, 2) :: + (1, 2, 3) :: + (2, 1, 0) :: + (2, 2, 4) :: + (3, 1, 6) :: + (3, 2, 0) :: Nil).toDF("a", "b", "c") + val expectAnswer = testData.collect().toSeq. + filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2)) + + checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) + } + + test("in") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".isin(1, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, 1)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"b".isin("y", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".isin("z", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".isin("z", "y")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + intercept[AnalysisException] { + df2.filter($"a".isin($"b")) + } + } test("&&") { checkAnswer( @@ -275,36 +474,40 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } - test("sqrt") { - checkAnswer( - testData.select(sqrt('key)).orderBy('key.asc), - (1 to 100).map(n => Row(math.sqrt(n))) - ) + test("SPARK-7321 when conditional statements") { + val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value") checkAnswer( - testData.select(sqrt('value), 'key).orderBy('key.asc, 'value.asc), - (1 to 100).map(n => Row(math.sqrt(n), n)) + testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), + Seq(Row(-1), Row(-2), Row(0)) ) + // Without the ending otherwise, return null for unmatched conditions. + // Also test putting a non-literal value in the expression. checkAnswer( - testData.select(sqrt(lit(null))), - (1 to 100).map(_ => Row(null)) + testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)), + Seq(Row(-1), Row(-2), Row(null)) ) + + // Test error handling for invalid expressions. + intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) } + intercept[IllegalArgumentException] { $"key".otherwise(-1) } + intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) } } - test("abs") { + test("sqrt") { checkAnswer( - testData.select(abs('key)).orderBy('key.asc), - (1 to 100).map(n => Row(n)) + testData.select(sqrt('key)).orderBy('key.asc), + (1 to 100).map(n => Row(math.sqrt(n))) ) checkAnswer( - negativeData.select(abs('key)).orderBy('key.desc), - (1 to 100).map(n => Row(n)) + testData.select(sqrt('value), 'key).orderBy('key.asc, 'value.asc), + (1 to 100).map(n => Row(math.sqrt(n), n)) ) checkAnswer( - testData.select(abs(lit(null))), + testData.select(sqrt(lit(null))), (1 to 100).map(_ => Row(null)) ) } @@ -324,6 +527,10 @@ class ColumnExpressionSuite extends QueryTest { testData.select(upper(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + sql("SELECT upper('aB'), ucase('cDe')"), + Row("AB", "CDE")) } test("lower") { @@ -341,5 +548,155 @@ class ColumnExpressionSuite extends QueryTest { testData.select(lower(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + sql("SELECT lower('aB'), lcase('cDe')"), + Row("ab", "cde")) } + + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + + test("sparkPartitionId") { + // Make sure we have 2 partitions, each with 2 records. + val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(sparkPartitionId()), + Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil + ) + } + + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + + test("lift alias out of cast") { + compareExpressions( + col("1234").as("name").cast("int").expr, + col("1234").cast("int").as("name").expr) + } + + test("columns can be compared") { + assert('key.desc == 'key.desc) + assert('key.desc != 'key.asc) + } + + test("alias with metadata") { + val metadata = new MetadataBuilder() + .putString("originName", "value") + .build() + val schema = testData + .select($"*", col("value").as("abc", metadata)) + .schema + assert(schema("value").metadata === Metadata.empty) + assert(schema("abc").metadata === metadata) + } + + test("rand") { + val randCol = testData.select($"key", rand(5L).as("rand")) + randCol.columns.length should be (2) + val rows = randCol.collect() + rows.foreach { row => + assert(row.getDouble(1) <= 1.0) + assert(row.getDouble(1) >= 0.0) + } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + case tungstenProject: TungstenProject => tungstenProject + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } + } + + test("randn") { + val randCol = testData.select('key, randn(5L).as("rand")) + randCol.columns.length should be (2) + val rows = randCol.collect() + rows.foreach { row => + assert(row.getDouble(1) <= 4.0) + assert(row.getDouble(1) >= -4.0) + } + } + + test("bitwiseAND") { + checkAnswer( + testData2.select($"a".bitwiseAND(75)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) & 75))) + + checkAnswer( + testData2.select($"a".bitwiseAND($"b").bitwiseAND(22)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) & r.getInt(1) & 22))) + } + + test("bitwiseOR") { + checkAnswer( + testData2.select($"a".bitwiseOR(170)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) | 170))) + + checkAnswer( + testData2.select($"a".bitwiseOR($"b").bitwiseOR(42)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) | r.getInt(1) | 42))) + } + + test("bitwiseXOR") { + checkAnswer( + testData2.select($"a".bitwiseXOR(112)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ 112))) + + checkAnswer( + testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)), + testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala new file mode 100644 index 000000000000..72cf7aab0b97 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DecimalType + + +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("groupBy") { + checkAnswer( + testData2.groupBy("a").agg(sum($"b")), + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) + ) + checkAnswer( + testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)), + Row(9) + ) + checkAnswer( + testData2.groupBy("a").agg(count("*")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("*" -> "count")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("b" -> "sum")), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) + .toDF("key", "value1", "value2", "rest") + + checkAnswer( + df1.groupBy("key").min(), + df1.groupBy("key").min("value1", "value2").collect() + ) + checkAnswer( + df1.groupBy("key").min("value2"), + Seq(Row("a", 0), Row("b", 4)) + ) + } + + test("spark.sql.retainGroupColumns config") { + checkAnswer( + testData2.groupBy("a").agg(sum($"b")), + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) + ) + + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + checkAnswer( + testData2.groupBy("a").agg(sum($"b")), + Seq(Row(3), Row(3), Row(3)) + ) + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + } + + test("agg without groups") { + checkAnswer( + testData2.agg(sum('b)), + Row(9) + ) + } + + test("average") { + checkAnswer( + testData2.agg(avg('a)), + Row(2.0)) + + // Also check mean + checkAnswer( + testData2.agg(mean('a)), + Row(2.0)) + + checkAnswer( + testData2.agg(avg('a), sumDistinct('a)), // non-partial + Row(2.0, 6.0) :: Nil) + + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + checkAnswer( + decimalData.agg(avg('a), sumDistinct('a)), // non-partial + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + + checkAnswer( + decimalData.agg(avg('a cast DecimalType(10, 2))), + Row(new java.math.BigDecimal(2.0))) + // non-partial + checkAnswer( + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + } + + test("null average") { + checkAnswer( + testData3.agg(avg('b)), + Row(2.0)) + + checkAnswer( + testData3.agg(avg('b), countDistinct('b)), + Row(2.0, 1)) + + checkAnswer( + testData3.agg(avg('b), sumDistinct('b)), // non-partial + Row(2.0, 2.0)) + } + + test("zero average") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(avg('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial + Row(null, null)) + } + + test("count") { + assert(testData2.count() === testData2.map(_ => 1).count()) + + checkAnswer( + testData2.agg(count('a), sumDistinct('a)), // non-partial + Row(6, 6.0)) + } + + test("null count") { + checkAnswer( + testData3.groupBy('a).agg(count('b)), + Seq(Row(1, 0), Row(2, 1)) + ) + + checkAnswer( + testData3.groupBy('a).agg(count('a + 'b)), + Seq(Row(1, 0), Row(2, 1)) + ) + + checkAnswer( + testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), + Row(2, 1, 2, 2, 1) + ) + + checkAnswer( + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial + Row(1, 1, 2) + ) + } + + test("zero count") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial + Row(0, null)) + } + + test("zero sum") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(sum('a)), + Row(null)) + } + + test("zero sum distinct") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(sumDistinct('a)), + Row(null)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala new file mode 100644 index 000000000000..3c359dd840ab --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). + */ +class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("UDF on struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(struct($"a").as("s")).select(f($"s.a")).collect() + } + + test("UDF on named_struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() + } + + test("UDF on array") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala new file mode 100644 index 000000000000..9d965258e389 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +/** + * Test suite for functions in [[org.apache.spark.sql.functions]]. + */ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("array with column name") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array("a", "b")).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 1)) + } + + test("array with column expression") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array(col("a"), col("b") + col("b"))).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + } + + // Turn this on once we add a rule to the analyzer to throw a friendly exception + ignore("array: throw exception if putting columns of different types into an array") { + val df = Seq((0, "str")).toDF("a", "b") + intercept[AnalysisException] { + df.select(array("a", "b")) + } + } + + test("struct with column name") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct("a", "b")).first() + + val expectedType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(1, "str")) + } + + test("struct with column expression") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first() + + val expectedType = StructType(Seq( + StructField("c", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(2, "str")) + } + + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), col("b"))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct(lit("v"), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) + } + + test("constant functions") { + checkAnswer( + sql("SELECT E()"), + Row(scala.math.E) + ) + checkAnswer( + sql("SELECT PI()"), + Row(scala.math.Pi) + ) + } + + test("bitwiseNOT") { + checkAnswer( + testData2.select(bitwiseNOT($"a")), + testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) + } + + test("bin") { + val df = Seq[(Integer, Integer)]((12, null)).toDF("a", "b") + checkAnswer( + df.select(bin("a"), bin("b")), + Row("1100", null)) + checkAnswer( + df.selectExpr("bin(a)", "bin(b)"), + Row("1100", null)) + } + + test("if function") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"), + Row("one", "not_one")) + } + + test("nvl function") { + checkAnswer( + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + Row("x", "y", null)) + } + + test("misc md5 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(md5($"a"), md5($"b")), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + + checkAnswer( + df.selectExpr("md5(a)", "md5(b)"), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + } + + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1($"b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2($"b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + + test("misc crc32 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(crc32($"a"), crc32($"b")), + Row(2743272264L, 2180413220L)) + + checkAnswer( + df.selectExpr("crc32(a)", "crc32(b)"), + Row(2743272264L, 2180413220L)) + } + + test("string function find_in_set") { + val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") + + checkAnswer( + df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"), + Row(3, 0)) + } + + test("conditional function: least") { + checkAnswer( + testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), + Row(-1) + ) + checkAnswer( + sql("SELECT least(a, 2) as l from testData2 order by l"), + Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) + ) + } + + test("conditional function: greatest") { + checkAnswer( + testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), + Row(3) + ) + checkAnswer( + sql("SELECT greatest(a, 2) as g from testData2 order by g"), + Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) + ) + } + + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") + checkAnswer( + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") + checkAnswer( + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) + checkAnswer( + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) + } + + test("sort_array function") { + val df = Seq( + (Array[Int](2, 1, 3), Array("b", "c", "a")), + (Array[Int](), Array[String]()), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(sort_array($"a"), sort_array($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.select(sort_array($"a", false), sort_array($"b", false)), + Seq( + Row(Seq(3, 2, 1), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("sort_array(a)", "sort_array(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), + Seq( + Row(Seq(1, 2, 3), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + + val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df2.selectExpr("sort_array(a)").collect() + }.getMessage().contains("does not support sorting array of type array")) + + val df3 = Seq(("xxx", "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df3.selectExpr("sort_array(a)").collect() + }.getMessage().contains("only supports array input")) + } + + test("array size function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), + (Seq[Int](1, 2, 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size($"a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("map size function") { + val df = Seq( + (Map[Int, Int](1 -> 1, 2 -> 2), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size($"a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("array contains function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df.select(array_contains(df("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, 1)"), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.select(array_contains(array(lit(2), lit(null)), 1)), + Seq(Row(false), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df.select(array_contains(df("a"), null)) + } + intercept[AnalysisException] { + df.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df.selectExpr("array_contains(null, 1)") + } + + // In hive, if either argument has a matching type has a null value, return false, even if + // the first argument array contains a null and the second argument is null + checkAnswer( + df.selectExpr("array_contains(array(array(1), null)[1], 1)"), + Seq(Row(false), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), + Seq(Row(false), Row(false)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala new file mode 100644 index 000000000000..e5d7d63441a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -0,0 +1,54 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("RDD of tuples") { + checkAnswer( + ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } + + test("Seq of tuples") { + checkAnswer( + (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } + + test("RDD[Int]") { + checkAnswer( + ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + (1 to 10).map(i => Row(i))) + } + + test("RDD[Long]") { + checkAnswer( + ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + (1L to 10L).map(i => Row(i))) + } + + test("RDD[String]") { + checkAnswer( + ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + (1 to 10).map(i => Row(i.toString))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala new file mode 100644 index 000000000000..e2716d7841d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("join - join using") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") + + checkAnswer( + df.join(df2, "int"), + Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) + } + + test("join - join using multiple columns") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "int2")), + Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) + } + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("[SPARK-6231] join - self join auto resolve ambiguity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("key")), + Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df, df("key") === df("key") && df("value") === 1), + Row(1, "1", 1, "1") :: Nil) + + val left = df.groupBy("key").agg(count("*")) + val right = df.groupBy("key").agg(sum("key")) + checkAnswer( + left.join(right, left("key") === right("key")), + Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) + } + + test("broadcast join hint") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + + // equijoin - should be converted into broadcast join + val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) + + // no join key -- should not be a broadcast join + val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) + + // planner should not crash without a join + broadcast(df1).queryExecution.executedPlan + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala new file mode 100644 index 000000000000..329ffb66083b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.test.SharedSQLContext + + +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + def createDF(): DataFrame = { + Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Bob", 16, 176.5), + ("Alice", null, 164.3), + ("David", 60, null), + ("Nina", 25, Double.NaN), + ("Amy", null, null), + (null, null, null) + ).toDF("name", "age", "height") + } + + test("drop") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("name" :: Nil).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) + + checkAnswer( + input.na.drop("age" :: Nil).select("name"), + Row("Bob") :: Row("David") :: Row("Nina") :: Nil) + + checkAnswer( + input.na.drop("age" :: "height" :: Nil), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(), + rows(0)) + + // dropna on an a dataframe with no column should return an empty data frame. + val empty = input.sqlContext.emptyDataFrame.select() + assert(empty.na.drop().count() === 0L) + + // Make sure the columns are properly named. + assert(input.na.drop().columns.toSeq === input.columns.toSeq) + } + + test("drop with how") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("all").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) + + checkAnswer( + input.na.drop("any"), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("any", Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("all", Seq("age", "height")).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Nil) + } + + test("drop with threshold") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop(2, Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(3, Seq("name", "age", "height")), + rows(0)) + + // Make sure the columns are properly named. + assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq) + } + + test("fill") { + val input = createDF() + + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + } + + test("fill with map") { + val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( + (null, null, null, null)).toDF("a", "b", "c", "d") + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + )), + Row("test", null, 1, 2.2)) + + // Test Java version + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + ).asJava), + Row("test", null, 1, 2.2)) + } + + test("replace") { + val input = createDF() + + // Replace two numeric columns: age and height + val out = input.na.replace(Seq("age", "height"), Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )).collect() + + assert(out(0) === Row("Bob", 61, 176.5)) + assert(out(1) === Row("Alice", null, 461.3)) + assert(out(2) === Row("David", 6, null)) + assert(out(3).get(2).asInstanceOf[Double].isNaN) + assert(out(4) === Row("Amy", null, null)) + assert(out(5) === Row(null, null, null)) + + // Replace only the age column + val out1 = input.na.replace("age", Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )).collect() + + assert(out1(0) === Row("Bob", 61, 176.5)) + assert(out1(1) === Row("Alice", null, 164.3)) + assert(out1(2) === Row("David", 6, null)) + assert(out1(3).get(2).asInstanceOf[Double].isNaN) + assert(out1(4) === Row("Amy", null, null)) + assert(out1(5) === Row(null, null, null)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala new file mode 100644 index 000000000000..28bdd6f83b68 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Random + +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private def toLetter(i: Int): String = (i + 97).toChar.toString + + test("sample with replacement") { + val n = 100 + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + Seq(5, 10, 52, 73).map(Row(_)) + ) + } + + test("sample without replacement") { + val n = 100 + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + Seq(16, 23, 88, 100).map(Row(_)) + ) + } + + test("randomSplit") { + val n = 600 + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.stat.corr("a", "b", "pearson") + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.stat.corr("a", "c", "pearson") + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.stat.corr("a", "b", "pearson") + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) + } + + test("covariance") { + val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") + + val results = df.stat.cov("singles", "doubles") + assert(math.abs(results - 55.0 / 3) < 1e-12) + intercept[IllegalArgumentException] { + df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes + } + val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b") + val decimalRes = decimalData.stat.cov("a", "b") + assert(math.abs(decimalRes) < 1e-12) + } + + test("crosstab") { + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") + val crosstab = df.stat.crosstab("a", "b") + val columnNames = crosstab.schema.fieldNames + assert(columnNames(0) === "a_b") + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 until columnNames.length) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } + } + + test("special crosstab elements (., '', null, ``)") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) + } + + test("Frequent Items") { + val rows = Seq.tabulate(1000) { i => + if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) + } + val df = rows.toDF("numbers", "letters", "negDoubles") + + val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) + val items = results.collect().head + assert(items.getSeq[Int](0).contains(1)) + assert(items.getSeq[String](1).contains(toLetter(1))) + + val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) + val items2 = singleColResults.collect().head + assert(items2.getSeq[Double](0).contains(-1.0)) + } + + test("Frequent Items 2") { + val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) + // this is a regression test, where when merging partitions, we omitted values with higher + // counts than those that existed in the map when the map was full. This test should also fail + // if anything like SPARK-9614 is observed once again + val df = rows.mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { // must come from one of the later merges, therefore higher partition index + Iterator("3", "3", "3", "3", "3") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + val results = df.stat.freqItems(Array("a"), 0.25) + val items = results.collect().head.getSeq[String](0) + assert(items.contains("3")) + assert(items.length === 1) + } + + test("sampleBy") { + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f6b65a81ce05..284fff184085 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,55 +17,176 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.types._ - -/* Implicits */ -import org.apache.spark.sql.test.TestSQLContext._ +import java.io.File import scala.language.postfixOps +import scala.util.Random + +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest { - import org.apache.spark.sql.TestData._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + // Eager analysis. + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + + // No more eager analysis once the flag is turned off + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { + testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } + + test("dataframe toString") { + assert(testData.toString === "[key: int, value: string]") + assert(testData("key").toString === "key") + assert($"test".toString === "test") + } + + test("rename nested groupby") { + val df = Seq((1, (1, 1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + + test("invalid plan toString, debug mode") { + // Turn on debug mode so we can see invalid query plans. + import org.apache.spark.sql.execution.debug._ + + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + sqlContext.debug() + + val badPlan = testData.select('badColumn) + + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) } } + test("access complex data") { + assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) + assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) + assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) + } + test("table scan") { checkAnswer( testData, testData.collect().toSeq) } - test("repartition") { - checkAnswer( - testData.select('key).repartition(10).select('key), - testData.select('key).collect().toSeq) + test("empty data frame") { + assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(sqlContext.emptyDataFrame.count() === 0) + } + + test("head and take") { + assert(testData.take(2) === testData.collect().take(2)) + assert(testData.head(2) === testData.collect().take(2)) + assert(testData.head(2).head.schema === testData.schema) } - test("agg") { + test("simple explode") { + val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") + checkAnswer( - testData2.groupBy("a").agg($"a", sum($"b")), - Seq(Row(1,3), Row(2,3), Row(3,3)) + df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), + Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil ) + } + + test("explode") { + val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") + val df2 = + df.explode('letters) { + case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq + } + checkAnswer( - testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), - Row(9) + df2 + .select('_1 as 'letter, 'number) + .groupBy('letter) + .agg(countDistinct('number)), + Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil ) + } + + test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains( + "Cannot explode *, explode can only be applied on a specific column.")) + + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + + test("explode alias and star") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + checkAnswer( - testData2.agg(sum('b)), - Row(9) - ) + df.select(explode($"a").as("a"), $"*"), + Row("a", Seq("a"), 1) :: Nil) + } + + test("selectExpr") { + checkAnswer( + testData.selectExpr("abs(key)", "value"), + testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq) + } + + test("selectExpr with alias") { + checkAnswer( + testData.selectExpr("key as k").select("k"), + testData.select("key").collect().toSeq) + } + + test("filterExpr") { + checkAnswer( + testData.filter("key > 90"), + testData.collect().filter(_.getInt(0) > 90).toSeq) + } + + test("filterExpr using where") { + checkAnswer( + testData.where("key > 50"), + testData.collect().filter(_.getInt(0) > 50).toSeq) + } + + test("repartition") { + checkAnswer( + testData.select('key).repartition(10).select('key), + testData.select('key).collect().toSeq) + } + + test("coalesce") { + assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) + + checkAnswer( + testData.select('key).coalesce(1).select('key), + testData.select('key).collect().toSeq) } test("convert $\"attribute name\" into unresolved attribute") { @@ -115,35 +236,39 @@ class DataFrameSuite extends QueryTest { test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) + + checkAnswer( + testData2.orderBy(asc("a"), desc("b")), + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( - arrayData.orderBy('data.getItem(0).asc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDF().orderBy('data.getItem(0).asc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( - arrayData.orderBy('data.getItem(0).desc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDF().orderBy('data.getItem(0).desc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - arrayData.orderBy('data.getItem(1).asc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDF().orderBy('data.getItem(1).asc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - arrayData.orderBy('data.getItem(1).desc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) + arrayData.toDF().orderBy('data.getItem(1).desc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -152,146 +277,614 @@ class DataFrameSuite extends QueryTest { testData.take(10).toSeq) checkAnswer( - arrayData.limit(1), + arrayData.toDF().limit(1), arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) checkAnswer( - mapData.limit(1), + mapData.toDF().limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } - test("average") { + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + } + + test("intersect") { checkAnswer( - testData2.agg(avg('a)), - Row(2.0)) + lowerCaseData.intersect(lowerCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + } + + test("udf") { + val foo = udf((a: Int, b: String) => a.toString + b) checkAnswer( - testData2.agg(avg('a), sumDistinct('a)), // non-partial - Row(2.0, 6.0) :: Nil) + // SELECT *, foo(key, value) FROM testData + testData.select($"*", foo('key, 'value)).limit(3), + Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil + ) + } + test("deprecated callUdf in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUdf", (v: Int) => v * v) checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) + df.select($"id", callUdf("simpleUdf", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + + test("callUDF in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUDF", (v: Int) => v * v) checkAnswer( - decimalData.agg(avg('a), sumDistinct('a)), // non-partial - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + df.select($"id", callUDF("simpleUDF", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2))), - Row(new java.math.BigDecimal(2.0))) + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) + } + + test("replace column using withColumn") { + val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + df3.select("x"), + Row(2) :: Row(3) :: Row(4) :: Nil) } - test("null average") { + test("drop column using drop") { + val df = testData.drop("key") checkAnswer( - testData3.agg(avg('b)), - Row(2.0)) + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) + } + test("drop unknown column (no-op)") { + val df = testData.drop("random") checkAnswer( - testData3.agg(avg('b), countDistinct('b)), - Row(2.0, 1)) + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + test("drop column using drop with column reference") { + val col = testData("key") + val df = testData.drop(col) checkAnswer( - testData3.agg(avg('b), sumDistinct('b)), // non-partial - Row(2.0, 2.0)) + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) } - test("zero average") { + test("drop unknown column (no-op) with column reference") { + val col = Column("random") + val df = testData.drop(col) checkAnswer( - emptyTableData.agg(avg('a)), - Row(null)) + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + test("drop unknown column with same name (no-op) with column reference") { + val col = Column("key") + val df = testData.drop(col) checkAnswer( - emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial - Row(null, null)) + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) } - test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + test("drop column after join with duplicate columns using column reference") { + val newSalary = salary.withColumnRenamed("personId", "id") + val col = newSalary("id") + // this join will result in duplicate "id" columns + val joinedDf = person.join(newSalary, + person("id") === newSalary("id"), "inner") + // remove only the "id" column that was associated with newSalary + val df = joinedDf.drop(col) + checkAnswer( + df, + joinedDf.collect().map { + case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => + Row(id, name, age, salary) + }.toSeq) + assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) + assert(df("id") == person("id")) + } + test("withColumnRenamed") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) + .withColumnRenamed("value", "valueRenamed") checkAnswer( - testData2.agg(count('a), sumDistinct('a)), // non-partial - Row(6, 6.0)) + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) + } + + test("describe") { + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", "4", "4"), + Row("mean", "33.0", "178.0"), + Row("stddev", "16.583123951777", "10.0"), + Row("min", "16", "164"), + Row("max", "60", "192")) + + val emptyDescribeResult = Seq( + Row("count", "0", "0"), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + // All aggregate value should have been cast to string + describeTwoCols.collect().foreach { row => + assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass) + assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) + } + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) } - test("null count") { + test("apply on query results (SPARK-5462)") { + val df = testData.sqlContext.sql("select key from testData") + checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) + } + + test("inputFiles") { + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") + + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) + + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } + } + + ignore("show") { + // This test case is intended ignored, but to make sure it compiles correctly + testData.select($"*").show() + testData.select($"*").show(1000) + } + + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + + test("showString(negative)") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |only showing top 0 rows + |""".stripMargin + assert(testData.select($"*").showString(-1) === expectedAnswer) + } + + test("showString(0)") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |only showing top 0 rows + |""".stripMargin + assert(testData.select($"*").showString(0) === expectedAnswer) + } + + test("showString: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = """+---------+---------+ + || _1| _2| + |+---------+---------+ + ||[1, 2, 3]|[1, 2, 3]| + ||[2, 3, 4]|[2, 3, 4]| + |+---------+---------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + + test("showString: minimum column width") { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + val expectedAnswer = """+---+---+ + || _1| _2| + |+---+---+ + || 1| 1| + || 2| 2| + |+---+---+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + + test("SPARK-7319 showString") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""".stripMargin + assert(testData.select($"*").showString(1) === expectedAnswer) + } + + test("SPARK-7327 show with empty dataFrame") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |""".stripMargin + assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer) + } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = sqlContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } + + test("SPARK-6899: type should match when using codegen") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + } + } + + test("SPARK-7133: Implement struct, array, and map field accessor") { + assert(complexData.filter(complexData("a")(0) === 2).count() == 1) + assert(complexData.filter(complexData("m")("1") === 1).count() == 1) + assert(complexData.filter(complexData("s")("key") === 1).count() == 1) + assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) + assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1) + } + + test("SPARK-7551: support backticks for DataFrame attribute resolution") { + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( - testData3.groupBy('a).agg('a, count('b)), - Seq(Row(1,0), Row(2, 1)) + df.select(df("`a.b`.c.`d..e`.`f`")), + Row(1) ) + val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( - testData3.groupBy('a).agg('a, count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) + df2.select(df2("`a b`.c.d e.f")), + Row(1) ) + def checkError(testFun: => Unit): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + testFun + } + assert(e.getMessage.contains("syntax error in attribute name:")) + } + checkError(df("`abc.`c`")) + checkError(df("`abc`..d")) + checkError(df("`a`.b.")) + checkError(df("`a.b`.c.`d")) + } + + test("SPARK-7324 dropDuplicates") { + val testData = sqlContext.sparkContext.parallelize( + (2, 1, 2) :: (1, 1, 1) :: + (1, 2, 1) :: (2, 1, 2) :: + (2, 2, 2) :: (2, 2, 1) :: + (2, 1, 1) :: (1, 1, 2) :: + (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") + checkAnswer( - testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), - Row(2, 1, 2, 2, 1) - ) + testData.dropDuplicates(), + Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1), + Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1), + Row(1, 1, 2), Row(1, 2, 2))) checkAnswer( - testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial - Row(1, 1, 2) - ) - } + testData.dropDuplicates(Seq("key", "value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) - test("zero count") { - assert(emptyTableData.count() === 0) + checkAnswer( + testData.dropDuplicates(Seq("value1", "value2")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) checkAnswer( - emptyTableData.agg(count('a), sumDistinct('a)), // non-partial - Row(0, null)) - } + testData.dropDuplicates(Seq("key")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) - test("zero sum") { checkAnswer( - emptyTableData.agg(sum('a)), - Row(null)) - } + testData.dropDuplicates(Seq("value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1))) - test("zero sum distinct") { checkAnswer( - emptyTableData.agg(sumDistinct('a)), - Row(null)) + testData.dropDuplicates(Seq("value2")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) } - test("except") { - checkAnswer( - lowerCaseData.except(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.except(lowerCaseData), Nil) - checkAnswer(upperCaseData.except(upperCaseData), Nil) + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = sqlContext.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = sqlContext.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = sqlContext.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = sqlContext.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = sqlContext.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = sqlContext.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = sqlContext.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + + // only end provided as argument + val res10 = sqlContext.range(10).select("id") + assert(res10.count == 10) + assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res11 = sqlContext.range(-1).select("id") + assert(res11.count == 0) } - test("intersect") { - checkAnswer( - lowerCaseData.intersect(lowerCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") } - test("udf") { - val foo = (a: Int, b: String) => a.toString + b + test("SPARK-8797: sort by float column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } - checkAnswer( - // SELECT *, foo(key, value) FROM testData - testData.select($"*", callUDF(foo, 'key, 'value)).limit(3), - Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil - ) + test("SPARK-8797: sort by double column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() } - test("apply on query results (SPARK-5462)") { - val df = testData.sqlContext.sql("select key from testData") - checkAnswer(df("key"), testData.select('key).collect().toSeq) + test("NaN is greater than all other non-NaN numeric values") { + val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) + val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Float.isNaN(maxFloat.getFloat(0))) + } + + test("SPARK-8072: Better Exception for Duplicate Columns") { + // only one duplicate column present + val e = intercept[org.apache.spark.sql.AnalysisException] { + Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") + } + assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("parquet")) + assert(e.getMessage.contains("column1")) + assert(!e.getMessage.contains("column2")) + + // multiple duplicate columns present + val f = intercept[org.apache.spark.sql.AnalysisException] { + Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") + } + assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("JSON")) + assert(f.getMessage.contains("column1")) + assert(f.getMessage.contains("column3")) + assert(!f.getMessage.contains("column2")) + } + + test("SPARK-6941: Better error message for inserting into RDD-based Table") { + withTempDir { dir => + + val tempParquetFile = new File(dir, "tmp_parquet") + val tempJsonFile = new File(dir, "tmp_json") + + val df = Seq(Tuple1(1)).toDF() + val insertion = Seq(Tuple1(2)).toDF("col") + + // pass case: parquet table (HadoopFsRelation) + df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) + val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) + pdf.registerTempTable("parquet_base") + insertion.write.insertInto("parquet_base") + + // pass case: json table (InsertableRelation) + df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) + val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) + jdf.registerTempTable("json_base") + insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") + + // error cases: insert into an RDD + df.registerTempTable("rdd_base") + val e1 = intercept[AnalysisException] { + insertion.write.insertInto("rdd_base") + } + assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + + // error case: insert into a logical plan that is not a LeafNode + val indirectDS = pdf.select("_1").filter($"_1" > 5) + indirectDS.registerTempTable("indirect_ds") + val e2 = intercept[AnalysisException] { + insertion.write.insertInto("indirect_ds") + } + assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + + // error case: insert into an OneRowRelation + new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") + val e3 = intercept[AnalysisException] { + insertion.write.insertInto("one_row") + } + assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + } + } + + test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + test("SPARK-8609: local DataFrame with random columns should return same value after sort") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF() + checkAnswer(df.sort(rand(33)), df.sort(rand(33))) } + test("SPARK-9083: sort with non-deterministic expressions") { + import org.apache.spark.util.random.XORShiftRandom + + val seed = 33 + val df = (1 to 100).map(Tuple1.apply).toDF("i") + val random = new XORShiftRandom(seed) + val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) + val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) + assert(expected === actual) + } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } + + test("SPARK-9950: correctly analyze grouping/aggregating on struct fields") { + val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") + checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) + } + + test("SPARK-10093: Avoid transformations on executors") { + val df = Seq((1, 1)).toDF("a", "b") + df.where($"a" === 1) + .select($"a", $"b", struct($"b")) + .orderBy("a") + .select(struct($"b")) + .collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 000000000000..77907e91363e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("test simple types") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + } + + test("test struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + } + + test("test nested struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala new file mode 100644 index 000000000000..9080c53c491a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Timestamp, Date} +import java.text.SimpleDateFormat + +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("function current_date") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + // This is a bad test. SPARK-9196 will fix it and re-enable it. + ignore("function current_timestamp") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) + + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } + + test("date format") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), + Row("2015", "2015", "2013")) + + checkAnswer( + df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), + Row("2015", "2015", "2013")) + } + + test("year") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(year($"a"), year($"b"), year($"c")), + Row(2015, 2015, 2013)) + + checkAnswer( + df.selectExpr("year(a)", "year(b)", "year(c)"), + Row(2015, 2015, 2013)) + } + + test("quarter") { + val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(quarter($"a"), quarter($"b"), quarter($"c")), + Row(2, 2, 4)) + + checkAnswer( + df.selectExpr("quarter(a)", "quarter(b)", "quarter(c)"), + Row(2, 2, 4)) + } + + test("month") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(month($"a"), month($"b"), month($"c")), + Row(4, 4, 4)) + + checkAnswer( + df.selectExpr("month(a)", "month(b)", "month(c)"), + Row(4, 4, 4)) + } + + test("dayofmonth") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(dayofmonth($"a"), dayofmonth($"b"), dayofmonth($"c")), + Row(8, 8, 8)) + + checkAnswer( + df.selectExpr("day(a)", "day(b)", "dayofmonth(c)"), + Row(8, 8, 8)) + } + + test("dayofyear") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(dayofyear($"a"), dayofyear($"b"), dayofyear($"c")), + Row(98, 98, 98)) + + checkAnswer( + df.selectExpr("dayofyear(a)", "dayofyear(b)", "dayofyear(c)"), + Row(98, 98, 98)) + } + + test("hour") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(hour($"a"), hour($"b"), hour($"c")), + Row(0, 13, 13)) + + checkAnswer( + df.selectExpr("hour(a)", "hour(b)", "hour(c)"), + Row(0, 13, 13)) + } + + test("minute") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(minute($"a"), minute($"b"), minute($"c")), + Row(0, 10, 10)) + + checkAnswer( + df.selectExpr("minute(a)", "minute(b)", "minute(c)"), + Row(0, 10, 10)) + } + + test("second") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(second($"a"), second($"b"), second($"c")), + Row(0, 15, 15)) + + checkAnswer( + df.selectExpr("second(a)", "second(b)", "second(c)"), + Row(0, 15, 15)) + } + + test("weekofyear") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(weekofyear($"a"), weekofyear($"b"), weekofyear($"c")), + Row(15, 15, 15)) + + checkAnswer( + df.selectExpr("weekofyear(a)", "weekofyear(b)", "weekofyear(c)"), + Row(15, 15, 15)) + } + + test("function date_add") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_add(col("d"), 1)), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + checkAnswer( + df.select(date_add(col("t"), 3)), + Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) + checkAnswer( + df.select(date_add(col("s"), 5)), + Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) + checkAnswer( + df.select(date_add(col("ss"), 7)), + Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) + + checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_ADD(d, 1)"""), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + } + + test("function date_sub") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_sub(col("d"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("t"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("s"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("ss"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(lit(null), 1)).limit(1), Row(null)) + + checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_SUB(d, 1)"""), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + } + + test("time_add") { + val t1 = Timestamp.valueOf("2015-07-31 23:59:59") + val t2 = Timestamp.valueOf("2015-12-31 00:00:00") + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-12-31") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d + $i"), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29")))) + checkAnswer( + df.selectExpr(s"t + $i"), + Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")), + Row(Timestamp.valueOf("2016-02-29 00:00:02")))) + } + + test("time_sub") { + val t1 = Timestamp.valueOf("2015-10-01 00:00:01") + val t2 = Timestamp.valueOf("2016-02-29 00:00:02") + val d1 = Date.valueOf("2015-09-30") + val d2 = Date.valueOf("2016-02-29") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d - $i"), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + checkAnswer( + df.selectExpr(s"t - $i"), + Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), + Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + } + + test("function add_months") { + val d1 = Date.valueOf("2015-08-31") + val d2 = Date.valueOf("2015-02-28") + val df = Seq((1, d1), (2, d2)).toDF("n", "d") + checkAnswer( + df.select(add_months(col("d"), 1)), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + checkAnswer( + df.selectExpr("add_months(d, -1)"), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + } + + test("function months_between") { + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-02-16") + val t1 = Timestamp.valueOf("2014-09-30 23:30:00") + val t2 = Timestamp.valueOf("2015-09-16 12:00:00") + val s1 = "2014-09-15 11:30:00" + val s2 = "2015-10-01 00:00:00" + val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") + checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + } + + test("function last_day") { + val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") + val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") + checkAnswer( + df1.select(last_day(col("d"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + checkAnswer( + df2.select(last_day(col("t"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + } + + test("function next_day") { + val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d") + val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t") + checkAnswer( + df1.select(next_day(col("d"), "MONDAY")), + Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) + checkAnswer( + df2.select(next_day(col("t"), "th")), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) + } + + test("function to_date") { + val d1 = Date.valueOf("2015-07-22") + val d2 = Date.valueOf("2015-07-01") + val t1 = Timestamp.valueOf("2015-07-22 10:00:00") + val t2 = Timestamp.valueOf("2014-12-31 23:59:59") + val s1 = "2015-07-22 10:00:00" + val s2 = "2014-12-31" + val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s") + + checkAnswer( + df.select(to_date(col("t"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.select(to_date(col("d"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.select(to_date(col("s"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + + checkAnswer( + df.selectExpr("to_date(t)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.selectExpr("to_date(d)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.selectExpr("to_date(s)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + } + + test("function trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:00:00")), + (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select(trunc(col("t"), "YY")), + Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) + + checkAnswer( + df.selectExpr("trunc(t, 'Month')"), + Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) + } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + + test("unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + + test("datediff") { + val df = Seq( + (Date.valueOf("2015-07-24"), Timestamp.valueOf("2015-07-24 01:00:00"), + "2015-07-23", "2015-07-23 03:00:00"), + (Date.valueOf("2015-07-25"), Timestamp.valueOf("2015-07-25 02:00:00"), + "2015-07-24", "2015-07-24 04:00:00") + ).toDF("a", "b", "c", "d") + checkAnswer(df.select(datediff(col("a"), col("b"))), Seq(Row(0), Row(0))) + checkAnswer(df.select(datediff(col("a"), col("c"))), Seq(Row(1), Row(1))) + checkAnswer(df.select(datediff(col("d"), col("b"))), Seq(Row(-1), Row(-1))) + checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) + } + + test("from_utc_timestamp") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") + ).toDF("a", "b") + checkAnswer( + df.select(from_utc_timestamp(col("a"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-23 17:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-23 17:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") + ).toDF("a", "b") + checkAnswer( + df.select(to_utc_timestamp(col("a"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-25 07:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-25 07:00:00")))) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f0c939dbb195..f5c5046a8ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,39 +17,39 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext + +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ -class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed - val planned = planner.HashJoin(join) + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val rdd = sql(sqlString) - val physical = rdd.queryExecution.sparkPlan + val df = sql(sqlString) + val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -59,7 +59,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -74,15 +74,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -90,26 +91,83 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } + + test("SortMergeJoin shouldn't work on unsortable columns") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + Seq( + ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } } test("broadcasted hash join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() sql("CACHE TABLE testData") + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } + } + sql("UNCACHE TABLE testData") + } - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - + test("broadcasted hash outer join operator selection") { + ctx.cacheManager.clearCache() + sql("CACHE TABLE testData") + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed - val planned = planner.HashJoin(join) + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -140,10 +198,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where($"a" === 1).as("y") checkAnswer( x.join(y).where($"x.a" === $"y.a"), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } @@ -384,39 +442,37 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() sql("CACHE TABLE testData") - val tmp = conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) sql("UNCACHE TABLE testData") } test("left semi join") { - val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") - checkAnswer(rdd, + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(df, Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala new file mode 100644 index 000000000000..045fea82e4c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("function get_json_object") { + val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") + checkAnswer( + df.selectExpr("get_json_object(a, '$.name')", "get_json_object(a, '$.age')"), + Row("alice", "5")) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala new file mode 100644 index 000000000000..babf8835d254 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -0,0 +1,85 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + + private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") + + before { + df.registerTempTable("ListTablesSuiteTable") + } + + after { + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + } + + test("get all tables") { + checkAnswer( + ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + checkAnswer( + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("getting all Tables with a database name has no impact on returned table names") { + checkAnswer( + ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + checkAnswer( + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("query the returned DataFrame of tables") { + val expectedSchema = StructType( + StructField("tableName", StringType, false) :: + StructField("isTemporary", BooleanType, false) :: Nil) + + Seq(ctx.tables(), sql("SHOW TABLes")).foreach { + case tableDF => + assert(expectedSchema === tableDF.schema) + + tableDF.registerTempTable("tables") + checkAnswer( + sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + Row(true, "ListTablesSuiteTable") + ) + checkAnswer( + ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + Row("tables", true)) + ctx.dropTempTable("tables") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 000000000000..30289c3c1d09 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.SharedSQLContext + +private object MathExpressionsTestData { + case class DoubleData(a: java.lang.Double, b: java.lang.Double) + case class NullDoubles(a: java.lang.Double) +} + +class MathExpressionsSuite extends QueryTest with SharedSQLContext { + import MathExpressionsTestData._ + import testImplicits._ + + private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() + + private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() + + private lazy val nullDoubles = + Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() + + private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = + { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + private def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDegrees") { + testOneToOneMathFunction(toDegrees, math.toDegrees) + checkAnswer( + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) + ) + } + + test("toRadians") { + testOneToOneMathFunction(toRadians, math.toRadians) + checkAnswer( + sql("SELECT radians(0), radians(1), radians(1.5)"), + Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) + ) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil and ceiling") { + testOneToOneMathFunction(ceil, math.ceil) + checkAnswer( + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + Row(0.0, 1.0, 2.0)) + } + + test("conv") { + val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") + checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) + checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) + checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) + checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) + checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4")) + checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16")) + checkAnswer( + df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("factorial") { + val df = (0 to 5).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(factorial('a)), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + checkAnswer( + df.selectExpr("factorial(a)"), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) + ) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum / sign") { + testOneToOneMathFunction[Double](signum, math.signum) + + checkAnswer( + sql("SELECT sign(10), signum(-11)"), + Row(1, -1)) + } + + test("pow / power") { + testTwoToOneMathFunction(pow, pow, math.pow) + + checkAnswer( + sql("SELECT pow(1, 2), power(2, 1)"), + Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) + ) + } + + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) + } + + test("unhex") { + val data = Seq(("1C", "737472696E67")).toDF("a", "b") + checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) + checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log / ln") { + testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) + checkAnswer( + sql("SELECT ln(0), ln(1), ln(1.5)"), + Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) + ) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + } + + test("shift right unsigned") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1), + shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)", + "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + } + + test("binary log") { + val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") + checkAnswer( + df.select(org.apache.spark.sql.functions.log("a"), + org.apache.spark.sql.functions.log(2.0, "a"), + org.apache.spark.sql.functions.log("b")), + Row(math.log(123), math.log(123) / math.log(2), null)) + + checkAnswer( + df.selectExpr("log(a)", "log(2.0, a)", "log(b)"), + Row(math.log(123), math.log(123) / math.log(2), null)) + } + + test("abs") { + val input = + Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) + checkAnswer( + input.toDF("key", "value").select(abs($"key").alias("a")).sort("a"), + input.map(pair => Row(pair._2))) + + checkAnswer( + input.toDF("key", "value").selectExpr("abs(key) a").sort("a"), + input.map(pair => Row(pair._2))) + } + + test("log2") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.select(log2("b") + log2("a")), + Row(1)) + + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + } + + test("sqrt") { + val df = Seq((1, 4)).toDF("a", "b") + checkAnswer( + df.select(sqrt("a"), sqrt("b")), + Row(1.0, 2.0)) + + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) + } + + test("negative") { + checkAnswer( + sql("SELECT negative(1), negative(0), negative(-1)"), + Row(-1, 0, 1)) + } + + test("positive") { + val df = Seq((1, -1, "abc")).toDF("a", "b", "c") + checkAnswer(df.selectExpr("positive(a)"), Row(1)) + checkAnswer(df.selectExpr("positive(b)"), Row(-1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a7f2faa3ecf7..3649c2a97b5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,103 +17,137 @@ package org.apache.spark.sql +import java.util.{Locale, TimeZone} + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[DataFrame]] to be executed + * @param df the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString + def checkExistence(df: DataFrame, exists: Boolean, keywords: String*) { + val outputs = df.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") + assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") + assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") } } } /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + + /** + * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + */ + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } - if (!isSorted) converted.sortBy(_.toString) else converted + if (!isSorted) converted.sortBy(_.toString()) else converted } - val sparkAnswer = try rdd.collect().toSeq catch { + val sparkAnswer = try df.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: - |${rdd.queryExecution} + |${df.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: - |${rdd.logicalPlan} - |== Analyzed Plan == - |${rdd.queryExecution.analyzed} - |== Physical Plan == - |${rdd.queryExecution.executedPlan} + |${df.queryExecution} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) + return None } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(df, expectedAnswer.asScala) match { + case Some(errorMessage) => errorMessage + case None => null } } - - /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468da..795d4e983f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,32 +17,36 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends FunSuite { +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) - expected.update(0, 2147483647) - expected.update(1, "this is a string") - expected.update(2, false) - expected.update(3, null) + expected.setInt(0, 2147483647) + expected.update(1, UTF8String.fromString("this is a string")) + expected.setBoolean(2, false) + expected.setNullAt(3) + val actual1 = Row(2147483647, "this is a string", false, null) - assert(expected.size === actual1.size) + assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected(3) === actual1(3)) + assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) - assert(expected.size === actual2.size) + assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected(3) === actual2(3)) + assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { @@ -50,4 +54,35 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } + + test("get values by field name on Row created via .toDF") { + val row = Seq((1, Seq(1))).toDF("a", "b").first() + assert(row.getAs[Int]("a") === 1) + assert(row.getAs[Seq[Int]]("b") === Seq(1)) + + intercept[IllegalArgumentException]{ + row.getAs[Int]("c") + } + } + + test("float NaN == NaN") { + val r1 = Row(Float.NaN) + val r2 = Row(Float.NaN) + assert(r1 === r2) + } + + test("double NaN == NaN") { + val r1 = Row(Double.NaN) + val r2 = Row(Double.NaN) + assert(r1 === r2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala new file mode 100644 index 000000000000..2e33777f14ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf._ + +class SQLConfEntrySuite extends SparkFunSuite { + + val conf = new SQLConf + + test("intConf") { + val key = "spark.sql.SQLConfEntrySuite.int" + val confEntry = SQLConfEntry.intConf(key) + assert(conf.getConf(confEntry, 5) === 5) + + conf.setConf(confEntry, 10) + assert(conf.getConf(confEntry, 5) === 10) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5) === 20) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be int, but was abc") + } + + test("longConf") { + val key = "spark.sql.SQLConfEntrySuite.long" + val confEntry = SQLConfEntry.longConf(key) + assert(conf.getConf(confEntry, 5L) === 5L) + + conf.setConf(confEntry, 10L) + assert(conf.getConf(confEntry, 5L) === 10L) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5L) === 20L) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be long, but was abc") + } + + test("booleanConf") { + val key = "spark.sql.SQLConfEntrySuite.boolean" + val confEntry = SQLConfEntry.booleanConf(key) + assert(conf.getConf(confEntry, false) === false) + + conf.setConf(confEntry, true) + assert(conf.getConf(confEntry, false) === true) + + conf.setConfString(key, "true") + assert(conf.getConfString(key, "false") === "true") + assert(conf.getConfString(key) === "true") + assert(conf.getConf(confEntry, false) === true) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be boolean, but was abc") + } + + test("doubleConf") { + val key = "spark.sql.SQLConfEntrySuite.double" + val confEntry = SQLConfEntry.doubleConf(key) + assert(conf.getConf(confEntry, 5.0) === 5.0) + + conf.setConf(confEntry, 10.0) + assert(conf.getConf(confEntry, 5.0) === 10.0) + + conf.setConfString(key, "20.0") + assert(conf.getConfString(key, "5.0") === "20.0") + assert(conf.getConfString(key) === "20.0") + assert(conf.getConf(confEntry, 5.0) === 20.0) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be double, but was abc") + } + + test("stringConf") { + val key = "spark.sql.SQLConfEntrySuite.string" + val confEntry = SQLConfEntry.stringConf(key) + assert(conf.getConf(confEntry, "abc") === "abc") + + conf.setConf(confEntry, "abcd") + assert(conf.getConf(confEntry, "abc") === "abcd") + + conf.setConfString(key, "abcde") + assert(conf.getConfString(key, "abc") === "abcde") + assert(conf.getConfString(key) === "abcde") + assert(conf.getConf(confEntry, "abc") === "abcde") + } + + test("enumConf") { + val key = "spark.sql.SQLConfEntrySuite.enum" + val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a")) + assert(conf.getConf(confEntry) === "a") + + conf.setConf(confEntry, "b") + assert(conf.getConf(confEntry) === "b") + + conf.setConfString(key, "c") + assert(conf.getConfString(key, "a") === "c") + assert(conf.getConfString(key) === "c") + assert(conf.getConf(confEntry) === "c") + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "d") + } + assert(e.getMessage === s"The value of $key should be one of a, b, c, but was d") + } + + test("stringSeqConf") { + val key = "spark.sql.SQLConfEntrySuite.stringSeq" + val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq", + defaultValue = Some(Nil)) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c")) + + conf.setConf(confEntry, Seq("a", "b", "c", "d")) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d")) + + conf.setConfString(key, "a,b,c,d,e") + assert(conf.getConfString(key, "a,b,c") === "a,b,c,d,e") + assert(conf.getConfString(key) === "a,b,c,d,e") + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bf73d0c7074a..7699adadd9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,68 +17,71 @@ package org.apache.spark.sql -import org.scalatest.FunSuiteLike +import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.test._ -/* Implicits */ -import TestSQLContext._ - -class SQLConfSuite extends QueryTest with FunSuiteLike { - - val testKey = "test.key.0" - val testVal = "test.val.0" +class SQLConfSuite extends QueryTest with SharedSQLContext { + private val testKey = "test.key.0" + private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(TestSQLContext.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") == "true") + val newContext = new SQLContext(ctx.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - conf.clear() - assert(getAllConfs.size === 0) + ctx.conf.clear() + assert(ctx.getAllConfs.size === 0) - setConf(testKey, testVal) - assert(getConf(testKey) == testVal) - assert(getConf(testKey, testVal + "_") == testVal) - assert(getAllConfs.contains(testKey)) + ctx.setConf(testKey, testVal) + assert(ctx.getConf(testKey) === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.getConf(testKey) == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getAllConfs.contains(testKey)) + assert(ctx.getConf(testKey) == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) - conf.clear() + ctx.conf.clear() } test("parse SQL set commands") { - conf.clear() + ctx.conf.clear() sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(getConf("some.property", "0") == "20") + assert(ctx.getConf("some.property", "0") === "20") sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(getConf(key, "0") == vs) + assert(ctx.getConf(key, "0") === vs) sql(s"set $key=") - assert(getConf(key, "0") == "") + assert(ctx.getConf(key, "0") === "") - conf.clear() + ctx.conf.clear() } test("deprecated property") { - conf.clear() + ctx.conf.clear() sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + assert(ctx.conf.numShufflePartitions === 10) + } + + test("invalid conf value") { + ctx.conf.clear() + val e = intercept[IllegalArgumentException] { + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + } + assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala new file mode 100644 index 000000000000..007be1295077 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext + +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { + + override def afterAll(): Unit = { + try { + SQLContext.setLastInstantiatedContext(ctx) + } finally { + super.afterAll() + } + } + + test("getOrCreate instantiates SQLContext") { + SQLContext.clearLastInstantiatedContext() + val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + assert(sqlContext != null, "SQLContext.getOrCreate returned null") + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") + } + + test("getOrCreate gets last explicitly instantiated SQLContext") { + SQLContext.clearLastInstantiatedContext() + val sqlContext = new SQLContext(ctx.sparkContext) + assert(SQLContext.getOrCreate(ctx.sparkContext) != null, + "SQLContext.getOrCreate after explicitly created SQLContext returned null") + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e18ba287e868..9e172b2c264c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,43 +17,187 @@ package org.apache.spark.sql -import java.util.TimeZone +import java.math.MathContext +import java.sql.Timestamp + +import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ -import org.scalatest.BeforeAndAfterAll +/** A SQL Dialect for testing purpose, and it can not be nested type */ +class MyDialect extends DefaultParserDialect -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ -/* Implicits */ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ + setupTestData() + test("having clause") { + Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") + checkAnswer( + sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), + Row("one", 6) :: Row("three", 3) :: Nil) + } -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { - // Make sure the tables are loaded. - TestData + test("SPARK-8010: promote numeric to string") { + val df = Seq((1, 1)).toDF("key", "value") + df.registerTempTable("src") + val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") + val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") - var origZone: TimeZone = _ - override protected def beforeAll() { - origZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + checkAnswer(queryCaseWhen, Row("1.0") :: Nil) + checkAnswer(queryCoalesce, Row("1") :: Nil) } - override protected def afterAll() { - TimeZone.setDefault(origZone) + test("show functions") { + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + } + + test("describe functions") { + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql');", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + } + + test("SPARK-6743: no columns from cache") { + Seq( + (83, 0, 38), + (26, 0, 79), + (43, 81, 24) + ).toDF("a", "b", "c").registerTempTable("cachedData") + + sqlContext.cacheTable("cachedData") + checkAnswer( + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + Row(0) :: Row(81) :: Nil) + } + + test("self join with aliases") { + Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + + checkAnswer( + sql( + """ + |SELECT x.str, COUNT(*) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("support table.star") { + checkAnswer( + sql( + """ + |SELECT r.* + |FROM testData l join testData2 r on (l.key = r.a) + """.stripMargin), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } + + test("self join with alias in agg") { + Seq(1, 2, 3) + .map(i => (i, i.toString)) + .toDF("int", "str") + .groupBy("str") + .agg($"str", count("str").as("strCount")) + .registerTempTable("df") + + checkAnswer( + sql( + """ + |SELECT x.str, SUM(x.strCount) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("SPARK-8668 expr function") { + checkAnswer(Seq((1, "Bobby G.")) + .toDF("id", "name") + .select(expr("length(name)"), expr("abs(id)")), Row(8, 1)) + + checkAnswer(Seq((1, "building burrito tunnels"), (1, "major projects")) + .toDF("id", "saying") + .groupBy(expr("length(saying)")) + .count(), Row(24, 1) :: Row(14, 1) :: Nil) + } + + test("SQL Dialect Switching to a new SQL parser") { + val newContext = new SQLContext(sqlContext.sparkContext) + newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) + assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) + } + + test("SQL Dialect Switch to an invalid parser with alias") { + val newContext = new SQLContext(sqlContext.sparkContext) + newContext.sql("SET spark.sql.dialect=MyTestClass") + intercept[DialectException] { + newContext.sql("SELECT 1") + } + // test if the dialect set back to DefaultSQLDialect + assert(newContext.getSQLDialect().getClass === classOf[DefaultParserDialect]) } test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) + Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } + test("SPARK-7158 collect and take return different results") { + import java.util.UUID + + val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") + // we except the id is materialized once + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) + + val dfWithId = df.withColumn("id", idUDF()) + // Make a new DataFrame (actually the same reference to the old one) + val cached = dfWithId.cache() + // Trigger the cache + val d0 = dfWithId.collect() + val d1 = cached.collect() + val d2 = cached.collect() + + // Since the ID is only materialized once, then all of the records + // should come from the cache, not by re-computing. Otherwise, the ID + // will be different + assert(d0.map(_(0)) === d2.map(_(0))) + assert(d0.map(_(1)) === d2.map(_(1))) + + assert(d1.map(_(0)) === d2.map(_(0))) + assert(d1.map(_(1)) === d2.map(_(1))) + } + test("grouping on nested fields") { - jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.parallelize( + """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -69,23 +213,147 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(1, 1) :: Nil) } - test("SPARK-3176 Added Parser of SQL ABS()") { - checkAnswer( - sql("SELECT ABS(-1.3)"), - Row(1.3)) - checkAnswer( - sql("SELECT ABS(0.0)"), - Row(0.0)) + test("SPARK-6201 IN type conversion") { + sqlContext.read.json( + sqlContext.sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + .registerTempTable("d") + checkAnswer( - sql("SELECT ABS(2.5)"), - Row(2.5)) + sql("select * from d where d.a in (1,2)"), + Seq(Row("1"), Row("2"))) + } + + test("SPARK-8828 sum should return null if all input values are null") { + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + } + + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: aggregate.TungstenAggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) } test("aggregation with codegen") { - val originalValue = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") - sql("SELECT key FROM testData GROUP BY key").collect() - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + // Prepare a table that we can group some rows. + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) + .registerTempTable("testData3x") + + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) + } } test("Add Parser of SQL COALESCE()") { @@ -94,7 +362,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(1)) checkAnswer( sql("SELECT COALESCE(null, 1, 1.5)"), - Row(1.toDouble)) + Row(BigDecimal(1))) checkAnswer( sql("SELECT COALESCE(null, null, null)"), Row(null)) @@ -142,27 +410,36 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3173 Timestamp support in the parser") { + (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") + checkAnswer(sql( - "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.001'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' - AND time>'1970-01-01 00:00:00.001'"""), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) + "SELECT time FROM timestamps WHERE '1969-12-31 16:00:00.001'=time"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + """SELECT time FROM timestamps WHERE time<'1969-12-31 16:00:00.003' + AND time>'1969-12-31 16:00:00.001'"""), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) + + checkAnswer(sql( + """ + |SELECT time FROM timestamps + |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') + """.stripMargin), + Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -178,7 +455,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3,1), Row(3,2)) + Seq(Row(3, 1), Row(3, 2)) + ) + } + + test("left semi greater than predicate and equal operator") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) ) } @@ -195,16 +484,33 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1,3), Row(2,3), Row(3,3))) + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1,2), Row(2,2), Row(3,2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + def literalInAggTest(): Unit = { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + literalInAggTest() + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + literalInAggTest() + } } test("aggregates with nulls") { @@ -226,22 +532,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1")) } - def sortTest() = { + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), @@ -269,17 +575,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { + sortTest() + } } test("external sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { + sortTest() + } + } + + test("SPARK-6927 sorting with codegen on") { + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true") { + sortTest() + } + } + + test("SPARK-6927 external sorting with codegen on") { + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true") { + sortTest() + } } test("limit") { @@ -296,9 +614,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("CTE feature") { + checkAnswer( + sql("with q1 as (select * from testData limit 10) select * from q1"), + testData.take(10).toSeq) + + checkAnswer( + sql(""" + |with q1 as (select * from testData where key= '5'), + |q2 as (select * from testData where key = '4') + |select * from q1 union all select * from q2""".stripMargin), + Row(5, "5") :: Row(4, "4") :: Nil) + + } + + test("Allow only a single WITH clause per query") { + intercept[RuntimeException] { + sql( + "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + } + } + + test("date row") { + checkAnswer(sql( + """select cast("2015-01-28" as date) from testData limit 1"""), + Row(java.sql.Date.valueOf("2015-01-28")) + ) + } + test("from follow multiple brackets") { checkAnswer(sql( - "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), + """ + |select key from ((select * from testData limit 1) + | union all (select * from testData limit 1)) x limit 1 + """.stripMargin), Row(1) ) @@ -308,7 +657,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) checkAnswer(sql( - "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), + """ + |select key from + | (select * from testData limit 1 union all select * from testData limit 1) x + | limit 1 + """.stripMargin), Row(1) ) } @@ -322,7 +675,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq(Row(2147483645.0,1), Row(2.0,2))) + Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { @@ -355,10 +708,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1, 0), Row(2, 1))) checkAnswer( - sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), + sql( + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } + test("count of empty table") { + withTempTable("t") { + Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") + checkAnswer( + sql("select count(a) from t"), + Row(0)) + } + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), @@ -386,10 +749,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil) } test("inner join, no matches") { @@ -421,7 +784,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - ignore("cartesian product join") { + test("cartesian product join") { checkAnswer( testData3.join(testData3), Row(1, null, 1, null) :: @@ -592,7 +955,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. - intercept[TreeNodeException[_]] { + intercept[AnalysisException] { sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() } } @@ -622,7 +985,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SET commands semantics using sql()") { - conf.clear() + sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -634,27 +997,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Row(s"$testKey=$testVal"), - Row(s"${testKey + testKey}=${testVal + testVal}")) + Row(testKey, testVal), + Row(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) checkAnswer( sql(s"SET $nonexistentKey"), - Row(s"$nonexistentKey=") + Row(nonexistentKey, "") ) - conf.clear() + sqlContext.conf.clear() + } + + test("SET commands with illegal or inappropriate argument") { + sqlContext.conf.clear() + // Set negative mapred.reduce.tasks for automatically determing + // the number of reducers is not supported + intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) + intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) + intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) + sqlContext.conf.clear() } test("apply schema") { @@ -672,7 +1045,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = applySchema(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -702,7 +1075,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = applySchema(rowRDD2, schema2) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -727,7 +1100,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = applySchema(rowRDD3, schema2) + val df3 = sqlContext.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -772,7 +1145,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person.rdd, schemaWithMeta) + val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -783,11 +1156,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { - udf.register("len", (s: String) => s.length) + sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -808,10 +1182,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("throw errors for non-aggregate attributes with aggregation") { def checkAggregation(query: String, isInvalidQuery: Boolean = true) { if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n") + val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) + assert(e.getMessage contains "group by") } else { // Should not throw sql(query).queryExecution.analyzed @@ -841,19 +1213,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -886,11 +1258,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) checkAnswer( - sql("SELECT -5.2"), Row(-5.2) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( @@ -970,9 +1342,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - jsonRDD(data).registerTempTable("records") - sql("SELECT `key?number1` FROM records") + val data = sqlContext.sparkContext.parallelize( + Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + sqlContext.read.json(data).registerTempTable("records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { @@ -1012,45 +1385,336 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-4322 Grouping field with struct field as sub expression") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - dropTempTable("data") + sqlContext.dropTempTable("data") - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json( + sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - dropTempTable("data") + sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) + Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } test("Supporting relational operator '<=>' in Spark SQL") { - val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil - val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) - rdd1.registerTempTable("nulldata1") - val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil - val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) - rdd2.registerTempTable("nulldata2") + val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil + val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + rdd1.toDF().registerTempTable("nulldata1") + val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil + val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) } test("Multi-column COUNT(DISTINCT ...)") { - val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.registerTempTable("distinctData") + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } + + test("SPARK-4699 case sensitivity SQL query") { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + rdd.toDF().registerTempTable("testTable1") + checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) + sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) + } + + test("SPARK-6145: ORDER BY test for nested fields") { + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + .registerTempTable("nestedOrder") + + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + } + + test("SPARK-6145: special cases") { + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) + } + + test("SPARK-6898: complete support for special chars in column names") { + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } + + test("SPARK-6583 order by aggregated function") { + Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2) + .toDF("a", "b").registerTempTable("orderByData") + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil) + + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + 1 + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + } + + test("SPARK-7952: fix the equality check between boolean and numeric types") { + withTempTable("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } + } + + test("SPARK-7067: order by queries for complex ExtractValue chain") { + withTempTable("t") { + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) + } + } + + test("SPARK-8782: ORDER BY NULL") { + withTempTable("t") { + Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") + checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + } + } + + test("SPARK-8837: use keyword in column name") { + withTempTable("t") { + val df = Seq(1 -> "a").toDF("count", "sort") + checkAnswer(df.filter("count > 0"), Row(1, "a")) + df.registerTempTable("t") + checkAnswer(sql("select count, sort from t"), Row(1, "a")) + } + } + + test("SPARK-8753: add interval type") { + import org.apache.spark.unsafe.types.CalendarInterval + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds") + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) + withTempPath(f => { + // Currently we don't yet support saving out values of interval data type. + val e = intercept[AnalysisException] { + df.write.json(f.getCanonicalPath) + } + e.message.contains("Cannot save interval data type into external storage") + }) + + def checkIntervalParseError(s: String): Unit = { + val e = intercept[AnalysisException] { + sql(s) + } + e.message.contains("at least one time unit should be given for interval literal") + } + + checkIntervalParseError("select interval") + // Currently we don't yet support nanosecond + checkIntervalParseError("select interval 23 nanosecond") + } + + test("SPARK-8945: add and subtract expressions for interval type") { + import org.apache.spark.unsafe.types.CalendarInterval + import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) + + checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) + + checkAnswer(df.select(df("i") - new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) + + // unary minus + checkAnswer(df.select(-df("i")), + Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) + } + + test("aggregation with codegen updates peak execution memory") { + withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + } + + test("decimal precision with multiply/division") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(null)) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + + test("SPARK-10215 Div of Decimal returns null") { + val d = Decimal(1.12321) + val df = Seq((d, 1)).toDF("a", "b") + + checkAnswer( + df.selectExpr("b * a / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a / b / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a + b"), + Seq(Row(BigDecimal(2.12321)))) + checkAnswer( + df.selectExpr("b * a - b"), + Seq(Row(BigDecimal(0.12321)))) + checkAnswer( + df.selectExpr("b * a * b"), + Seq(Row(d.toBigDecimal))) + } + + test("precision smaller than scale") { + checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + + test("external sorting updates peak execution memory") { + withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + sortTest() + } + } + } + + test("SPARK-9511: error with table starting with number") { + withTempTable("1one") { + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("num", "str") + .registerTempTable("1one") + checkAnswer(sql("select count(num) from 1one"), Row(10)) + } + } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } + + test("SPARK-10130 type coercion for IF should have children resolved first") { + val df = Seq((1, 1), (-1, 1)).toDF("key", "value") + df.registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index a015884bae28..295f02f9a7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -74,40 +72,42 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends FunSuite { +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectData") + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) + Seq(data).toDF().registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), + new Timestamp(12345), Seq(1, 2, 3))) } test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectNullData") + Seq(data).toDF().registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectNullData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectOptionalData") + Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + val result = sql("SELECT data FROM reflectBinary") + .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -123,20 +123,19 @@ class ScalaReflectionRelationSuite extends FunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectComplexData") + Seq(data).toDF().registerTempTable("reflectComplexData") assert(sql("SELECT * FROM reflectComplexData").collect().head === - new GenericRow(Array[Any]( + Row( Seq(1, 2, 3), Seq(1, 2, null), Map(1 -> 10L, 2 -> 20L), Map(1 -> 10L, 2 -> 20L, 3 -> null), - new GenericRow(Array[Any]( + Row( Seq(10, 20, 30), Seq(10, 20, null), Map(10 -> 100L, 20 -> 200L), Map(10 -> 100L, 20 -> 200L, 30 -> null), - new GenericRow(Array[Any](null, "abc"))))))) + Row(null, "abc")))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala new file mode 100644 index 000000000000..45d0ee4a8e74 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext + +class SerializationSuite extends SparkFunSuite with SharedSQLContext { + + test("[SPARK-5235] SQLContext should be serializable") { + val _sqlContext = new SQLContext(sqlContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala new file mode 100644 index 000000000000..b91438baea06 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.Decimal + + +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("string concat") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat($"a", $"b"), concat($"a", $"b", $"c")), + Row("ab", null)) + + checkAnswer( + df.selectExpr("concat(a, b)", "concat(a, b, c)"), + Row("ab", null)) + } + + test("string concat_ws") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat_ws("||", $"a", $"b", $"c")), + Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws('||', a, b, c)"), + Row("a||b")) + } + + test("string Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + + test("string regex_replace / regex_extract") { + val df = Seq( + ("100-200", "(\\d+)-(\\d+)", "300"), + ("100-200", "(\\d+)-(\\d+)", "400"), + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") + + checkAnswer( + df.select( + regexp_replace($"a", "(\\d+)", "num"), + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), + Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil) + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection followed by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + checkAnswer( + df.filter("isnotnull(a)").selectExpr( + "regexp_replace(a, b, c)", + "regexp_extract(a, b, 1)"), + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) + } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii($"b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64($"a"), unbase64($"b")), + Row("AQIDBA==", bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string / binary substring function") { + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("1世3", Array[Byte](1, 2, 3, 4))).toDF("a", "b") + checkAnswer(df.select(substring($"a", 1, 2)), Row("1世")) + checkAnswer(df.select(substring($"b", 2, 2)), Row(Array[Byte](2,3))) + checkAnswer(df.selectExpr("substring(a, 1, 2)"), Row("1世")) + // scalastyle:on + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select(encode($"a", "utf-8"), decode($"c", "utf-8")), + Row(bytes, "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), + Row(bytes, "大千世界")) + // scalastyle:on + } + + test("string translate") { + val df = Seq(("translate", "")).toDF("a", "b") + checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) + checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae")) + } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(format_string("aa%d%s", $"b", $"c")), + Row("aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("soundex function") { + val df = Seq(("MARY", "SU")).toDF("l", "r") + checkAnswer( + df.select(soundex($"l"), soundex($"r")), Row("M600", "S000")) + + checkAnswer( + df.selectExpr("SoundEx(l)", "SoundEx(r)"), Row("M600", "S000")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", "aa")), + Row(1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string substring_index function") { + val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c") + checkAnswer( + df.select(substring_index($"a", ".", 2)), + Row("www.apache")) + checkAnswer( + df.selectExpr("substring_index(a, '.', 2)"), + Row("www.apache") + ) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select(locate("aa", $"a"), locate("aa", $"a", 1)), + Row(1, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select(lpad($"a", 1, "c"), lpad($"a", 5, "??"), rpad($"a", 1, "c"), rpad($"a", 5, "??")), + Row("h", "???hi", "h", "hi???")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select(repeat($"a", 2)), + Row("hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse($"b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length($"b")), + Row(3, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("initcap function") { + val df = Seq(("ab", "a B")).toDF("l", "r") + checkAnswer( + df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B")) + + checkAnswer( + df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B")) + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select(format_number($"f", 4)), + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection follows by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + val df2 = Seq((5L, 4), (4L, 3), (4L, 3), (4L, 3), (3L, 2)).toDF("a", "b") + checkAnswer( + df2.filter("b>0").selectExpr("format_number(a, b)"), + Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index dd781169ca57..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.Timestamp - -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.test._ - -/* Implicits */ -import org.apache.spark.sql.test.TestSQLContext._ - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDataFrame - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDataFrame - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDataFrame - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDataFrame - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDataFrame - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDataFrame - testData3.registerTempTable("testData3") - - val emptyTableData = logical.LocalRelation($"a".int, $"b".int) - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDataFrame - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDataFrame - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: - ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ) - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil) - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDataFrame - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => - TimestampField(new Timestamp(i)) - }) - timestamps.registerTempTable("timestamps") - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil) - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil) - salary.registerTempTable("salary") - - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toDataFrame - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 95923f9aad93..eb275af101e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,37 +17,178 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl.StringToColumn -import org.apache.spark.sql.test._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -/* Implicits */ -import TestSQLContext._ +private case class FunctionResult(f1: String, f2: String) -case class FunctionResult(f1: String, f2: String) +class UDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ -class UDFSuite extends QueryTest { + test("built-in fixed arity expressions") { + val df = ctx.emptyDataFrame + df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") + } + + test("built-in vararg expressions") { + val df = Seq((1, 2)).toDF("a", "b") + df.selectExpr("array(a, b)") + df.selectExpr("struct(a, b)") + } + + test("built-in expressions with multiple constructors") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect() + } + + test("count") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(a)") + } + + test("count distinct") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(distinct a)") + } + + test("SPARK-8003 spark_partition_id") { + val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") + df.registerTempTable("tmp_table") + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + ctx.dropTempTable("tmp_table") + } + + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + + test("error reporting for incorrect number of arguments") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("substr('abcd', 2, 3, 4)") + } + assert(e.getMessage.contains("arguments")) + } + + test("error reporting for undefined functions") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("a_function_that_does_not_exist()") + } + assert(e.getMessage.contains("undefined function")) + } test("Simple UDF") { - udf.register("strLenScala", (_: String).length) + ctx.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - udf.register("random0", () => { Math.random()}) + ctx.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_:Int)) + ctx.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } + test("UDF in a WHERE") { + ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + + val df = ctx.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("integerData") + + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } + + test("UDFs everywhere") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } + test("struct UDF") { - udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } + + test("udf that is transformed") { + ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } + + test("type coercion for udf inputs") { + ctx.udf.register("intExpected", (x: Int) => x) + // pass a decimal to intExpected. + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala new file mode 100644 index 000000000000..219435dff5bc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.types.UTF8String + +class UnsafeRowSuite extends SparkFunSuite { + + test("bitset width calculation") { + assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) + assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(32) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(64) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(65) === 16) + assert(UnsafeRow.calculateBitSetWidthInBytes(128) === 16) + } + + test("writeToStream") { + val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) + val arrayBackedUnsafeRow: UnsafeRow = + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) + assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) + val bytesFromArrayBackedRow: Array[Byte] = { + val baos = new ByteArrayOutputStream() + arrayBackedUnsafeRow.writeToStream(baos, null) + baos.toByteArray + } + val bytesFromOffheapRow: Array[Byte] = { + val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) + try { + Platform.copyMemory( + arrayBackedUnsafeRow.getBaseObject, + arrayBackedUnsafeRow.getBaseOffset, + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + arrayBackedUnsafeRow.getSizeInBytes + ) + val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + offheapUnsafeRow.pointTo( + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + 3, // num fields + arrayBackedUnsafeRow.getSizeInBytes + ) + assert(offheapUnsafeRow.getBaseObject === null) + val baos = new ByteArrayOutputStream() + val writeBuffer = new Array[Byte](1024) + offheapUnsafeRow.writeToStream(baos, writeBuffer) + baos.toByteArray + } finally { + MemoryAllocator.UNSAFE.free(offheapRowPage) + } + } + + assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } + + test("createFromByteArray and copyFrom") { + val row = InternalRow(1, UTF8String.fromString("abc")) + val converter = UnsafeProjection.create(Array[DataType](IntegerType, StringType)) + val unsafeRow = converter.apply(row) + + val emptyRow = UnsafeRow.createFromByteArray(64, 2) + val buffer = emptyRow.getBaseObject + + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + // make sure we reuse the buffer. + assert(emptyRow.getBaseObject === buffer) + + // make sure we really copied the input row. + unsafeRow.setInt(0, 2) + assert(emptyRow.getInt(0) === 1) + + val longString = UTF8String.fromString((1 to 100).map(_ => "abc").reduce(_ + _)) + val row2 = InternalRow(3, longString) + val unsafeRow2 = converter.apply(row2) + + // make sure we can resize. + emptyRow.copyFrom(unsafeRow2) + assert(emptyRow.getSizeInBytes() === unsafeRow2.getSizeInBytes) + assert(emptyRow.getInt(0) === 3) + assert(emptyRow.getUTF8String(1) === longString) + // make sure we really resized. + assert(emptyRow.getBaseObject != buffer) + + // make sure we can still handle small rows after resize. + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 0696a2335e63..b6d279ae4726 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -19,10 +19,15 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.OpenHashSet @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) @@ -43,29 +48,31 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) } } - override def userClass = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + + private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + private lazy val pointsRDD = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -83,10 +90,71 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } + + + test("UDTs with Parquet") { + val tempDir = Utils.createTempDir() + tempDir.delete() + pointsRDD.write.parquet(tempDir.getCanonicalPath) + } + + test("Repartition UDTs with Parquet") { + val tempDir = Utils.createTempDir() + tempDir.delete() + pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) + } + + // Tests to make sure that all operators correctly convert types on the way out. + test("Local UDTs") { + val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[MyDenseVector](1) + df.take(1)(0).getAs[MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + } + + test("HyperLogLogUDT") { + val hyperLogLogUDT = HyperLogLogUDT + val hyperLogLog = new HyperLogLog(0.4) + (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) + + val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) + assert(actual.cardinality() === hyperLogLog.cardinality()) + assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) + } + + test("OpenHashSetUDT") { + val openHashSetUDT = new OpenHashSetUDT(IntegerType) + val set = new OpenHashSet[Int] + (1 to 10).foreach(i => set.add(i)) + + val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) + assert(actual.iterator.toSet === set.iterator.toSet) + } + + test("UDTs with JSON") { + val data = Seq( + "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", + "{\"id\":2,\"vec\":[2.25,4.5,8.75]}" + ) + val schema = StructType(Seq( + StructField("id", IntegerType, false), + StructField("vec", new MyDenseVectorUDT, false) + )) + + val stringRDD = ctx.sparkContext.parallelize(data) + val jsonRDD = ctx.read.schema(schema).json(stringRDD) + checkAnswer( + jsonRDD, + Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index be2b34de077c..d0430d2a60e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -17,52 +17,94 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ -class ColumnStatsSuite extends FunSuite { - testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) - - def testColumnStats[T <: NativeType, U <: ColumnStats]( +class ColumnStatsSuite extends SparkFunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, + createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + + def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) - val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) - assertResult(10, "Wrong null count")(stats(2)) - assertResult(20, "Wrong row count")(stats(3)) - assertResult(stats(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 87e608a8853d..8f024690efd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,23 +18,29 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite +import com.esotericsoftware.kryo.io.{Input, Output} +import com.esotericsoftware.kryo.{Kryo, Serializer} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String -class ColumnTypeSuite extends FunSuite with Logging { - val DEFAULT_BUFFER_SIZE = 512 + +class ColumnTypeSuite extends SparkFunSuite with Logging { + private val DEFAULT_BUFFER_SIZE = 512 + private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, - STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, + MAP_GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -44,8 +50,8 @@ class ColumnTypeSuite extends FunSuite with Logging { } test("actualSize") { - def checkActualSize[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def checkActualSize[JvmType]( + columnType: ColumnType[JvmType], value: JvmType, expected: Int): Unit = { @@ -56,26 +62,24 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) - checkActualSize(SHORT, Short.MaxValue, 2) - checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, new Date(0L), 8) - checkActualSize(TIMESTAMP, new Timestamp(0L), 12) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) + checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) + checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -83,23 +87,34 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(INT)(_.putInt(_), _.getInt) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, - (buffer: ByteBuffer, string: String) => { + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - val bytes = string.getBytes("utf-8") + testNativeColumnType(FIXED_DECIMAL(15, 10))( + (buffer: ByteBuffer, decimal: Decimal) => { + buffer.putLong(decimal.toUnscaledLong) + }, + (buffer: ByteBuffer) => { + Decimal(buffer.getLong(), 15, 10) + }) + + + testNativeColumnType(STRING)( + (buffer: ByteBuffer, string: UTF8String) => { + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, @@ -107,10 +122,10 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String.fromBytes(bytes) }) - testColumnType[BinaryType.type, Array[Byte]]( + testColumnType[Array[Byte]]( BINARY, (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) @@ -127,7 +142,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val obj = Map(1 -> "spark", 2 -> "sql") val serializedObj = SparkSqlSerializer.serialize(obj) - GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) + MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) buffer.rewind() val length = buffer.getInt() @@ -144,20 +159,55 @@ class ColumnTypeSuite extends FunSuite with Logging { assertResult(obj, "Deserialized object didn't equal to the original object") { buffer.rewind() - SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) } } - def testNativeColumnType[T <: NativeType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType): Unit = { + test("CUSTOM") { + val conf = new SparkConf() + conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") + val serializer = new SparkSqlSerializer(conf).newInstance() - testColumnType[T, T#JvmType](columnType, putter, getter) + val buffer = ByteBuffer.allocate(512) + val obj = CustomClass(Int.MaxValue, Long.MaxValue) + val serializedObj = serializer.serialize(obj).array() + + MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) + buffer.rewind() + + val length = buffer.getInt + assert(length === serializedObj.length) + assert(13 == length) // id (1) + int (4) + long (8) + + val genericSerializedObj = SparkSqlSerializer.serialize(obj) + assert(length != genericSerializedObj.length) + assert(length < genericSerializedObj.length) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + serializer.deserialize(ByteBuffer.wrap(bytes)) + } + + buffer.rewind() + buffer.putInt(serializedObj.length).put(serializedObj) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + buffer.rewind() + serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) + } + } + + def testNativeColumnType[T <: AtomicType]( + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, + getter: (ByteBuffer) => T#InternalType): Unit = { + + testColumnType[T#InternalType](columnType, putter, getter) } - def testColumnType[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def testColumnType[JvmType]( + columnType: ColumnType[JvmType], putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType): Unit = { @@ -206,4 +256,36 @@ class ColumnTypeSuite extends FunSuite with Logging { if (sb.nonEmpty) sb.setLength(sb.length - 1) sb.toString() } + + test("column type for decimal types with different precision") { + (1 to 18).foreach { i => + assertResult(FIXED_DECIMAL(i, 0)) { + ColumnType(DecimalType(i, 0)) + } + } + + assertResult(GENERIC(DecimalType(19, 0))) { + ColumnType(DecimalType(19, 0)) + } + } +} + +private[columnar] final case class CustomClass(a: Int, b: Long) + +private[columnar] object CustomerSerializer extends Serializer[CustomClass] { + override def write(kryo: Kryo, output: Output, t: CustomClass) { + output.writeInt(t.a) + output.writeLong(t.b) + } + override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { + val a = input.readInt() + val b = input.readLong() + CustomClass(a, b) + } +} + +private[columnar] final class Registrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[CustomClass], CustomerSerializer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index f941465fa3e3..79bb7d072feb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -19,21 +19,19 @@ package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{DataType, NativeType} +import org.apache.spark.sql.types.{DataType, Decimal, AtomicType} +import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { - def makeNullRow(length: Int) = { + def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row } - def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) @@ -41,20 +39,18 @@ object ColumnarTestUtils { } (columnType match { - case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte - case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort - case INT => Random.nextInt() - case LONG => Random.nextLong() - case FLOAT => Random.nextFloat() - case DOUBLE => Random.nextDouble() - case STRING => Random.nextString(Random.nextInt(32)) - case BOOLEAN => Random.nextBoolean() - case BINARY => randomBytes(Random.nextInt(32)) - case DATE => new Date(Random.nextLong()) - case TIMESTAMP => - val timestamp = new Timestamp(Random.nextLong()) - timestamp.setNanos(Random.nextInt(999999999)) - timestamp + case BOOLEAN => Random.nextBoolean() + case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte + case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort + case INT => Random.nextInt() + case DATE => Random.nextInt() + case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() + case FLOAT => Random.nextFloat() + case DOUBLE => Random.nextDouble() + case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) + case BINARY => randomBytes(Random.nextInt(32)) + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) @@ -62,15 +58,15 @@ object ColumnarTestUtils { } def makeRandomValues( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) - def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } - def makeUniqueRandomValues[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def makeUniqueRandomValues[JvmType]( + columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => @@ -79,10 +75,10 @@ object ColumnarTestUtils { } def makeRandomRow( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = { + def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value @@ -90,9 +86,9 @@ object ColumnarTestUtils { row } - def makeUniqueValuesAndSingleValueRows[T <: NativeType]( + def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], - count: Int) = { + count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 3d33484ab0eb..952637c5f9cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.columnar -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.{QueryTest, TestData} +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + setupTestData() test("simple columnar query") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -37,15 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst") - cacheTable("sizeTst") + ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + .toDF().registerTempTable("sizeTst") + ctx.cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.logical.statistics.sizeInBytes > - conf.autoBroadcastJoinThreshold) + ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + ctx.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -54,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -66,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("repeatedData") + ctx.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -78,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("nullableRepeatedData") + ctx.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -86,15 +89,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-2729 regression: timestamp data type") { + val timestamps = (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time") + timestamps.registerTempTable("timestamps") + checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) - cacheTable("timestamps") + ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { @@ -102,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - cacheTable("withEmptyParts") + ctx.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -115,4 +121,74 @@ class InMemoryColumnarQuerySuite extends QueryTest { complexData.count() complexData.unpersist() } + + test("decimal type") { + // Casting is required here because ScalaReflection can't capture decimal precision information. + val df = (1 to 10) + .map(i => Tuple1(Decimal(i, 15, 10))) + .toDF("dec") + .select($"dec" cast DecimalType(15, 10)) + + assert(df.schema.head.dataType === DecimalType(15, 10)) + + df.cache().registerTempTable("test_fixed_decimal") + checkAnswer( + sql("SELECT * FROM test_fixed_decimal"), + (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) + } + + test("test different data types") { + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD for the schema + val rdd = + ctx.sparkContext.parallelize((1 to 100), 10).map { i => + Row( + s"str${i}: test cache.", + s"binary${i}: test cache.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i), + (1 to i).toSeq, + (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + Row((i - 0.25).toFloat, Seq(true, false, null))) + } + ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + // Cache the table. + sql("cache table InMemoryCache_different_data_types") + // Make sure the table is indeed cached. + val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + assert( + ctx.isCached("InMemoryCache_different_data_types"), + "InMemoryCache_different_data_types should be cached.") + // Issue a query and check the results. + checkAnswer( + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + ctx.table("InMemoryCache_different_data_types").collect()) + ctx.dropTempTable("InMemoryCache_different_data_types") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index f95c895587f3..f4f6c7649bfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -19,36 +19,37 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} -class TestNullableColumnAccessor[T <: DataType, JvmType]( +class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, - columnType: ColumnType[T, JvmType]) + columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) + : TestNullableColumnAccessor[JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) } } -class NullableColumnAccessorSuite extends FunSuite { +class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + .foreach { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnAccessor[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) @@ -74,7 +75,7 @@ class NullableColumnAccessorSuite extends FunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row(0) === randomRow(0)) + assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 80bd5c94570c..241d09ea205e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -17,34 +17,35 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) +class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder } } -class NullableColumnBuilderSuite extends FunSuite { +class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + .foreach { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnBuilder[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -91,13 +92,14 @@ class NullableColumnBuilderSuite extends FunSuite { // For non-null values (0 until 4).foreach { _ => - val actual = if (columnType == GENERIC) { - SparkSqlSerializer.deserialize[Any](GENERIC.extract(buffer)) + val actual = if (columnType.isInstanceOf[GENERIC]) { + SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) } else { columnType.extract(buffer) } - assert(actual === randomRow(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0, columnType.dataType), + "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index fe9a69edbb92..ab2644eb4581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,40 +17,43 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ + +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ -class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { - val originalColumnBatchSize = conf.columnBatchSize - val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) - }, 5) + }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + // Enable in-memory table scan accumulators + ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + ctx.cacheTable("pruningData") } override protected def afterAll(): Unit = { - setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) - } - - before { - cacheTable("pruningData") - } - - after { - uncacheTable("pruningData") + try { + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + ctx.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 8b518f094174..9a2948c59ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - -import org.apache.spark.sql.Row +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} -class BooleanBitSetSuite extends FunSuite { +class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ def skeleton(count: Int) { @@ -33,8 +32,8 @@ class BooleanBitSetSuite extends FunSuite { // ------------- val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) - val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_(0)) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) + val values = rows.map(_.getBoolean(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index c82d9799359c..acfab6586c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -19,19 +19,18 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -class DictionaryEncodingSuite extends FunSuite { - testDictionaryEncoding(new IntColumnStats, INT) - testDictionaryEncoding(new LongColumnStats, LONG) +class DictionaryEncodingSuite extends SparkFunSuite { + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) - def testDictionaryEncoding[T <: NativeType]( + def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 88011631ee4e..2111e9fbe62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType -class IntegralDeltaSuite extends FunSuite { - testIntegralDelta(new IntColumnStats, INT, IntDelta) +class IntegralDeltaSuite extends SparkFunSuite { + testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( @@ -33,7 +32,7 @@ class IntegralDeltaSuite extends FunSuite { columnType: NativeColumnType[I], scheme: CompressionScheme) { - def skeleton(input: Seq[I#JvmType]) { + def skeleton(input: Seq[I#InternalType]) { // ------------- // Tests encoder // ------------- @@ -116,17 +115,17 @@ class IntegralDeltaSuite extends FunSuite { test(s"$scheme: simple case") { val input = columnType match { - case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } test(s"$scheme: long random series") { // Have to workaround with `Any` since no `ClassTag[I#JvmType]` available here. val input = Array.fill[Any](10000)(makeRandomValue(columnType)) - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 08df1db37509..67ec08f594a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -17,22 +17,21 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -class RunLengthEncodingSuite extends FunSuite { +class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) - testRunLengthEncoding(new ByteColumnStats, BYTE) - testRunLengthEncoding(new ShortColumnStats, SHORT) - testRunLengthEncoding(new IntColumnStats, INT) - testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) - def testRunLengthEncoding[T <: NativeType]( + def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 0b18b4119268..5268dfe0aa03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -class TestCompressibleColumnBuilder[T <: NativeType]( +class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) @@ -32,10 +32,10 @@ class TestCompressibleColumnBuilder[T <: NativeType]( } object TestCompressibleColumnBuilder { - def apply[T <: NativeType]( + def apply[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - scheme: CompressionScheme) = { + scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) builder.initialize(0, "", useCompression = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala new file mode 100644 index 000000000000..8998f5111124 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext + +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { + test("shuffling UnsafeRows in exchange") { + val input = (1 to 1000).map(Tuple1.apply) + checkAnswer( + input.toDF(), + plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + input.map(Row.fromTuple) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index df108a9d262b..fad93b014c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,21 +17,44 @@ package org.apache.spark.sql.execution -import org.scalatest.FunSuite - -import org.apache.spark.sql.{SQLConf, execution} -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.SparkFunSuite +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends FunSuite { +class PlannerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + setupTestData() + + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val _ctx = ctx + import _ctx.planner._ + val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 3, + s"The plan of query $query does not have partial aggregations.") + } + test("unions are collapsed") { + val _ctx = ctx + import _ctx.planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -43,35 +66,30 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - assert(aggregations.size === 2) + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - applySchema(rowRDD, schema).registerTempTable("testLimit") + val rowRDD = ctx.sparkContext.parallelize(row :: Nil) + ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -85,10 +103,10 @@ class PlannerSuite extends FunSuite { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - dropTempTable("testLimit") + ctx.dropTempTable("testLimit") } - val origThreshold = conf.autoBroadcastJoinThreshold + val origThreshold = ctx.conf.autoBroadcastJoinThreshold val simpleTypes = NullType :: @@ -100,7 +118,7 @@ class PlannerSuite extends FunSuite { FloatType :: DoubleType :: DecimalType(10, 5) :: - DecimalType.Unlimited :: + DecimalType.SYSTEM_DEFAULT :: DateType :: TimestampType :: StringType :: @@ -120,18 +138,18 @@ class PlannerSuite extends FunSuite { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + val origThreshold = ctx.conf.autoBroadcastJoinThreshold + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") val a = testData.as("a") - val b = table("tiny").as("b") + val b = ctx.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } @@ -140,6 +158,214 @@ class PlannerSuite extends FunSuite { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + } + + test("efficient limit -> project -> sort") { + { + val query = + testData.select('key, 'value).sort('key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) + } + + { + // We need to make sure TakeOrderedAndProject's output is correct when we push a project + // into it. + val query = + testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + } } + + test("PartitioningCollection") { + withTempTable("normal", "small", "tiny") { + testData.registerTempTable("normal") + testData.limit(10).registerTempTable("small") + testData.limit(3).registerTempTable("tiny") + + // Disable broadcast join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + { + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + { + // This second query joins on different keys: + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (normal.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + } + } + } + + // --- Unit tests of EnsureRequirements --------------------------------------------------------- + + // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, + // there two dimensions that need to be considered: are the child partitionings compatible and + // do they satisfy the distribution requirements? As a result, we need at least four test cases. + + private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { + if (outputPlan.children.length > 1 + && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { + val childPartitionings = outputPlan.children.map(_.outputPartitioning) + if (!Partitioning.allCompatible(childPartitionings)) { + fail(s"Partitionings are not compatible: $childPartitionings") + } + } + outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { + case (child, requiredDist) => + assert(child.outputPartitioning.satisfies(requiredDist), + s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan") + } + } + + test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { + // Consider an operator that requires inputs that are clustered by two expressions (e.g. + // sort merge join where there are multiple columns in the equi-join condition) + val clusteringA = Literal(1) :: Nil + val clusteringB = Literal(2) :: Nil + val distribution = ClusteredDistribution(clusteringA ++ clusteringB) + // Say that the left and right inputs are each partitioned by _one_ of the two join columns: + val leftPartitioning = HashPartitioning(clusteringA, 1) + val rightPartitioning = HashPartitioning(clusteringB, 1) + // Individually, each input's partitioning satisfies the clustering distribution: + assert(leftPartitioning.satisfies(distribution)) + assert(rightPartitioning.satisfies(distribution)) + // However, these partitionings are not compatible with each other, so we still need to + // repartition both inputs prior to performing the join: + assert(!leftPartitioning.compatibleWith(rightPartitioning)) + assert(!rightPartitioning.compatibleWith(leftPartitioning)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = leftPartitioning), + DummySparkPlan(outputPartitioning = rightPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with child partitionings with different numbers of output partitions") { + // This is similar to the previous test, except it checks that partitionings are not compatible + // unless they produce the same number of partitions. + val clustering = Literal(1) :: Nil + val distribution = ClusteredDistribution(clustering) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)), + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2)) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + } + + test("EnsureRequirements with compatible child partitionings that do not satisfy distribution") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // The left and right inputs have compatible partitionings but they do not satisfy the + // distribution because they are clustered on different columns. Thus, we need to shuffle. + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with compatible child partitionings that satisfy distribution") { + // In this case, all requirements are satisfied and no exchange should be added. + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"Exchange should not have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-9703 + test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { + // Consider an operator that imposes both output distribution and ordering requirements on its + // children, such as sort sort merge join. If the distribution requirements are satisfied but + // the output ordering requirements are unsatisfied, then the planner should only add sorts and + // should not need to add additional shuffles / exchanges. + val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = SinglePartition), + DummySparkPlan(outputPartitioning = SinglePartition) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(outputOrdering, outputOrdering) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"No Exchanges should have been added:\n$outputPlan") + } + } + + // --------------------------------------------------------------------------------------------- +} + +// Used for unit-testing EnsureRequirements +private case class DummySparkPlan( + override val children: Seq[SparkPlan] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val requiredChildDistribution: Seq[Distribution] = Nil, + override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil + ) extends SparkPlan { + override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala new file mode 100644 index 000000000000..ef6ad59b71fb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { + + private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { + case c: ConvertToUnsafe => c + case c: ConvertToSafe => c + } + + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + assert(!outputsSafe.outputsUnsafeRows) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + assert(outputsUnsafe.outputsUnsafeRows) + + test("planner should insert unsafe->safe conversions when required") { + val plan = Limit(10, outputsUnsafe) + val preparedPlan = ctx.prepareForExecution.execute(plan) + assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) + } + + test("filter can process unsafe rows") { + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) + val preparedPlan = ctx.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("filter can process safe rows") { + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) + val preparedPlan = ctx.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("execute() fails an assertion if inputs rows are of different formats") { + val e = intercept[AssertionError] { + Union(Seq(outputsSafe, outputsUnsafe)).execute() + } + assert(e.getMessage.contains("format")) + } + + test("union requires all of its input rows' formats to agree") { + val plan = Union(Seq(outputsSafe, outputsUnsafe)) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = ctx.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("union can process safe rows") { + val plan = Union(Seq(outputsSafe, outputsSafe)) + val preparedPlan = ctx.prepareForExecution.execute(plan) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("union can process unsafe rows") { + val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) + val preparedPlan = ctx.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("round trip with ConvertToUnsafe and ConvertToSafe") { + val input = Seq(("hello", 1), ("world", 2)) + checkAnswer( + ctx.createDataFrame(input), + plan => ConvertToSafe(ConvertToUnsafe(plan)), + input.map(Row.fromTuple) + ) + } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SparkPlan.currentContext.set(ctx) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala new file mode 100644 index 000000000000..8fa77b0fcb7b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext + +class SortSuite extends SparkPlanTest with SharedSQLContext { + + // This test was originally added as an example of how to use [[SparkPlanTest]]; + // it's not designed to be a comprehensive test of ExternalSort. + test("basic sorting using ExternalSort") { + + val input = Seq( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), + sortAnswers = false) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), + sortAnswers = false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala new file mode 100644 index 000000000000..3a87f374d94b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.util._ + +/** + * Base class for writing tests for individual physical operators. For an example of how this + * class's test helper methods can be used, see [[SortSuite]]. + */ +private[sql] abstract class SparkPlanTest extends SparkFunSuite { + protected def _sqlContext: SQLContext + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + _sqlContext.implicits.localSeqToDataFrameHolder(data) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + input :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans.head), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer2( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to + * instantiate the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def doCheckAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the result produced by a reference plan. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkThatPlansAgree( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer( + input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} + +/** + * Helper methods for writing tests of individual physical operators. + */ +object SparkPlanTest { + + /** + * Runs the plan and makes sure the answer matches the result produced by a reference plan. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. + */ + def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean, + _sqlContext: SQLContext): Option[String] = { + + val outputPlan = planFunction(input.queryExecution.sparkPlan) + val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) + + val expectedAnswer: Seq[Row] = try { + executePlan(expectedOutputPlan, _sqlContext) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan to calculate expected answer: + | $expectedOutputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + val actualAnswer: Seq[Row] = try { + executePlan(outputPlan, _sqlContext) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match. + | Actual result Spark plan: + | $outputPlan + | Expected result Spark plan: + | $expectedOutputPlan + | $errorMessage + """.stripMargin + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean, + _sqlContext: SQLContext): Option[String] = { + + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) + + val sparkAnswer: Seq[Row] = try { + executePlan(outputPlan, _sqlContext) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match for Spark plan: + | $outputPlan + | $errorMessage + """.stripMargin + } + } + + private def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } + + private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = _sqlContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) + resolvedPlan.executeCollect().toSeq + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala new file mode 100644 index 000000000000..48c3938ff87b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.shuffle.ShuffleMemoryManager + +/** + * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. + */ +class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) { + private var oom = false + + override def tryToAcquire(numBytes: Long): Long = { + if (oom) { + oom = false + 0 + } else { + // Uncomment the following to trace memory allocations. + // println(s"tryToAcquire $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + val acquired = super.tryToAcquire(numBytes) + acquired + } + } + + override def release(numBytes: Long): Unit = { + // Uncomment the following to trace memory releases. + // println(s"release $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + super.release(numBytes) + } + + def markAsOutOfMemory(): Unit = { + oom = true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala new file mode 100644 index 000000000000..3158458edb83 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +/** + * A test suite that generates randomized data to test the [[TungstenSort]] operator. + */ +class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { + + override def beforeAll(): Unit = { + super.beforeAll() + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + try { + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } finally { + super.afterAll() + } + } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + test("sorting updates peak execution memory") { + val sc = ctx.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = ctx.createDataFrame( + ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + assert(TungstenSort.supportsSchema(inputDf.schema)) + checkThatPlansAgree( + inputDf, + plan => ConvertToSafe( + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 000000000000..d1f0b2b1fc52 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.control.NonFatal +import scala.collection.mutable +import scala.util.{Try, Random} + +import org.scalatest.Matchers + +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Test suite for [[UnsafeFixedWidthAggregationMap]]. + * + * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. + */ +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with SharedSQLContext { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: InternalRow = InternalRow(0) + private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes + + private var taskMemoryManager: TaskMemoryManager = null + private var shuffleMemoryManager: TestShuffleMemoryManager = null + + def testWithMemoryLeakDetection(name: String)(f: => Unit) { + def cleanup(): Unit = { + if (taskMemoryManager != null) { + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() + assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) + assert(leakedShuffleMemory === 0) + taskMemoryManager = null + } + TaskContext.unset() + } + + test(name) { + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + shuffleMemoryManager = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = Random.nextInt(10000), + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = null, + internalAccumulators = Seq.empty)) + + try { + f + } catch { + case NonFatal(e) => + Try(cleanup()) + throw e + } + cleanup() + } + } + + private def randomStrings(n: Int): Seq[String] = { + val rand = new Random(42) + Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + }.distinct + } + + testWithMemoryLeakDetection("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + testWithMemoryLeakDetection("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 1024, // initial capacity, + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + assert(!map.iterator().next()) + map.free() + } + + testWithMemoryLeakDetection("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 1024, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val groupKey = InternalRow(UTF8String.fromString("cats")) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + assert(map.getAggregationBuffer(groupKey) != null) + val iter = map.iterator() + assert(iter.next()) + iter.getKey.getString(0) should be ("cats") + iter.getValue.getInt(0) should be (0) + assert(!iter.next()) + + // Modifications to rows retrieved from the map should update the values in the map + iter.getValue.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + testWithMemoryLeakDetection("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null) + } + + val seenKeys = new mutable.HashSet[String] + val iter = map.iterator() + while (iter.next()) { + seenKeys += iter.getKey.getString(0) + } + assert(seenKeys.size === groupKeys.size) + assert(seenKeys === groupKeys) + map.free() + } + + testWithMemoryLeakDetection("test external sorting") { + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + val keys = randomStrings(1024).take(512) + keys.foreach { keyString => + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) + buf.setInt(0, keyString.length) + assert(buf != null) + } + + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) + out += iter.getKey.getString(0) + } + + assert(out === (keys ++ additionalKeys).sorted) + + map.free() + } + + testWithMemoryLeakDetection("test external sorting with an empty map") { + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) + } + + assert(out === (additionalKeys).sorted) + + map.free() + } + + testWithMemoryLeakDetection("test external sorting with empty records") { + + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + StructType(Nil), + StructType(Nil), + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + (1 to 10).foreach { i => + val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) + assert(buf != null) + } + + // Convert the map into a sorter. Right now, it contains one record. + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + (1 to 4096).foreach { i => + sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + var count = 0 + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + iter.getKey.copy() + iter.getValue.copy() + count += 1; + } + + // 1 record was from the map and 4096 records were explicitly inserted. + assert(count === 4097) + + map.free() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala new file mode 100644 index 000000000000..d3be568a8758 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} + +/** + * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. + */ +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { + private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) + private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) + + testKVSorter(new StructType, new StructType, spill = true) + testKVSorter(new StructType().add("c1", IntegerType), new StructType, spill = true) + testKVSorter(new StructType, new StructType().add("c1", IntegerType), spill = true) + + private val rand = new Random(42) + for (i <- 0 until 6) { + val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes) + val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes) + testKVSorter(keySchema, valueSchema, spill = i > 3) + } + + + /** + * Create a test case using randomly generated data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the generated data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = { + // Create the data converters + val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + val kConverter = UnsafeProjection.create(keySchema) + val vConverter = UnsafeProjection.create(valueSchema) + + val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get + val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get + + val inputData = Seq.fill(1024) { + val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) + val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } + + val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]") + val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]") + + test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") { + testKVSorter( + keySchema, + valueSchema, + inputData, + pageSize = 16 * 1024 * 1024, + spill + ) + } + } + + /** + * Create a test case using the given input data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the input data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter( + keySchema: StructType, + valueSchema: StructType, + inputData: Seq[(InternalRow, InternalRow)], + pageSize: Long, + spill: Boolean): Unit = { + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 98456, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null, + internalAccumulators = Seq.empty)) + + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize) + + // Insert the keys and values into the sorter + inputData.foreach { case (k, v) => + sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) + // 1% chance we will spill + if (rand.nextDouble() < 0.01 && spill) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + // Collect the sorted output + val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] + val iter = sorter.sortedIterator() + while (iter.next()) { + out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) + } + sorter.cleanupResources() + + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) + val kvOrdering = new Ordering[(InternalRow, InternalRow)] { + override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { + keyOrdering.compare(x._1, y._1) match { + case 0 => valueOrdering.compare(x._2, y._2) + case cmp => cmp + } + } + } + + // Testing to make sure output from the sorter is sorted by key + var prevK: InternalRow = null + out.zipWithIndex.foreach { case ((k, v), i) => + if (prevK != null) { + assert(keyOrdering.compare(prevK, k) <= 0, + s""" + |key is not in sorted order: + |previous key: $prevK + |current key : $k + """.stripMargin) + } + prevK = k + } + + // Testing to make sure the key/value in output matches input + assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering)) + + // Make sure there is no memory leak + val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory + if (shuffleMemMgr != null) { + val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() + assert(0L === leakedShuffleMemory) + } + assert(0 === leakedUnsafeMemory) + TaskContext.unset() + } + + test("kv sorting with records that exceed page size") { + val pageSize = 128 + + val schema = StructType(StructField("b", BinaryType) :: Nil) + val externalConverter = CatalystTypeConverters.createToCatalystConverter(schema) + val converter = UnsafeProjection.create(schema) + + val rand = new Random() + val inputData = Seq.fill(1024) { + val kBytes = new Array[Byte](rand.nextInt(pageSize)) + val vBytes = new Array[Byte](rand.nextInt(pageSize)) + rand.nextBytes(kBytes) + rand.nextBytes(vBytes) + val k = converter(externalConverter.apply(Row(kBytes)).asInstanceOf[InternalRow]) + val v = converter(externalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } + + testKVSorter( + schema, + schema, + inputData, + pageSize, + spill = true + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala new file mode 100644 index 000000000000..bd02c73a26ac --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.types._ + + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + +class UnsafeRowSerializerSuite extends SparkFunSuite { + + private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { + val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val converter = UnsafeProjection.create(schema) + converter.apply(internalRow) + } + + test("toUnsafeRow() test helper method") { + // This currently doesnt work because the generic getter throws an exception. + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) + assert(row.getInt(1) === unsafeRow.getInt(1)) + } + + test("basic row serialization") { + val rows = Seq(Row("Hello", 1), Row("World", 2)) + val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType))) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val baos = new ByteArrayOutputStream() + val serializerStream = serializer.serializeStream(baos) + for (unsafeRow <- unsafeRows) { + serializerStream.writeKey(0) + serializerStream.writeValue(unsafeRow) + } + serializerStream.close() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + for (expectedRow <- unsafeRows) { + val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 + assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) + assert(expectedRow.getString(0) === actualRow.getString(0)) + assert(expectedRow.getInt(1) === actualRow.getInt(1)) + } + assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 000000000000..5fdb82b06772 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { + + test("memory acquired on construction") { + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala new file mode 100644 index 000000000000..1174b27732f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -0,0 +1,1166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import java.io.{File, StringWriter} +import java.sql.{Date, Timestamp} + +import com.fasterxml.jackson.core.JsonFactory +import org.apache.spark.rdd.RDD +import org.scalactic.Tolerance._ + +import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { + import testImplicits._ + + test("Type promotion") { + def checkTypePromotion(expected: Any, actual: Any) { + assert(expected.getClass == actual.getClass, + s"Failed to promote ${actual.getClass} to ${expected.getClass}.") + assert(expected == actual, + s"Promoted value ${actual}(${actual.getClass}) does not equal the expected value " + + s"${expected}(${expected.getClass}).") + } + + val factory = new JsonFactory() + def enforceCorrectType(value: Any, dataType: DataType): Any = { + val writer = new StringWriter() + val generator = factory.createGenerator(writer) + generator.writeObject(value) + generator.flush() + + val parser = factory.createParser(writer.toString) + parser.nextToken() + JacksonParser.convertField(factory, parser, dataType) + } + + val intNumber: Int = 2147483647 + checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) + checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) + checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) + checkTypePromotion( + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT)) + + val longNumber: Long = 9223372036854775807L + checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) + checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) + checkTypePromotion( + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT)) + + val doubleNumber: Double = 1.7976931348623157E308d + checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) + + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), + enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), + enforceCorrectType(intNumber.toLong, TimestampType)) + val strTime = "2014-09-30 12:34:56" + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType)) + + val strDate = "2014-10-15" + checkTypePromotion( + DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) + + val ISO8601Time1 = "1970-01-01T01:00:01.0Z" + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(DateTimeUtils.millisToDays(3601000), + enforceCorrectType(ISO8601Time1, DateType)) + val ISO8601Time2 = "1970-01-01T02:00:01-01:00" + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), + enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(DateTimeUtils.millisToDays(10801000), + enforceCorrectType(ISO8601Time2, DateType)) + } + + test("Get compatible type") { + def checkDataType(t1: DataType, t2: DataType, expected: DataType) { + var actual = compatibleType(t1, t2) + assert(actual == expected, + s"Expected $expected as the most general data type for $t1 and $t2, found $actual") + actual = compatibleType(t2, t1) + assert(actual == expected, + s"Expected $expected as the most general data type for $t1 and $t2, found $actual") + } + + // NullType + checkDataType(NullType, BooleanType, BooleanType) + checkDataType(NullType, IntegerType, IntegerType) + checkDataType(NullType, LongType, LongType) + checkDataType(NullType, DoubleType, DoubleType) + checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) + checkDataType(NullType, StringType, StringType) + checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) + checkDataType(NullType, StructType(Nil), StructType(Nil)) + checkDataType(NullType, NullType, NullType) + + // BooleanType + checkDataType(BooleanType, BooleanType, BooleanType) + checkDataType(BooleanType, IntegerType, StringType) + checkDataType(BooleanType, LongType, StringType) + checkDataType(BooleanType, DoubleType, StringType) + checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType) + checkDataType(BooleanType, StringType, StringType) + checkDataType(BooleanType, ArrayType(IntegerType), StringType) + checkDataType(BooleanType, StructType(Nil), StringType) + + // IntegerType + checkDataType(IntegerType, IntegerType, IntegerType) + checkDataType(IntegerType, LongType, LongType) + checkDataType(IntegerType, DoubleType, DoubleType) + checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) + checkDataType(IntegerType, StringType, StringType) + checkDataType(IntegerType, ArrayType(IntegerType), StringType) + checkDataType(IntegerType, StructType(Nil), StringType) + + // LongType + checkDataType(LongType, LongType, LongType) + checkDataType(LongType, DoubleType, DoubleType) + checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) + checkDataType(LongType, StringType, StringType) + checkDataType(LongType, ArrayType(IntegerType), StringType) + checkDataType(LongType, StructType(Nil), StringType) + + // DoubleType + checkDataType(DoubleType, DoubleType, DoubleType) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DoubleType) + checkDataType(DoubleType, StringType, StringType) + checkDataType(DoubleType, ArrayType(IntegerType), StringType) + checkDataType(DoubleType, StructType(Nil), StringType) + + // DecimalType + checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT, + DecimalType.SYSTEM_DEFAULT) + checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType) + + // StringType + checkDataType(StringType, StringType, StringType) + checkDataType(StringType, ArrayType(IntegerType), StringType) + checkDataType(StringType, StructType(Nil), StringType) + + // ArrayType + checkDataType(ArrayType(IntegerType), ArrayType(IntegerType), ArrayType(IntegerType)) + checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) + checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) + checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) + + // StructType + checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) + checkDataType( + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(StructField("f1", IntegerType, true) :: Nil)) + checkDataType( + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(Nil), + StructType(StructField("f1", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: + StructField("f2", IntegerType, true) :: Nil), + StructType(StructField("f1", LongType, true) :: Nil) , + StructType( + StructField("f1", LongType, true) :: + StructField("f2", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: Nil), + StructType( + StructField("f2", IntegerType, true) :: Nil), + StructType( + StructField("f1", IntegerType, true) :: + StructField("f2", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: Nil), + DecimalType.SYSTEM_DEFAULT, + StringType) + } + + test("Complex field and type inferring with null in sampling") { + val jsonDF = ctx.read.json(jsonNullStruct) + val expectedSchema = StructType( + StructField("headers", StructType( + StructField("Charset", StringType, true) :: + StructField("Host", StringType, true) :: Nil) + , true) :: + StructField("ip", StringType, true) :: + StructField("nullstr", StringType, true):: Nil) + + assert(expectedSchema === jsonDF.schema) + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select nullstr, headers.Host from jsonTable"), + Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row(null, null)) + ) + } + + test("Primitive field and type inferring") { + val jsonDF = ctx.read.json(primitiveFieldAndType) + + val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", LongType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + + test("Complex field and type inferring") { + val jsonDF = ctx.read.json(complexFieldAndType1) + + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: + StructField("arrayOfInteger", ArrayType(LongType, true), true) :: + StructField("arrayOfLong", ArrayType(LongType, true), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType, true), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil), true), true) :: + StructField("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType(20, 0), true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( + StructField("field1", ArrayType(LongType, true), true) :: + StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + Row("str1", "str2", null) + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from jsonTable"), + Row(Seq(null, null, null, null)) + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + Row(Seq("1", "2", "3"), Seq("str1", "str2")) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + Row("str2", 2.1) + ) + + // Access elements of an array of structs. + checkAnswer( + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + "from jsonTable"), + Row( + Row(true, "str1", null), + Row(false, null, null), + Row(null, null, null), + null) + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from jsonTable"), + Row( + Row(true, new java.math.BigDecimal("92233720368547758070")), + true, + new java.math.BigDecimal("92233720368547758070")) :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + Row(Seq(4, 5, 6), Seq("str1", "str2")) + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + Row(5, null) + ) + } + + test("GetField operation on complex data type") { + val jsonDF = ctx.read.json(complexFieldAndType1) + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + Row(true, "str1") + ) + + // Getting all values of a specific field from an array of structs. + checkAnswer( + sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + Row(Seq(true, false, null), Seq("str1", null, null)) + ) + } + + test("Type conflict in primitive field values") { + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + + val expectedSchema = StructType( + StructField("num_bool", StringType, true) :: + StructField("num_num_1", LongType, true) :: + StructField("num_num_2", DoubleType, true) :: + StructField("num_num_3", DoubleType, true) :: + StructField("num_str", StringType, true) :: + StructField("str_bool", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row("true", 11L, null, 1.1, "13.1", "str1") :: + Row("12", null, 21474836470.9, null, null, "true") :: + Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") :: + Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil + ) + + // Number and Boolean conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_bool - 10 from jsonTable where num_bool > 11"), + Row(2) + ) + + // Widening to LongType + checkAnswer( + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), + Row(21474836370L) :: Row(21474836470L) :: Nil + ) + + checkAnswer( + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), + Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil + ) + + // Widening to DecimalType + checkAnswer( + sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + Row(21474836472.2) :: + Row(92233720368547758071.3) :: Nil + ) + + // Widening to Double + checkAnswer( + sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), + Row(101.2) :: Row(21474836471.2) :: Nil + ) + + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 14"), + Row(BigDecimal("92233720368547758071.2")) + ) + + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + Row(new java.math.BigDecimal("92233720368547758071.2")) + ) + + // String and Boolean conflict: resolve the type as string. + checkAnswer( + sql("select * from jsonTable where str_bool = 'str1'"), + Row("true", 11L, null, 1.1, "13.1", "str1") + ) + } + + ignore("Type conflict in primitive field values (Ignored)") { + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + jsonDF.registerTempTable("jsonTable") + + // Right now, the analyzer does not promote strings in a boolean expression. + // Number and Boolean conflict: resolve the type as boolean in this query. + checkAnswer( + sql("select num_bool from jsonTable where NOT num_bool"), + Row(false) + ) + + checkAnswer( + sql("select str_bool from jsonTable where NOT str_bool"), + Row(false) + ) + + // Right now, the analyzer does not know that num_bool should be treated as a boolean. + // Number and Boolean conflict: resolve the type as boolean in this query. + checkAnswer( + sql("select num_bool from jsonTable where num_bool"), + Row(true) + ) + + checkAnswer( + sql("select str_bool from jsonTable where str_bool"), + Row(false) + ) + + // The plan of the following DSL is + // Project [(CAST(num_str#65:4, DoubleType) + 1.2) AS num#78] + // Filter (CAST(CAST(num_str#65:4, DoubleType), DecimalType) > 92233720368547758060) + // ExistingRdd [num_bool#61,num_num_1#62L,num_num_2#63,num_num_3#64,num_str#65,str_bool#66] + // We should directly cast num_str to DecimalType and also need to do the right type promotion + // in the Project. + checkAnswer( + jsonDF. + where('num_str >= BigDecimal("92233720368547758060")). + select(('num_str + 1.2).as("num")), + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue()) + ) + + // The following test will fail. The type of num_str is StringType. + // So, to evaluate num_str + 1.2, we first need to use Cast to convert the type. + // In our test data, one value of num_str is 13.1. + // The result of (CAST(num_str#65:4, DoubleType) + 1.2) for this value is 14.299999999999999, + // which is not 14.3. + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 13"), + Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil + ) + } + + test("Type conflict in complex field values") { + val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + + val expectedSchema = StructType( + StructField("array", ArrayType(LongType, true), true) :: + StructField("num_struct", StringType, true) :: + StructField("str_array", StringType, true) :: + StructField("struct", StructType( + StructField("field", StringType, true) :: Nil), true) :: + StructField("struct_array", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: + Row(null, """{"field":false}""", null, null, "{}") :: + Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: + Row(Seq(7), "{}", """["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil + ) + } + + test("Type conflict in array elements") { + val jsonDF = ctx.read.json(arrayElementTypeConflict) + + val expectedSchema = StructType( + StructField("array1", ArrayType(StringType, true), true) :: + StructField("array2", ArrayType(StructType( + StructField("field", LongType, true) :: Nil), true), true) :: + StructField("array3", ArrayType(StringType, true), true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", + """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: + Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Row(null, null, Seq("1", "2", "3")) :: Nil + ) + + // Treat an element as a number. + checkAnswer( + sql("select array1[0] + 1 from jsonTable where array1 is not null"), + Row(2) + ) + } + + test("Handling missing fields") { + val jsonDF = ctx.read.json(missingFields) + + val expectedSchema = StructType( + StructField("a", BooleanType, true) :: + StructField("b", LongType, true) :: + StructField("c", ArrayType(LongType, true), true) :: + StructField("d", StructType( + StructField("field", BooleanType, true) :: Nil), true) :: + StructField("e", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + } + + test("jsonFile should be based on JSONRelation") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalFile.toURI.toString + ctx.sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + + val analyzed = jsonDF.queryExecution.analyzed + assert( + analyzed.isInstanceOf[LogicalRelation], + "The DataFrame returned by jsonFile should be based on LogicalRelation.") + val relation = analyzed.asInstanceOf[LogicalRelation].relation + assert( + relation.isInstanceOf[JSONRelation], + "The DataFrame returned by jsonFile should be based on JSONRelation.") + assert(relation.asInstanceOf[JSONRelation].paths === Array(path)) + assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) + + val schema = StructType(StructField("a", LongType, true) :: Nil) + val logicalRelation = + ctx.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] + val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] + assert(relationWithSchema.paths === Array(path)) + assert(relationWithSchema.schema === schema) + assert(relationWithSchema.samplingRatio > 0.99) + } + + test("Loading a JSON dataset from a text file") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + val jsonDF = ctx.read.json(path) + + val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", LongType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + + test("Loading a JSON dataset from a text file with SQL") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTableSQL + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + + checkAnswer( + sql("select * from jsonTableSQL"), + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + + test("Applying schemas") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val schema = StructType( + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + val jsonDF1 = ctx.read.schema(schema).json(path) + + assert(schema === jsonDF1.schema) + + jsonDF1.registerTempTable("jsonTable1") + + checkAnswer( + sql("select * from jsonTable1"), + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") + ) + + val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + + assert(schema === jsonDF2.schema) + + jsonDF2.registerTempTable("jsonTable2") + + checkAnswer( + sql("select * from jsonTable2"), + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + + test("Applying schemas with MapType") { + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + + jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") + + checkAnswer( + sql("select map from jsonWithSimpleMap"), + Row(Map("a" -> 1)) :: + Row(Map("b" -> 2)) :: + Row(Map("c" -> 3)) :: + Row(Map("c" -> 1, "d" -> 4)) :: + Row(Map("e" -> null)) :: Nil + ) + + checkAnswer( + sql("select map['c'] from jsonWithSimpleMap"), + Row(null) :: + Row(null) :: + Row(3) :: + Row(1) :: + Row(null) :: Nil + ) + + val innerStruct = StructType( + StructField("field1", ArrayType(IntegerType, true), true) :: + StructField("field2", IntegerType, true) :: Nil) + val schemaWithComplexMap = StructType( + StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) + + val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + + jsonWithComplexMap.registerTempTable("jsonWithComplexMap") + + checkAnswer( + sql("select map from jsonWithComplexMap"), + Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: + Row(Map("b" -> Row(null, 2))) :: + Row(Map("c" -> Row(Seq(), 4))) :: + Row(Map("c" -> Row(null, 3), "d" -> Row(Seq(null), null))) :: + Row(Map("e" -> null)) :: + Row(Map("f" -> Row(null, null))) :: Nil + ) + + checkAnswer( + sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + Row(Seq(1, 2, 3, null), null) :: + Row(null, null) :: + Row(null, 4) :: + Row(null, 3) :: + Row(null, null) :: + Row(null, null) :: Nil + ) + } + + test("SPARK-2096 Correctly parse dot notations") { + val jsonDF = ctx.read.json(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + Row(true, "str1") + ) + checkAnswer( + sql( + """ + |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] + |from jsonTable + """.stripMargin), + Row("str2", 6) + ) + } + + test("SPARK-3390 Complex arrays") { + val jsonDF = ctx.read.json(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql( + """ + |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] + |from jsonTable + """.stripMargin), + Row(5, 7, 8) + ) + checkAnswer( + sql( + """ + |select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0], + |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 + |from jsonTable + """.stripMargin), + Row("str1", Nil, "str4", 2) + ) + } + + test("SPARK-3308 Read top level JSON arrays") { + val jsonDF = ctx.read.json(jsonArray) + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql( + """ + |select a, b, c + |from jsonTable + """.stripMargin), + Row("str_a_1", null, null) :: + Row("str_a_2", null, null) :: + Row(null, "str_b_3", null) :: + Row("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + } + + test("Corrupt records") { + // Test if we can query corrupt records. + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val jsonDF = ctx.read.json(corruptRecords) + jsonDF.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) + + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } + + test("SPARK-4068: nulls in arrays") { + val jsonDF = ctx.read.json(nullsInArrays) + jsonDF.registerTempTable("jsonTable") + + val schema = StructType( + StructField("field1", + ArrayType(ArrayType(ArrayType(ArrayType(StringType, true), true), true), true), true) :: + StructField("field2", + ArrayType(ArrayType( + StructType(StructField("Test", LongType, true) :: Nil), true), true), true) :: + StructField("field3", + ArrayType(ArrayType( + StructType(StructField("Test", StringType, true) :: Nil), true), true), true) :: + StructField("field4", + ArrayType(ArrayType(ArrayType(LongType, true), true), true), true) :: Nil) + + assert(schema === jsonDF.schema) + + checkAnswer( + sql( + """ + |SELECT field1, field2, field3, field4 + |FROM jsonTable + """.stripMargin), + Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: + Row(null, Seq(null, Seq(Row(1))), null, null) :: + Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) :: + Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil + ) + } + + test("SPARK-4228 DataFrame to JSON") { + val schema1 = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", ArrayType(StringType), nullable = true) :: + StructField("f5", IntegerType, true) :: Nil) + + val rowRDD1 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v5 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) + } + + val df1 = ctx.createDataFrame(rowRDD1, schema1) + df1.registerTempTable("applySchema1") + val df2 = df1.toDF + val result = df2.toJSON.collect() + // scalastyle:off + assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") + assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + // scalastyle:on + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val df3 = ctx.createDataFrame(rowRDD2, schema2) + df3.registerTempTable("applySchema2") + val df4 = df3.toDF + val result2 = df4.toJSON.collect() + + assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") + assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") + + val jsonDF = ctx.read.json(primitiveFieldAndType) + val primTable = ctx.read.json(jsonDF.toJSON) + primTable.registerTempTable("primativeTable") + checkAnswer( + sql("select * from primativeTable"), + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + "this is a simple string.") + ) + + val complexJsonDF = ctx.read.json(complexFieldAndType1) + val compTable = ctx.read.json(complexJsonDF.toJSON) + compTable.registerTempTable("complexTable") + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), + Row("str1", "str2", null) + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from complexTable"), + Row(Seq(null, null, null, null)) + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + " from complexTable"), + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), + Row(Seq("1", "2", "3"), Seq("str1", "str2")) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), + Row("str2", 2.1) + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from complexTable"), + Row( + Row(true, new java.math.BigDecimal("92233720368547758070")), + true, + new java.math.BigDecimal("92233720368547758070")) :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), + Row(Seq(4, 5, 6), Seq("str1", "str2")) + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + "from complexTable"), + Row(5, null) + ) + } + + test("JSONRelation equality test") { + val relation0 = new JSONRelation( + Some(empty), + 1.0, + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + None, None)(ctx) + val logicalRelation0 = LogicalRelation(relation0) + val relation1 = new JSONRelation( + Some(singleRow), + 1.0, + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + None, None)(ctx) + val logicalRelation1 = LogicalRelation(relation1) + val relation2 = new JSONRelation( + Some(singleRow), + 0.5, + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + None, None)(ctx) + val logicalRelation2 = LogicalRelation(relation2) + val relation3 = new JSONRelation( + Some(singleRow), + 1.0, + Some(StructType(StructField("b", IntegerType, true) :: Nil)), + None, None)(ctx) + val logicalRelation3 = LogicalRelation(relation3) + + assert(relation0 !== relation1) + assert(!logicalRelation0.sameResult(logicalRelation1), + s"$logicalRelation0 and $logicalRelation1 should be considered not having the same result.") + + assert(relation1 === relation2) + assert(logicalRelation1.sameResult(logicalRelation2), + s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.") + + assert(relation1 !== relation3) + assert(!logicalRelation1.sameResult(logicalRelation3), + s"$logicalRelation1 and $logicalRelation3 should be considered not having the same result.") + + assert(relation2 !== relation3) + assert(!logicalRelation2.sameResult(logicalRelation3), + s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.") + + withTempPath(dir => { + val path = dir.getCanonicalFile.toURI.toString + ctx.sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + + val d1 = ResolvedDataSource( + ctx, + userSpecifiedSchema = None, + partitionColumns = Array.empty[String], + provider = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)) + + val d2 = ResolvedDataSource( + ctx, + userSpecifiedSchema = None, + partitionColumns = Array.empty[String], + provider = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)) + assert(d1 === d2) + }) + } + + test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { + // This is really a test that it doesn't throw an exception + val emptySchema = InferSchema(empty, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } + + test("SPARK-7565 MapType in JsonRDD") { + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + try { + val temp = Utils.createTempDir().getPath + + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + df.write.mode("overwrite").parquet(temp) + // order of MapType is not defined + assert(ctx.read.parquet(temp).count() == 5) + + val df2 = ctx.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(temp) + checkAnswer(ctx.read.parquet(temp), df2.collect()) + } finally { + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } + } + + test("SPARK-8093 Erase empty structs") { + val emptySchema = InferSchema(emptyRecords, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } + + test("JSON with Partition") { + def makePartition(rdd: RDD[String], parent: File, partName: String, partValue: Any): File = { + val p = new File(parent, s"$partName=${partValue.toString}") + rdd.saveAsTextFile(p.getCanonicalPath) + p + } + + withTempPath(root => { + val d1 = new File(root, "d1=1") + // root/dt=1/col1=abc + val p1_col1 = makePartition( + ctx.sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), + d1, + "col1", + "abc") + + // root/dt=1/col1=abd + val p2 = makePartition( + ctx.sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), + d1, + "col1", + "abd") + + ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala new file mode 100644 index 000000000000..2864181cf91d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext + +private[json] trait TestJsonData { + protected def _sqlContext: SQLContext + + def primitiveFieldAndType: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"string":"this is a simple string.", + "integer":10, + "long":21474836470, + "bigInteger":92233720368547758070, + "double":1.7976931348623157E308, + "boolean":true, + "null":null + }""" :: Nil) + + def primitiveFieldValueTypeConflict: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, + "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: + """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, + "num_bool":12, "num_str":null, "str_bool":true}""" :: + """{"num_num_1":21474836470, "num_num_2":92233720368547758070, "num_num_3": 100, + "num_bool":false, "num_str":"str1", "str_bool":false}""" :: + """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, + "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) + + def jsonNullStruct: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: + """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: + """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: + """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) + + def complexFieldValueTypeConflict: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"num_struct":11, "str_array":[1, 2, 3], + "array":[], "struct_array":[], "struct": {}}""" :: + """{"num_struct":{"field":false}, "str_array":null, + "array":null, "struct_array":{}, "struct": null}""" :: + """{"num_struct":null, "str_array":"str", + "array":[4, 5, 6], "struct_array":[7, 8, 9], "struct": {"field":null}}""" :: + """{"num_struct":{}, "str_array":["str1", "str2", 33], + "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) + + def arrayElementTypeConflict: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], + "array2": [{"field":214748364700}, {"field":1}]}""" :: + """{"array3": [{"field":"str"}, {"field":1}]}""" :: + """{"array3": [1, 2, 3]}""" :: Nil) + + def missingFields: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"a":true}""" :: + """{"b":21474836470}""" :: + """{"c":[33, 44]}""" :: + """{"d":{"field":true}}""" :: + """{"e":"str"}""" :: Nil) + + def complexFieldAndType1: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"struct":{"field1": true, "field2": 92233720368547758070}, + "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, + "arrayOfString":["str1", "str2"], + "arrayOfInteger":[1, 2147483647, -2147483648], + "arrayOfLong":[21474836470, 9223372036854775807, -9223372036854775808], + "arrayOfBigInteger":[922337203685477580700, -922337203685477580800], + "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], + "arrayOfBoolean":[true, false, true], + "arrayOfNull":[null, null, null, null], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], + "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], + "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] + }""" :: Nil) + + def complexFieldAndType2: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], + "complexArrayOfStruct": [ + { + "field1": [ + { + "inner1": "str1" + }, + { + "inner2": ["str2", "str22"] + }], + "field2": [[1, 2], [3, 4]] + }, + { + "field1": [ + { + "inner2": ["str3", "str33"] + }, + { + "inner1": "str4" + }], + "field2": [[5, 6], [7, 8]] + }], + "arrayOfArray1": [ + [ + [5] + ], + [ + [6, 7], + [8] + ]], + "arrayOfArray2": [ + [ + [ + { + "inner1": "str1" + } + ] + ], + [ + [], + [ + {"inner2": ["str3", "str33"]}, + {"inner2": ["str4"], "inner1": "str11"} + ] + ], + [ + [ + {"inner3": [[{"inner4": 2}]]} + ] + ]] + }""" :: Nil) + + def mapType1: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"map": {"a": 1}}""" :: + """{"map": {"b": 2}}""" :: + """{"map": {"c": 3}}""" :: + """{"map": {"c": 1, "d": 4}}""" :: + """{"map": {"e": null}}""" :: Nil) + + def mapType2: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: + """{"map": {"b": {"field2": 2}}}""" :: + """{"map": {"c": {"field1": [], "field2": 4}}}""" :: + """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" :: + """{"map": {"e": null}}""" :: + """{"map": {"f": {"field1": null}}}""" :: Nil) + + def nullsInArrays: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{"field1":[[null], [[["Test"]]]]}""" :: + """{"field2":[null, [{"Test":1}]]}""" :: + """{"field3":[[null], [{"Test":"2"}]]}""" :: + """{"field4":[[null, [1,2,3]]]}""" :: Nil) + + def jsonArray: RDD[String] = + _sqlContext.sparkContext.parallelize( + """[{"a":"str_a_1"}]""" :: + """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """[]""" :: Nil) + + def corruptRecords: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a":1, b:2}""" :: + """{"a":{, b:3}""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """]""" :: Nil) + + def emptyRecords: RDD[String] = + _sqlContext.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a": {}}""" :: + """{"a": {"b": {}}}""" :: + """{"b": [{"c": {}}]}""" :: + """]""" :: Nil) + + + lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala new file mode 100644 index 000000000000..bd7cf8c10abe --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord +import org.apache.hadoop.fs.Path +import org.apache.parquet.avro.AvroParquetWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.parquet.test.avro._ +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit): Unit = { + logInfo( + s"""Writing Avro records with the following Avro schema into Parquet file: + | + |${schema.toString(true)} + """.stripMargin) + + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() + } + + test("required primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroPrimitives](path, AvroPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write( + AvroPrimitives.newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setStringColumn(s"val_$i") + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + }) + } + } + + test("optional primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroOptionalPrimitives](path, AvroOptionalPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = if (i % 3 == 0) { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(null) + .setMaybeIntColumn(null) + .setMaybeLongColumn(null) + .setMaybeFloatColumn(null) + .setMaybeDoubleColumn(null) + .setMaybeBinaryColumn(null) + .setMaybeStringColumn(null) + .build() + } else { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeIntColumn(i) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeFloatColumn(i.toFloat + 0.1f) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setMaybeStringColumn(s"val_$i") + .build() + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + if (i % 3 == 0) { + Row.apply(Seq.fill(7)(null): _*) + } else { + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + } + }) + } + } + + test("non-nullable arrays") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroNonNullableArrays](path, AvroNonNullableArrays.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = { + val builder = + AvroNonNullableArrays.newBuilder() + .setStringsColumn(Seq.tabulate(3)(i => s"val_$i").asJava) + + if (i % 3 == 0) { + builder.setMaybeIntsColumn(null).build() + } else { + builder.setMaybeIntsColumn(Seq.tabulate(3)(Int.box).asJava).build() + } + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(i => s"val_$i"), + if (i % 3 == 0) null else Seq.tabulate(3)(identity)) + }) + } + } + + ignore("nullable arrays (parquet-avro 1.7.0 does not properly support this)") { + // TODO Complete this test case after upgrading to parquet-mr 1.8+ + } + + test("SPARK-10136 array of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroArrayOfArray](path, AvroArrayOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroArrayOfArray.newBuilder() + .setIntArraysColumn( + Seq.tabulate(3, 3)((i, j) => i * 3 + j: Integer).map(_.asJava).asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) + }) + } + } + + test("map of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroMapOfArray](path, AvroMapOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroMapOfArray.newBuilder() + .setStringToIntsColumn( + Seq.tabulate(3) { i => + i.toString -> Seq.tabulate(3)(j => i + j: Integer).asJava + }.toMap.asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) + }) + } + } + + test("various complex types") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + } + + def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { + def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Nested + .newBuilder() + .setNestedIntsColumn(Seq.tabulate(3)(j => i + j + m: Integer).asJava) + .setNestedStringColumn(s"val_${i + m}") + .build() + }.asJava + }.toMap.asJava + } + + ParquetAvroCompat + .newBuilder() + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}").asJava) + .setStringToIntColumn(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap.asJava) + .setComplexColumn(makeComplexColumn(i)) + .build() + } + + test("SPARK-9407 Push down predicates involving Parquet ENUM columns") { + import testImplicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala new file mode 100644 index 000000000000..df68432faeeb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.MessageType + +import org.apache.spark.sql.QueryTest + +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { + protected def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) + } + + protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { + val fsPath = new Path(path) + val fs = fsPath.getFileSystem(configuration) + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq + + val footers = + ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true) + footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema + } + + protected def logParquetSchema(path: String): Unit = { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} + """.stripMargin) + } +} + +object ParquetCompatibilityTest { + def makeNullable[T <: AnyRef](i: Int)(f: => T): T = { + if (i % 3 == 0) null.asInstanceOf[T] else f + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala new file mode 100644 index 000000000000..f067112cfca9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.parquet.filter2.predicate.Operators._ +import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} + +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite that tests Parquet filter2 API based filter pushdown optimization. + * + * NOTE: + * + * 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the + * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` + * results in a `GtEq` filter predicate rather than a `Not`. + * + * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred + * data type is nullable. + */ +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + + private def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + filterClass: Class[_ <: FilterPredicate], + checker: (DataFrame, Seq[Row]) => Unit, + expected: Seq[Row]): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + val analyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters + }.flatten + assert(analyzedPredicate.nonEmpty) + + val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate) + assert(selectedFilters.nonEmpty) + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } + } + checker(query, expected) + } + } + + private def checkFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit df: DataFrame): Unit = { + checkFilterPredicate(df, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) + } + + private def checkFilterPredicate[T] + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) + (implicit df: DataFrame): Unit = { + checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) + } + + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit df: DataFrame): Unit = { + def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { + df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + } + } + + checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected) + } + + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit df: DataFrame): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) + } + + test("filter pushdown - boolean") { + withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + } + } + + test("filter pushdown - integer") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - long") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - float") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - double") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - string") { + withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + } + } + + test("filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") + } + + withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) + + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + } + } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.read.parquet(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala new file mode 100644 index 000000000000..08d2b9dee99b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -0,0 +1,506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.util.Collections + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.metadata.{BlockMetaData, CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport +// with an empty configuration (it is after all not intended to be used in this way?) +// and members are private so we need to make our own in order to pass the schema +// to the writer. +private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { + var groupWriter: GroupWriter = null + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + groupWriter = new GroupWriter(recordConsumer, schema) + } + + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, new java.util.HashMap[String, String]()) + } + + override def write(record: Group) { + groupWriter.write(record) + } +} + +/** + * A test suite that tests basic Parquet I/O. + */ +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ + + /** + * Writes `data` to a Parquet file, reads it back and check file contents. + */ + protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { + withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) + } + + test("basic data types (without binary)") { + val data = (1 to 4).map { i => + (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + } + checkParquetFile(data) + } + + test("raw binary") { + val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) + withParquetDataFrame(data) { df => + assertResult(data.map(_._1.mkString(",")).sorted) { + df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + } + } + } + + test("string") { + val data = (1 to 4).map(i => Tuple1(i.toString)) + // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL + // as we store Spark SQL schema in the extra metadata. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) + } + + test("fixed-length decimals") { + def makeDecimalRDD(decimal: DecimalType): DataFrame = + sqlContext.sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(i / 100.0)) + .toDF() + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") + + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + } + } + } + + test("date type") { + def makeDateRDD(): DataFrame = + sqlContext.sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) + .toDF() + .select($"_1") + + withTempPath { dir => + val data = makeDateRDD() + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + } + } + + test("map") { + val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) + checkParquetFile(data) + } + + test("array") { + val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) + checkParquetFile(data) + } + + test("array and double") { + val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) + checkParquetFile(data) + } + + test("struct") { + val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) + withParquetDataFrame(data) { df => + // Structs are converted to `Row`s + checkAnswer(df, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) + } + } + + test("nested struct with array of array as field") { + val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) + withParquetDataFrame(data) { df => + // Structs are converted to `Row`s + checkAnswer(df, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) + } + } + + test("nested map with struct as value type") { + val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) + withParquetDataFrame(data) { df => + checkAnswer(df, data.map { case Tuple1(m) => + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + }) + } + } + + test("nulls") { + val allNulls = ( + null.asInstanceOf[java.lang.Boolean], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[java.lang.Float], + null.asInstanceOf[java.lang.Double]) + + withParquetDataFrame(allNulls :: Nil) { df => + val rows = df.collect() + assert(rows.length === 1) + assert(rows.head === Row(Seq.fill(5)(null): _*)) + } + } + + test("nones") { + val allNones = ( + None.asInstanceOf[Option[Int]], + None.asInstanceOf[Option[Long]], + None.asInstanceOf[Option[String]]) + + withParquetDataFrame(allNones :: Nil) { df => + val rows = df.collect() + assert(rows.length === 1) + assert(rows.head === Row(Seq.fill(3)(null): _*)) + } + } + + test("compression codec") { + def compressionCodecFor(path: String): String = { + val codecs = ParquetTypesConverter + .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala + .flatMap(_.getColumns.asScala) + .map(_.getCodec.name()) + .distinct + + assert(codecs.size === 1) + codecs.head + } + + val data = (0 until 10).map(i => (i, i.toString)) + + def checkCompressionCodec(codec: CompressionCodecName): Unit = { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { + withParquetFile(data) { path => + assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { + compressionCodecFor(path) + } + } + } + } + + // Checks default compression codec + checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) + + checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) + checkCompressionCodec(CompressionCodecName.GZIP) + checkCompressionCodec(CompressionCodecName.SNAPPY) + } + + test("read raw Parquet file") { + def makeRawParquetFile(path: Path): Unit = { + val schema = MessageTypeParser.parseMessageType( + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + |} + """.stripMargin) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + (0 until 10).foreach { i => + val record = new SimpleGroup(schema) + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + writer.write(record) + } + + writer.close() + } + + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) + } + } + + test("write metadata") { + withTempPath { file => + val path = new Path(file.toURI.toString) + val fs = FileSystem.getLocal(configuration) + val attributes = ScalaReflection.attributesFor[(Int, String)] + ParquetTypesConverter.writeMetaData(attributes, path, configuration) + + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + + val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val actualSchema = metaData.getFileMetaData.getSchema + val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + + actualSchema.checkContains(expectedSchema) + expectedSchema.checkContains(actualSchema) + } + } + + test("save - overwrite") { + withParquetFile((1 to 10).map(i => (i, i.toString))) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) + checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) + } + } + + test("save - ignore") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) + checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) + } + } + + test("save - throw") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + val errorMessage = intercept[Throwable] { + newData.toDF().write.format("parquet").mode(SaveMode.ErrorIfExists).save(file) + }.getMessage + assert(errorMessage.contains("already exists")) + } + } + + test("save - append") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) + checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) + } + } + + test("SPARK-6315 regression test") { + // Spark 1.1 and prior versions write Spark schema as case class string into Parquet metadata. + // This has been deprecated by JSON format since 1.2. Notice that, 1.3 further refactored data + // types API, and made StructType.fields an array. This makes the result of StructType.toString + // different from prior versions: there's no "Seq" wrapping the fields part in the string now. + val sparkSchema = + "StructType(Seq(StructField(a,BooleanType,false),StructField(b,IntegerType,false)))" + + // The Parquet schema is intentionally made different from the Spark schema. Because the new + // Parquet data source simply falls back to the Parquet schema once it fails to parse the Spark + // schema. By making these two different, we are able to assert the old style case class string + // is parsed successfully. + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c; + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Collections.singletonMap( + CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + + ParquetFileWriter.writeMetadataFile( + sqlContext.sparkContext.hadoopConfiguration, + path, + Collections.singletonList( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) + + assertResult(sqlContext.read.parquet(path.toString).schema) { + StructType( + StructField("a", BooleanType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: + Nil) + } + } + } + + test("SPARK-6352 DirectParquetOutputCommitter") { + val clonedConf = new Configuration(configuration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + classOf[DirectParquetOutputCommitter].getCanonicalName) + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(configuration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { + withTempPath { dir => + val clonedConf = new Configuration(configuration) + + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) + + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) + + try { + val message = intercept[SparkException] { + sqlContext.range(0, 1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(message === "Intentional exception for testing purposes") + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + } + + test("SPARK-6330 regression test") { + // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: + // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// + intercept[Throwable] { + sqlContext.read.parquet("file:///nonexistent") + } + val errorMessage = intercept[Throwable] { + sqlContext.read.parquet("hdfs://nonexistent") + }.toString + assert(errorMessage.contains("UnknownHostException")) + } + + test("SPARK-7837 Do not close output writer twice when commitTask() fails") { + val clonedConf = new Configuration(configuration) + + // Using a output committer that always fail when committing a task, so that both + // `commitTask()` and `abortTask()` are invoked. + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) + + try { + // Before fixing SPARK-7837, the following code results in an NPE because both + // `commitTask()` and `abortTask()` try to close output writers. + + withTempPath { dir => + val m1 = intercept[SparkException] { + sqlContext.range(1).coalesce(1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m1.contains("Intentional exception for testing purposes")) + } + + withTempPath { dir => + val m2 = intercept[SparkException] { + val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + df.write.partitionBy("a").parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m2.contains("Intentional exception for testing purposes")) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } +} + +class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} + +class TaskCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitTask(context: TaskAttemptContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala new file mode 100644 index 000000000000..ed8bafb10c60 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -0,0 +1,600 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File +import java.math.BigInteger +import java.sql.Timestamp + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { + import PartitioningUtils._ + import testImplicits._ + + val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" + + test("column type inference") { + def check(raw: String, literal: Literal): Unit = { + assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) + } + + check("10", Literal.create(10, IntegerType)) + check("1000000000000000", Literal.create(1000000000000000L, LongType)) + check("1.5", Literal.create(1.5, DoubleType)) + check("hello", Literal.create("hello", StringType)) + check(defaultPartitionName, Literal.create(null, NullType)) + } + + test("parse partition") { + def check(path: String, expected: Option[PartitionValues]): Unit = { + assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) + } + + def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { + val message = intercept[T] { + parsePartition(new Path(path), defaultPartitionName, true).get + }.getMessage + + assert(message.contains(expected)) + } + + check("file://path/a=10", Some { + PartitionValues( + ArrayBuffer("a"), + ArrayBuffer(Literal.create(10, IntegerType))) + }) + + check("file://path/a=10/b=hello/c=1.5", Some { + PartitionValues( + ArrayBuffer("a", "b", "c"), + ArrayBuffer( + Literal.create(10, IntegerType), + Literal.create("hello", StringType), + Literal.create(1.5, DoubleType))) + }) + + check("file://path/a=10/b_hello/c=1.5", Some { + PartitionValues( + ArrayBuffer("c"), + ArrayBuffer(Literal.create(1.5, DoubleType))) + }) + + check("file:///", None) + check("file:///path/_temporary", None) + check("file:///path/_temporary/c=1.5", None) + check("file:///path/_temporary/path", None) + check("file://path/a=10/_temporary/c=1.5", None) + check("file://path/a=10/c=1.5/_temporary", None) + + checkThrows[AssertionError]("file://path/=10", "Empty partition column name") + checkThrows[AssertionError]("file://path/a=", "Empty partition column value") + } + + test("parse partitions") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq(Partition(InternalRow(10, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", DoubleType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(10, UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(10.5, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", DoubleType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(10, UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(10.5, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(10, UTF8String.fromString("20")), + s"hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(null, UTF8String.fromString("hello")), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", DoubleType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(InternalRow(10.5, null), + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) + } + + test("parse partitions with type inference disabled") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq(Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + s"hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(null, UTF8String.fromString("hello")), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), null), + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(InternalRow(UTF8String.fromString("10.5"), null), + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + val dir = makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps) + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + dir) + // Introduce _temporary dir to test the robustness of the schema discovery process. + new File(dir.toString, "_temporary").mkdir() + } + // Introduce _temporary dir to the base dir the robustness of the schema discovery process. + new File(base.getCanonicalPath, "_temporary").mkdir() + + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } + + test("read partitioned table - merging compatible schemas") { + withTempDir { base => + makeParquetFile( + (1 to 10).map(i => Tuple1(i)).toDF("intField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 1)) + + makeParquetFile( + (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 2)) + + sqlContext + .read + .option("mergeSchema", "true") + .format("parquet") + .load(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) + } + } + } + + test("SPARK-7749 Non-partitioned table should have empty partition spec") { + withTempPath { dir => + (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) + val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution + queryExecution.analyzed.collectFirst { + case LogicalRelation(relation: ParquetRelation) => + assert(relation.partitionSpec === PartitionSpec.emptySpec) + }.getOrElse { + fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") + } + } + } + + test("SPARK-7847: Dynamic partition directory path escaping and unescaping") { + withTempPath { dir => + val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") + df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + } + } + + test("Various partition value types") { + val row = + Row( + 100.toByte, + 40000.toShort, + Int.MaxValue, + Long.MaxValue, + 1.5.toFloat, + 4.5, + new java.math.BigDecimal(new BigInteger("212500"), 5), + new java.math.BigDecimal(2.125), + java.sql.Date.valueOf("2015-05-23"), + new Timestamp(0), + "This is a string, /[]?=:", + "This is not a partition column") + + // BooleanType is not supported yet + val partitionColumnTypes = + Seq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(10, 5), + DecimalType.SYSTEM_DEFAULT, + DateType, + TimestampType, + StringType) + + val partitionColumns = partitionColumnTypes.zipWithIndex.map { + case (t, index) => StructField(s"p_$index", t) + } + + val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + } + } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + } + } + + test("listConflictingPartitionColumns") { + def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { + val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => + s"\tPartition column name list #$index: $list" + }.mkString("\n", "\n", "\n") + + // scalastyle:off + s"""Conflicting partition column names detected: + |$conflictingColNameLists + |For partitioned table directories, data files should only live in leaf directories. + |And directories at the same level should have the same partition column name. + |Please check the following directories for unexpected files or inconsistent partition column names: + |${paths.map("\t" + _).mkString("\n", "\n", "")} + """.stripMargin.trim + // scalastyle:on + } + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim === + makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim === + makeExpectedMessage( + Seq("a"), + Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), + PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1/b=foo"), + PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim === + makeExpectedMessage( + Seq("a", "a, b"), + Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) + } + + test("Parallel partition discovery") { + withTempPath { dir => + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + val path = dir.getCanonicalPath + val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + df.write.partitionBy("b", "c").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala new file mode 100644 index 000000000000..b290429c2a02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + + private def readParquetProtobufFile(name: String): DataFrame = { + val url = Thread.currentThread().getContextClassLoader.getResource(name) + sqlContext.read.parquet(url.toString) + } + + test("unannotated array of primitive type") { + checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + } + + test("unannotated array of struct") { + checkAnswer( + readParquetProtobufFile("old-repeated-message.parquet"), + Row( + Seq( + Row("First inner", null, null), + Row(null, "Second inner", null), + Row(null, null, "Third inner")))) + + checkAnswer( + readParquetProtobufFile("proto-repeated-struct.parquet"), + Row( + Seq( + Row("0 - 1", "0 - 2", "0 - 3"), + Row("1 - 1", "1 - 2", "1 - 3")))) + + checkAnswer( + readParquetProtobufFile("proto-struct-with-array-many.parquet"), + Seq( + Row( + Seq( + Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"), + Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))), + Row( + Seq( + Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"), + Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))), + Row( + Seq( + Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"), + Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3"))))) + } + + test("struct with unannotated array") { + checkAnswer( + readParquetProtobufFile("proto-struct-with-array.parquet"), + Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) + } + + test("unannotated array of struct with unannotated array") { + checkAnswer( + readParquetProtobufFile("nested-array-struct.parquet"), + Seq( + Row(2, Seq(Row(1, Seq(Row(3))))), + Row(5, Seq(Row(4, Seq(Row(6))))), + Row(8, Seq(Row(7, Seq(Row(9))))))) + } + + test("unannotated array of string") { + checkAnswer( + readParquetProtobufFile("proto-repeated-string.parquet"), + Seq( + Row(Seq("hello", "world")), + Row(Seq("good", "bye")), + Row(Seq("one", "two", "three")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala new file mode 100644 index 000000000000..b7b70c2bbbd5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * A test suite that tests various Parquet queries. + */ +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { + + test("simple select queries") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) + } + ctx.catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) + } + ctx.catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = ctx.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + + test("Enabling/disabling schema merging") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + testSchemaMerging(3) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + testSchemaMerging(2) + } + } + + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + + // Disables the global SQL option for schema merging + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + assertResult(2) { + // Disables schema merging via data source option + ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length + } + + assertResult(3) { + // Enables schema merging via data source option + ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length + } + } + } + } + + test("SPARK-9119 Decimal should be correctly written into parquet") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) + val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val df = sqlContext.createDataFrame(rowRDD, schema) + df.write.parquet(basePath) + + val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + assert(Decimal("67123.45") === Decimal(decimal)) + } + } + + test("SPARK-10005 Schema merging for nested struct") { + val sqlContext = _sqlContext + import sqlContext.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + def append(df: DataFrame): Unit = { + df.write.mode(SaveMode.Append).parquet(path) + } + + // Note that both the following two DataFrames contain a single struct column with multiple + // nested fields. + append((1 to 2).map(i => Tuple1((i, i))).toDF()) + append((1 to 2).map(i => Tuple1((i, i, i))).toDF()) + + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer( + sqlContext.read.option("mergeSchema", "true").parquet(path), + Seq( + Row(Row(1, 1, null)), + Row(Row(2, 2, null)), + Row(Row(1, 1, 1)), + Row(Row(2, 2, 2)))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala new file mode 100644 index 000000000000..9dcbc1a047be --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -0,0 +1,944 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.parquet.schema.MessageTypeParser + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { + + /** + * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. + */ + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + actual.checkContains(expected) + expected.checkContains(actual) + } + } + + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( + "basic types", + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + | optional binary _6; + |} + """.stripMargin, + binaryAsString = false) + + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( + "logical integral types", + """ + |message root { + | required int32 _1 (INT_8); + | required int32 _2 (INT_16); + | required int32 _3 (INT_32); + | required int64 _4 (INT_64); + | optional int32 _5 (DATE); + |} + """.stripMargin) + + testSchemaInference[Tuple1[String]]( + "string", + """ + |message root { + | optional binary _1 (UTF8); + |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group bag { + | optional int32 array_element; + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Pair[Int, String]]]( + "struct", + """ + |message root { + | optional group _1 { + | required int32 _1; + | optional binary _2 (UTF8); + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", + """ + |message root { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 key; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group bag { + | optional group array_element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", + """ + |message root { + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} + +class ParquetSchemaSuite extends ParquetSchemaTest { + test("DataType string parser compatibility") { + // This is the generated string from previous versions of the Spark SQL, using the following: + // val schema = StructType(List( + // StructField("c1", IntegerType, false), + // StructField("c2", BinaryType, true))) + val caseClassString = + "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + + // scalastyle:off + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" + // scalastyle:on + + val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) + val fromJson = ParquetTypesConverter.convertFromString(jsonString) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + assert(a.nullable === b.nullable) + } + } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 7 - " + + "parquet-protobuf primitive lists", + new StructType() + .add("f1", ArrayType(IntegerType, containsNull = false), nullable = false), + """message root { + | repeated int32 f1; + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 8 - " + + "parquet-protobuf non-primitive lists", + { + val elementType = + new StructType() + .add("c1", StringType, nullable = true) + .add("c2", IntegerType, nullable = false) + + new StructType() + .add("f1", ArrayType(elementType, containsNull = false), nullable = false) + }, + """message root { + | repeated group f1 { + | optional binary c1 (UTF8); + | required int32 c2; + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 array_element; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala new file mode 100644 index 000000000000..5dbc7d1630f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} + +/** + * A helper trait that provides convenient facilities for Parquet testing. + * + * NOTE: Considering classes `Tuple1` ... `Tuple22` all extend `Product`, it would be more + * convenient to use tuples rather than special case classes when writing test cases/suites. + * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. + */ +private[sql] trait ParquetTest extends SQLTestUtils { + protected def _sqlContext: SQLContext + + /** + * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withParquetFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. + */ + protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) + } + + /** + * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Parquet file will be dropped/deleted after `f` returns. + */ + protected def withParquetTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withParquetDataFrame(data) { df => + _sqlContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala new file mode 100644 index 000000000000..b789c5a106e5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + import ParquetCompatibilityTest._ + + private val parquetFilePath = + Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") + + test("Read Parquet file generated by parquet-thrift") { + logInfo( + s"""Schema of the Parquet file written by parquet-thrift: + |${readParquetSchema(parquetFilePath.toString)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") + + Row( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + suits(i % 4), + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i.toByte: java.lang.Byte), + nullable((i + 1).toShort: java.lang.Short), + nullable(i + 2: Integer), + nullable((i * 10).toLong: java.lang.Long), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i"), + nullable(s"val_$i"), + nullable(suits(i % 4)), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4e9472c60249..22189477d277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,17 +17,12 @@ package org.apache.spark.sql.execution.debug -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { -class DebuggingSuite extends FunSuite { test("DataFrame.debug()") { testData.debug() } - - test("DataFrame.typeCheck()") { - testData.typeCheck() - } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 000000000000..53a0e53fd771 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,87 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.joins + +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators with unsafe enabled. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + private var sc: SparkContext = null + private var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + sc.stop() + sc = null + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1ac..4c9187a9a710 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,47 +17,121 @@ package org.apache.spark.sql.execution.joins -import org.scalatest.FunSuite +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} -import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends FunSuite { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { - override def apply(row: Row): Row = row + override def apply(row: InternalRow): InternalRow = row } test("GeneralHashedRelation") { - val data = Array(Row(0), Row(1), Row(2), Row(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(InternalRow(10)) === null) - val data2 = CompactBuffer[Row](data(2)) + val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { - val data = Array(Row(0), Row(1), Row(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) + assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(Row(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) + } + + test("UnsafeHashedRelation") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val toUnsafe = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafe(_).copy()).toArray + + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(toUnsafe(InternalRow(10))) === null) + + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(toUnsafe(InternalRow(10))) === null) + assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) + } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 000000000000..cc649b9bd4c4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val myUpperCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + private lazy val myLowerCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + private lazy val myTestData = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testInnerJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } + + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 000000000000..a1a617d7b739 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} + +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testOuterJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + joinType: JoinType, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + } + } + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala new file mode 100644 index 000000000000..baa86e320d98 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testLeftSemiJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala new file mode 100644 index 000000000000..80006bf077fe --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -0,0 +1,577 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") + val f = () => { + l += 1L + l.add(1L) + } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = ctx.sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = ctx.listener.executionIdToData.keySet + df.collect() + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = ctx.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = ctx.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + assert(expectedMetrics === actualMetrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) + } + } + + test("TungstenProject metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) + } + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("Aggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortBasedAggregate metrics") { + // Because SortBasedAggregate may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> + // SortBasedAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) + // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 3L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("TungstenAggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + } + + test("BroadcastHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("ShuffledHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> ShuffledHashOuterJoin(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(df2, $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + + val df4 = df1.join(df2, $"key" === $"key2", "outer") + testSparkPlanMetrics(df4, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 7L))) + ) + } + } + + test("BroadcastHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + } + + test("BroadcastNestedLoopJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("CartesianProduct metrics") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 12L, // right is read 6 times + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = ctx.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = ctx.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = ctx.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq(2L)) + } + } + +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala new file mode 100644 index 000000000000..80d1e8895694 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.Properties + +import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.scheduler._ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.test.SharedSQLContext + +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + private def createTestDataFrame: DataFrame = { + Seq( + (1, 1), + (2, 2) + ).toDF().filter("_1 > 1") + } + + private def createProperties(executionId: Long): Properties = { + val properties = new Properties() + properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString) + properties + } + + private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = new StageInfo( + stageId = stageId, + attemptId = attemptId, + // The following fields are not used in tests + name = "", + numTasks = 0, + rddInfos = Nil, + parentIds = Nil, + details = "" + ) + + private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + taskId = taskId, + attempt = attempt, + // The following fields are not used in tests + index = 0, + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false + ) + + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { + val metrics = new TaskMetrics + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) + metrics.updateAccumulators() + metrics + } + + test("basic") { + val listener = new SQLListener(ctx) + val executionId = 0 + val df = createTestDataFrame + val accumulatorIds = + SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + // Assume all accumulators are long + var accumulatorValue = 0L + val accumulatorUpdates = accumulatorIds.map { id => + accumulatorValue += 1L + (id, accumulatorValue) + }.toMap + + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + + val executionUIData = listener.executionIdToData(0) + + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq( + createStageInfo(0, 0), + createStageInfo(1, 0) + ), + createProperties(executionId))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + + assert(listener.getExecutionMetrics(0).isEmpty) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3)) + + // Retrying a stage should reset the metrics + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + // Ignore the task end for the first attempt + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5)) + + // Summit a new stage + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + + assert(executionUIData.runningJobs === Seq(0)) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs.isEmpty) + + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + } + + test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { + val listener = new SQLListener(ctx) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { + val listener = new SQLListener(ctx) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + listener.onJobStart(SparkListenerJobStart( + jobId = 1, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 1, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.sorted === Seq(0, 1)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before onJobEnd(JobFailed)") { + val listener = new SQLListener(ctx) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq.empty, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobFailed(new RuntimeException("Oops")) + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs === Seq(0)) + } + + ignore("no memory leak") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } + } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala deleted file mode 100644 index f332cb389f33..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import scala.collection.mutable.MutableList - -import com.spotify.docker.client._ - -/** - * A factory and morgue for DockerClient objects. In the DockerClient we use, - * calling close() closes the desired DockerClient but also renders all other - * DockerClients inoperable. This is inconvenient if we have more than one - * open, such as during tests. - */ -object DockerClientFactory { - var numClients: Int = 0 - val zombies = new MutableList[DockerClient]() - - def get(): DockerClient = { - this.synchronized { - numClients = numClients + 1 - DefaultDockerClient.fromEnv.build() - } - } - - def close(dc: DockerClient) { - this.synchronized { - numClients = numClients - 1 - zombies += dc - if (numClients == 0) { - zombies.foreach(_.close()) - zombies.clear() - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d25c1390db15..0edac0848c3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,40 +18,73 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import org.apache.spark.sql.test._ -import org.scalatest.{FunSuite, BeforeAndAfter} import java.sql.DriverManager -import TestSQLContext._ +import java.util.{Calendar, GregorianCalendar, Properties} + +import org.h2.jdbc.JdbcSQLException +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ -class JDBCSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testH2Dialect = new JdbcDialect { + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + Some(StringType) + } + before { - Class.forName("org.h2.Driver") - conn = DriverManager.getConnection(url) + Utils.classForName("org.h2.Driver") + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() - conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() - conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate() + conn.prepareStatement( + "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() sql( s""" |CREATE TEMPORARY TABLE foobar |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + sql( + s""" + |CREATE TEMPORARY TABLE fetchtwo + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | fetchSize '2') """.stripMargin.replaceAll("\n", " ")) sql( s""" |CREATE TEMPORARY TABLE parts |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', - |partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " @@ -65,12 +98,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE inttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.INTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), " + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate() - var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") + val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") stmt.setBytes(1, testBytes) stmt.setString(2, "Sensitive") stmt.setString(3, "Insensitive") @@ -82,23 +115,25 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE strtypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.STRTYPES') + |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" ).executeUpdate() conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate() + conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + + "null, '2002-02-20 11:22:33.543543543')").executeUpdate() conn.commit() sql( s""" |CREATE TEMPORARY TABLE timetypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES') + |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))" + conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() conn.prepareStatement("insert into test.flttypes values (" + "1.0000000000000002220446049250313080847263336181640625, " @@ -109,7 +144,24 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE flttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + conn.prepareStatement( + s""" + |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), + |f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP, + |m DOUBLE, n REAL, o DECIMAL(38, 18)) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement("insert into test.nulltypes values (" + + "null, null, null, null, null, null, null, null, null, " + + "null, null, null, null, null, null)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE nulltypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. @@ -120,29 +172,52 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT *") { - assert(sql("SELECT * FROM foobar").collect().size == 3) + assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + } + + test("SELECT * WHERE (quoted strings)") { + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) - assert(names.size == 3) + assert(names.size === 3) + assert(names(0).equals("fred")) + assert(names(1).equals("joe 'foo' \"bar\"")) + assert(names(2).equals("mary")) + } + + test("SELECT first field when fetchSize is two") { + val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) + assert(names.size === 3) assert(names(0).equals("fred")) - assert(names(1).equals("joe")) + assert(names(1).equals("joe 'foo' \"bar\"")) assert(names(2).equals("mary")) } test("SELECT second field") { val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) + } + + test("SELECT second field when fetchSize is two") { + val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("SELECT * partitioned") { @@ -150,46 +225,72 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT WHERE (simple predicates) partitioned") { - assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) } test("SELECT second field partitioned") { val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) + } + + test("Register JDBC query with renamed fields") { + // Regression test for bug SPARK-7345 + sql( + s""" + |CREATE TEMPORARY TABLE renamed + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(select NAME as NAME1, NAME as NAME2 from TEST.PEOPLE)', + |user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + val df = sql("SELECT * FROM renamed") + assert(df.schema.fields.size == 2) + assert(df.schema.fields(0).name == "NAME1") + assert(df.schema.fields(1).name == "NAME2") } test("Basic API") { - assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE").collect.size == 3) + assert(ctx.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + } + + test("Basic API with FetchSize") { + val properties = new Properties + properties.setProperty("fetchSize", "2") + assert(ctx.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { - val parts = JDBCPartitioningInfo("THEID", 0, 4, 3) - assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3) + assert( + ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3) + assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + .collect().length === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size == 1) - assert(rows(0).getInt(0) == 1) - assert(rows(0).getBoolean(1) == false) - assert(rows(0).getInt(2) == 3) - assert(rows(0).getInt(3) == 4) - assert(rows(0).getLong(4) == 1234567890123L) + assert(rows.length === 1) + assert(rows(0).getInt(0) === 1) + assert(rows(0).getBoolean(1) === false) + assert(rows(0).getInt(2) === 3) + assert(rows(0).getInt(3) === 4) + assert(rows(0).getLong(4) === 1234567890123L) } test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size == 1) + assert(rows.length === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -209,40 +310,137 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("H2 time types") { val rows = sql("SELECT * FROM timetypes").collect() - assert(rows(0).getAs[java.sql.Timestamp](0).getHours == 12) - assert(rows(0).getAs[java.sql.Timestamp](0).getMinutes == 34) - assert(rows(0).getAs[java.sql.Timestamp](0).getSeconds == 56) - assert(rows(0).getAs[java.sql.Date](1).getYear == 96) - assert(rows(0).getAs[java.sql.Date](1).getMonth == 0) - assert(rows(0).getAs[java.sql.Date](1).getDate == 1) - assert(rows(0).getAs[java.sql.Timestamp](2).getYear == 102) - assert(rows(0).getAs[java.sql.Timestamp](2).getMonth == 1) - assert(rows(0).getAs[java.sql.Timestamp](2).getDate == 20) - assert(rows(0).getAs[java.sql.Timestamp](2).getHours == 11) - assert(rows(0).getAs[java.sql.Timestamp](2).getMinutes == 22) - assert(rows(0).getAs[java.sql.Timestamp](2).getSeconds == 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543) + val cal = new GregorianCalendar(java.util.Locale.ROOT) + cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) + assert(cal.get(Calendar.HOUR_OF_DAY) === 12) + assert(cal.get(Calendar.MINUTE) === 34) + assert(cal.get(Calendar.SECOND) === 56) + cal.setTime(rows(0).getAs[java.sql.Timestamp](1)) + assert(cal.get(Calendar.YEAR) === 1996) + assert(cal.get(Calendar.MONTH) === 0) + assert(cal.get(Calendar.DAY_OF_MONTH) === 1) + cal.setTime(rows(0).getAs[java.sql.Timestamp](2)) + assert(cal.get(Calendar.YEAR) === 2002) + assert(cal.get(Calendar.MONTH) === 1) + assert(cal.get(Calendar.DAY_OF_MONTH) === 20) + assert(cal.get(Calendar.HOUR) === 11) + assert(cal.get(Calendar.MINUTE) === 22) + assert(cal.get(Calendar.SECOND) === 33) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543000) + } + + test("test DATE types") { + val rows = ctx.read.jdbc( + urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(rows(1).getAs[java.sql.Date](1) === null) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + + test("test DATE types in cache") { + val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().registerTempTable("mycached_date") + val cachedRows = sql("select * from mycached_date").collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + + test("test types for null value") { + val rows = ctx.read.jdbc( + urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() + assert((0 to 14).forall(i => rows(0).isNullAt(i))) } test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==. - assert(rows(0).getAs[BigDecimal](2) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) + assert(rows(0).getDouble(0) === 1.00000000000000022) + assert(rows(0).getDouble(1) === 1.00000011920928955) + assert(rows(0).getAs[BigDecimal](2) === + new BigDecimal("123456789012345.543215432154321000")) + assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) + val result = sql("SELECT C FROM flttypes where C > C - 1").collect() + assert(result(0).getAs[BigDecimal](0) === + new BigDecimal("123456789012345.543215432154321000")) } - test("SQL query as table name") { sql( s""" |CREATE TEMPORARY TABLE hack |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)') + |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', + | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() - assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } + + test("Pass extra properties via OPTIONS") { + // We set rowId to false during setup, which means that _ROWID_ column should be absent from + // all tables. If rowId is true (default), the query below doesn't throw an exception. + intercept[JdbcSQLException] { + sql( + s""" + |CREATE TEMPORARY TABLE abc + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + } + + test("Remap types via JdbcDialects") { + JdbcDialects.registerDialect(testH2Dialect) + val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) + val rows = df.collect() + assert(rows(0).get(0).isInstanceOf[String]) + assert(rows(0).get(1).isInstanceOf[String]) + JdbcDialects.unregisterDialect(testH2Dialect) + } + + test("Default jdbc dialect registration") { + assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) + assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("test.invalid") == NoopDialect) + } + + test("quote column names by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + + val columns = Seq("abc", "key") + val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) + val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + assert(MySQLColumns === Seq("`abc`", "`key`")) + assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + } + + test("Dialect unregister") { + JdbcDialects.registerDialect(testH2Dialect) + JdbcDialects.unregisterDialect(testH2Dialect) + assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) + } + + test("Aggregated dialects") { + val agg = new AggregatedDialect(List(new JdbcDialect { + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + if (sqlType % 2 == 0) { + Some(LongType) + } else { + None + } + }, testH2Dialect)) + assert(agg.canHandle("jdbc:h2:xxx")) + assert(!agg.canHandle("jdbc:h2")) + assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e581ac9b50c2..5dc3a2c07b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,91 +17,140 @@ package org.apache.spark.sql.jdbc -import java.math.BigDecimal -import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ -import org.apache.spark.sql.test._ -import org.scalatest.{FunSuite, BeforeAndAfter} import java.sql.DriverManager -import TestSQLContext._ +import java.util.Properties + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { -class JDBCWriteSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null + val url1 = "jdbc:h2:mem:testdb3" + var conn1: java.sql.Connection = null + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() + + conn1 = DriverManager.getConnection(url1, properties) + conn1.prepareStatement("create schema test").executeUpdate() + conn1.prepareStatement("drop table if exists test.people").executeUpdate() + conn1.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn1.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() + conn1.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() + conn1.prepareStatement("drop table if exists test.people1").executeUpdate() + conn1.prepareStatement( + "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn1.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE PEOPLE + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + sql( + s""" + |CREATE TEMPORARY TABLE PEOPLE1 + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) } after { conn.close() + conn1.close() } - val sc = TestSQLContext.sparkContext + private lazy val sc = ctx.sparkContext - val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - val arr1x2 = Array[Row](Row.apply("fred", 3)) - val schema2 = StructType( + private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) + private lazy val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) - val schema3 = StructType( + private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + private lazy val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").collect()(0).length) + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) - srdd.createJDBCTable(url, "TEST.DROPTEST", false) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count) - assert(3 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length) + df.write.jdbc(url1, "TEST.DROPTEST", properties) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) - srdd2.createJDBCTable(url, "TEST.DROPTEST", true) - assert(1 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) + assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) - srdd.createJDBCTable(url, "TEST.APPENDTEST", false) - srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false) - assert(3 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").collect()(0).length) + df.write.jdbc(url, "TEST.APPENDTEST", new Properties) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) + assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) - srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false) - srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true) - assert(1 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").count) - assert(2 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").collect()(0).length) + df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) + assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2) - val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) + df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { - srdd2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) } } + test("INSERT to JDBC Datasource") { + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } + + test("INSERT to JDBC Datasource with overwrite") { + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala deleted file mode 100644 index 89920f2650c3..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} -import com.spotify.docker.client.{DefaultDockerClient, DockerClient} -import com.spotify.docker.client.messages.ContainerConfig -import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore} - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ -import org.apache.spark.sql._ -import org.apache.spark.sql.test._ -import TestSQLContext._ - -import org.apache.spark.sql.jdbc._ - -class MySQLDatabase { - val docker: DockerClient = DockerClientFactory.get() - val containerId = { - println("Pulling mysql") - docker.pull("mysql") - println("Configuring container") - val config = (ContainerConfig.builder().image("mysql") - .env("MYSQL_ROOT_PASSWORD=rootpass") - .build()) - println("Creating container") - val id = docker.createContainer(config).id - println("Starting container " + id) - docker.startContainer(id) - id - } - val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - - def close() { - try { - println("Killing container " + containerId) - docker.killContainer(containerId) - println("Removing container " + containerId) - docker.removeContainer(containerId) - println("Closing docker client") - DockerClientFactory.close(docker) - } catch { - case e: Exception => { - println(e) - println("You may need to clean this up manually.") - throw e - } - } - } -} - -@Ignore class MySQLIntegration extends FunSuite with BeforeAndAfterAll { - var ip: String = null - - def url(ip: String): String = url(ip, "mysql") - def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass" - - def waitForDatabase(ip: String, maxMillis: Long) { - println("Waiting for database to start up.") - val before = System.currentTimeMillis() - var lastException: java.sql.SQLException = null - while (true) { - if (System.currentTimeMillis() > before + maxMillis) { - throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException) - } - try { - val conn = java.sql.DriverManager.getConnection(url(ip)) - conn.close() - println("Database is up.") - return; - } catch { - case e: java.sql.SQLException => { - lastException = e - java.lang.Thread.sleep(250) - } - } - } - } - - def setupDatabase(ip: String) { - val conn = java.sql.DriverManager.getConnection(url(ip)) - try { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate() - conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate() - conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate() - - conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), " - + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " - + "dbl DOUBLE)").executeUpdate() - conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', " - + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " - + "42.75, 1.0000000000000002)").executeUpdate() - - conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " - + "yr YEAR)").executeUpdate() - conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', " - + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() - - // TODO: Test locale conversion for strings. - conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " - + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" - ).executeUpdate() - conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() - } finally { - conn.close() - } - } - - var db: MySQLDatabase = null - - override def beforeAll() { - // If you load the MySQL driver here, DriverManager will deadlock. The - // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres - // and H2 drivers. - //Class.forName("com.mysql.jdbc.Driver") - - db = new MySQLDatabase() - waitForDatabase(db.ip, 60000) - setupDatabase(db.ip) - ip = db.ip - } - - override def afterAll() { - db.close() - } - - test("Basic test") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "tbl") - val rows = rdd.collect - assert(rows.length == 2) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 2) - assert(types(0).equals("class java.lang.Integer")) - assert(types(1).equals("class java.lang.String")) - } - - test("Numeric types") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 9) - println(types(1)) - assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Long")) - assert(types(2).equals("class java.lang.Integer")) - assert(types(3).equals("class java.lang.Integer")) - assert(types(4).equals("class java.lang.Integer")) - assert(types(5).equals("class java.lang.Long")) - assert(types(6).equals("class java.math.BigDecimal")) - assert(types(7).equals("class java.lang.Double")) - assert(types(8).equals("class java.lang.Double")) - assert(rows(0).getBoolean(0) == false) - assert(rows(0).getLong(1) == 0x225) - assert(rows(0).getInt(2) == 17) - assert(rows(0).getInt(3) == 77777) - assert(rows(0).getInt(4) == 123456789) - assert(rows(0).getLong(5) == 123456789012345L) - val bd = new BigDecimal("123456789012345.12345678901234500000") - assert(rows(0).getAs[BigDecimal](6).equals(bd)) - assert(rows(0).getDouble(7) == 42.75) - assert(rows(0).getDouble(8) == 1.0000000000000002) - } - - test("Date types") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 5) - assert(types(0).equals("class java.sql.Date")) - assert(types(1).equals("class java.sql.Timestamp")) - assert(types(2).equals("class java.sql.Timestamp")) - assert(types(3).equals("class java.sql.Timestamp")) - assert(types(4).equals("class java.sql.Date")) - assert(rows(0).getAs[Date](0).equals(new Date(91, 10, 9))) - assert(rows(0).getAs[Timestamp](1).equals(new Timestamp(70, 0, 1, 13, 31, 24, 0))) - assert(rows(0).getAs[Timestamp](2).equals(new Timestamp(96, 0, 1, 1, 23, 45, 0))) - assert(rows(0).getAs[Timestamp](3).equals(new Timestamp(109, 1, 13, 23, 31, 30, 0))) - assert(rows(0).getAs[Date](4).equals(new Date(101, 0, 1))) - } - - test("String types") { - val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 9) - assert(types(0).equals("class java.lang.String")) - assert(types(1).equals("class java.lang.String")) - assert(types(2).equals("class java.lang.String")) - assert(types(3).equals("class java.lang.String")) - assert(types(4).equals("class java.lang.String")) - assert(types(5).equals("class java.lang.String")) - assert(types(6).equals("class [B")) - assert(types(7).equals("class [B")) - assert(types(8).equals("class [B")) - assert(rows(0).getString(0).equals("the")) - assert(rows(0).getString(1).equals("quick")) - assert(rows(0).getString(2).equals("brown")) - assert(rows(0).getString(3).equals("fox")) - assert(rows(0).getString(4).equals("jumps")) - assert(rows(0).getString(5).equals("over")) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) - } - - test("Basic write test") { - val rdd1 = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers") - val rdd2 = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates") - val rdd3 = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings") - rdd1.createJDBCTable(url(ip, "foo"), "numberscopy", false) - rdd2.createJDBCTable(url(ip, "foo"), "datescopy", false) - rdd3.createJDBCTable(url(ip, "foo"), "stringscopy", false) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala deleted file mode 100644 index c174d7adb720..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.math.BigDecimal -import org.apache.spark.sql.test._ -import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore} -import java.sql.DriverManager -import TestSQLContext._ -import com.spotify.docker.client.{DefaultDockerClient, DockerClient} -import com.spotify.docker.client.messages.ContainerConfig - -class PostgresDatabase { - val docker: DockerClient = DockerClientFactory.get() - val containerId = { - println("Pulling postgres") - docker.pull("postgres") - println("Configuring container") - val config = (ContainerConfig.builder().image("postgres") - .env("POSTGRES_PASSWORD=rootpass") - .build()) - println("Creating container") - val id = docker.createContainer(config).id - println("Starting container " + id) - docker.startContainer(id) - id - } - val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - - def close() { - try { - println("Killing container " + containerId) - docker.killContainer(containerId) - println("Removing container " + containerId) - docker.removeContainer(containerId) - println("Closing docker client") - DockerClientFactory.close(docker) - } catch { - case e: Exception => { - println(e) - println("You may need to clean this up manually.") - throw e - } - } - } -} - -@Ignore class PostgresIntegration extends FunSuite with BeforeAndAfterAll { - lazy val db = new PostgresDatabase() - - def url(ip: String) = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" - - def waitForDatabase(ip: String, maxMillis: Long) { - val before = System.currentTimeMillis() - var lastException: java.sql.SQLException = null - while (true) { - if (System.currentTimeMillis() > before + maxMillis) { - throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", - lastException) - } - try { - val conn = java.sql.DriverManager.getConnection(url(ip)) - conn.close() - println("Database is up.") - return; - } catch { - case e: java.sql.SQLException => { - lastException = e - java.lang.Thread.sleep(250) - } - } - } - } - - def setupDatabase(ip: String) { - val conn = DriverManager.getConnection(url(ip)) - try { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " - + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() - conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " - + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() - } finally { - conn.close() - } - } - - override def beforeAll() { - println("Waiting for database to start up.") - waitForDatabase(db.ip, 60000) - println("Setting up database.") - setupDatabase(db.ip) - } - - override def afterAll() { - db.close() - } - - test("Type mapping for various types") { - val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar") - val rows = rdd.collect - assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 10) - assert(types(0).equals("class java.lang.String")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Double")) - assert(types(3).equals("class java.lang.Long")) - assert(types(4).equals("class java.lang.Boolean")) - assert(types(5).equals("class [B")) - assert(types(6).equals("class [B")) - assert(types(7).equals("class java.lang.Boolean")) - assert(types(8).equals("class java.lang.String")) - assert(types(9).equals("class java.lang.String")) - assert(rows(0).getString(0).equals("hello")) - assert(rows(0).getInt(1) == 42) - assert(rows(0).getDouble(2) == 1.25) - assert(rows(0).getLong(3) == 123456789012345L) - assert(rows(0).getBoolean(4) == false) - // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49,48,48,48,49,48,48,49,48,49))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) - assert(rows(0).getBoolean(7) == true) - assert(rows(0).getString(8) == "172.16.0.42") - assert(rows(0).getString(9) == "192.168.0.0/16") - } - - test("Basic write test") { - val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar") - rdd.createJDBCTable(url(db.ip), "public.barcopy", false) - // Test only that it doesn't bomb out. - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala deleted file mode 100644 index cb615388da0c..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ /dev/null @@ -1,927 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} - -class JsonSuite extends QueryTest { - import org.apache.spark.sql.json.TestJsonData._ - - TestJsonData - - test("Type promotion") { - def checkTypePromotion(expected: Any, actual: Any) { - assert(expected.getClass == actual.getClass, - s"Failed to promote ${actual.getClass} to ${expected.getClass}.") - assert(expected == actual, - s"Promoted value ${actual}(${actual.getClass}) does not equal the expected value " + - s"${expected}(${expected.getClass}).") - } - - val intNumber: Int = 2147483647 - checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) - checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) - checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) - checkTypePromotion( - Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) - - val longNumber: Long = 9223372036854775807L - checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) - checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) - checkTypePromotion( - Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) - - val doubleNumber: Double = 1.7976931348623157E308d - checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) - - checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), - enforceCorrectType(intNumber.toLong, TimestampType)) - val strTime = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) - - val strDate = "2014-10-15" - checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) - - val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) - val ISO8601Time2 = "1970-01-01T02:00:01-01:00" - checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) - } - - test("Get compatible type") { - def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2) - assert(actual == expected, - s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1) - assert(actual == expected, - s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - } - - // NullType - checkDataType(NullType, BooleanType, BooleanType) - checkDataType(NullType, IntegerType, IntegerType) - checkDataType(NullType, LongType, LongType) - checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(NullType, StringType, StringType) - checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) - checkDataType(NullType, StructType(Nil), StructType(Nil)) - checkDataType(NullType, NullType, NullType) - - // BooleanType - checkDataType(BooleanType, BooleanType, BooleanType) - checkDataType(BooleanType, IntegerType, StringType) - checkDataType(BooleanType, LongType, StringType) - checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType.Unlimited, StringType) - checkDataType(BooleanType, StringType, StringType) - checkDataType(BooleanType, ArrayType(IntegerType), StringType) - checkDataType(BooleanType, StructType(Nil), StringType) - - // IntegerType - checkDataType(IntegerType, IntegerType, IntegerType) - checkDataType(IntegerType, LongType, LongType) - checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(IntegerType, StringType, StringType) - checkDataType(IntegerType, ArrayType(IntegerType), StringType) - checkDataType(IntegerType, StructType(Nil), StringType) - - // LongType - checkDataType(LongType, LongType, LongType) - checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(LongType, StringType, StringType) - checkDataType(LongType, ArrayType(IntegerType), StringType) - checkDataType(LongType, StructType(Nil), StringType) - - // DoubleType - checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(DoubleType, StringType, StringType) - checkDataType(DoubleType, ArrayType(IntegerType), StringType) - checkDataType(DoubleType, StructType(Nil), StringType) - - // DoubleType - checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(DecimalType.Unlimited, StringType, StringType) - checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) - checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) - - // StringType - checkDataType(StringType, StringType, StringType) - checkDataType(StringType, ArrayType(IntegerType), StringType) - checkDataType(StringType, StructType(Nil), StringType) - - // ArrayType - checkDataType(ArrayType(IntegerType), ArrayType(IntegerType), ArrayType(IntegerType)) - checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) - checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) - checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) - checkDataType( - ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true)) - checkDataType( - ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true)) - checkDataType( - ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) - checkDataType( - ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, true)) - checkDataType( - ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) - checkDataType( - ArrayType(IntegerType, false), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) - - // StructType - checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) - checkDataType( - StructType(StructField("f1", IntegerType, true) :: Nil), - StructType(StructField("f1", IntegerType, true) :: Nil), - StructType(StructField("f1", IntegerType, true) :: Nil)) - checkDataType( - StructType(StructField("f1", IntegerType, true) :: Nil), - StructType(Nil), - StructType(StructField("f1", IntegerType, true) :: Nil)) - checkDataType( - StructType( - StructField("f1", IntegerType, true) :: - StructField("f2", IntegerType, true) :: Nil), - StructType(StructField("f1", LongType, true) :: Nil) , - StructType( - StructField("f1", LongType, true) :: - StructField("f2", IntegerType, true) :: Nil)) - checkDataType( - StructType( - StructField("f1", IntegerType, true) :: Nil), - StructType( - StructField("f2", IntegerType, true) :: Nil), - StructType( - StructField("f1", IntegerType, true) :: - StructField("f2", IntegerType, true) :: Nil)) - checkDataType( - StructType( - StructField("f1", IntegerType, true) :: Nil), - DecimalType.Unlimited, - StringType) - } - - test("Complex field and type inferring with null in sampling") { - val jsonDF = jsonRDD(jsonNullStruct) - val expectedSchema = StructType( - StructField("headers", StructType( - StructField("Charset", StringType, true) :: - StructField("Host", StringType, true) :: Nil) - , true) :: - StructField("ip", StringType, true) :: - StructField("nullstr", StringType, true):: Nil) - - assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select nullstr, headers.Host from jsonTable"), - Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row(null, null)) - ) - } - - test("Primitive field and type inferring") { - val jsonDF = jsonRDD(primitiveFieldAndType) - - val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: - StructField("boolean", BooleanType, true) :: - StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: - StructField("long", LongType, true) :: - StructField("null", StringType, true) :: - StructField("string", StringType, true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select * from jsonTable"), - Row(new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") - ) - } - - test("Complex field and type inferring") { - val jsonDF = jsonRDD(complexFieldAndType1) - - val expectedSchema = StructType( - StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: - StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: - StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: - StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: - StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: - StructField("arrayOfLong", ArrayType(LongType, false), true) :: - StructField("arrayOfNull", ArrayType(StringType, true), true) :: - StructField("arrayOfString", ArrayType(StringType, false), true) :: - StructField("arrayOfStruct", ArrayType( - StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: - StructField("field3", StringType, true) :: Nil), false), true) :: - StructField("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: - StructField("structWithArrayFields", StructType( - StructField("field1", ArrayType(IntegerType, false), true) :: - StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - - // Access elements of a primitive array. - checkAnswer( - sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), - Row("str1", "str2", null) - ) - - // Access an array of null values. - checkAnswer( - sql("select arrayOfNull from jsonTable"), - Row(Seq(null, null, null, null)) - ) - - // Access elements of a BigInteger array (we use DecimalType internally). - checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), - Row(new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) - ) - - // Access elements of an array of arrays. - checkAnswer( - sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), - Row(Seq("1", "2", "3"), Seq("str1", "str2")) - ) - - // Access elements of an array of arrays. - checkAnswer( - sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), - Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) - ) - - // Access elements of an array inside a filed with the type of ArrayType(ArrayType). - checkAnswer( - sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), - Row("str2", 2.1) - ) - - // Access elements of an array of structs. - checkAnswer( - sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + - "from jsonTable"), - Row( - Row(true, "str1", null), - Row(false, null, null), - Row(null, null, null), - null) - ) - - // Access a struct and fields inside of it. - checkAnswer( - sql("select struct, struct.field1, struct.field2 from jsonTable"), - Row( - Row(true, new java.math.BigDecimal("92233720368547758070")), - true, - new java.math.BigDecimal("92233720368547758070")) :: Nil - ) - - // Access an array field of a struct. - checkAnswer( - sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), - Row(Seq(4, 5, 6), Seq("str1", "str2")) - ) - - // Access elements of an array field of a struct. - checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), - Row(5, null) - ) - } - - ignore("Complex field and type inferring (Ignored)") { - val jsonDF = jsonRDD(complexFieldAndType1) - jsonDF.registerTempTable("jsonTable") - - // Right now, "field1" and "field2" are treated as aliases. We should fix it. - checkAnswer( - sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - Row(true, "str1") - ) - - // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. - // Getting all values of a specific field from an array of structs. - checkAnswer( - sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - Row(Seq(true, false), Seq("str1", null)) - ) - } - - test("Type conflict in primitive field values") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) - - val expectedSchema = StructType( - StructField("num_bool", StringType, true) :: - StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.Unlimited, true) :: - StructField("num_num_3", DoubleType, true) :: - StructField("num_str", StringType, true) :: - StructField("str_bool", StringType, true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select * from jsonTable"), - Row("true", 11L, null, 1.1, "13.1", "str1") :: - Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil - ) - - // Number and Boolean conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_bool - 10 from jsonTable where num_bool > 11"), - Row(2) - ) - - // Widening to LongType - checkAnswer( - sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), - Row(21474836370L) :: Row(21474836470L) :: Nil - ) - - checkAnswer( - sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), - Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil - ) - - // Widening to DecimalType - checkAnswer( - sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil - ) - - // Widening to DoubleType - checkAnswer( - sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), - Row(101.2) :: Row(21474836471.2) :: Nil - ) - - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(92233720368547758071.2) - ) - - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) - ) - - // String and Boolean conflict: resolve the type as string. - checkAnswer( - sql("select * from jsonTable where str_bool = 'str1'"), - Row("true", 11L, null, 1.1, "13.1", "str1") - ) - } - - ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) - jsonDF.registerTempTable("jsonTable") - - // Right now, the analyzer does not promote strings in a boolean expreesion. - // Number and Boolean conflict: resolve the type as boolean in this query. - checkAnswer( - sql("select num_bool from jsonTable where NOT num_bool"), - Row(false) - ) - - checkAnswer( - sql("select str_bool from jsonTable where NOT str_bool"), - Row(false) - ) - - // Right now, the analyzer does not know that num_bool should be treated as a boolean. - // Number and Boolean conflict: resolve the type as boolean in this query. - checkAnswer( - sql("select num_bool from jsonTable where num_bool"), - Row(true) - ) - - checkAnswer( - sql("select str_bool from jsonTable where str_bool"), - Row(false) - ) - - // The plan of the following DSL is - // Project [(CAST(num_str#65:4, DoubleType) + 1.2) AS num#78] - // Filter (CAST(CAST(num_str#65:4, DoubleType), DecimalType) > 92233720368547758060) - // ExistingRdd [num_bool#61,num_num_1#62L,num_num_2#63,num_num_3#64,num_str#65,str_bool#66] - // We should directly cast num_str to DecimalType and also need to do the right type promotion - // in the Project. - checkAnswer( - jsonDF. - where('num_str > BigDecimal("92233720368547758060")). - select(('num_str + 1.2).as("num")), - Row(new java.math.BigDecimal("92233720368547758061.2")) - ) - - // The following test will fail. The type of num_str is StringType. - // So, to evaluate num_str + 1.2, we first need to use Cast to convert the type. - // In our test data, one value of num_str is 13.1. - // The result of (CAST(num_str#65:4, DoubleType) + 1.2) for this value is 14.299999999999999, - // which is not 14.3. - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Row(14.3) :: Row(92233720368547758071.2) :: Nil - ) - } - - test("Type conflict in complex field values") { - val jsonDF = jsonRDD(complexFieldValueTypeConflict) - - val expectedSchema = StructType( - StructField("array", ArrayType(IntegerType, false), true) :: - StructField("num_struct", StringType, true) :: - StructField("str_array", StringType, true) :: - StructField("struct", StructType( - StructField("field", StringType, true) :: Nil), true) :: - StructField("struct_array", StringType, true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select * from jsonTable"), - Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: - Row(null, """{"field":false}""", null, null, "{}") :: - Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: - Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil - ) - } - - test("Type conflict in array elements") { - val jsonDF = jsonRDD(arrayElementTypeConflict) - - val expectedSchema = StructType( - StructField("array1", ArrayType(StringType, true), true) :: - StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: - StructField("array3", ArrayType(StringType, false), true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select * from jsonTable"), - Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: - Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: - Row(null, null, Seq("1", "2", "3")) :: Nil - ) - - // Treat an element as a number. - checkAnswer( - sql("select array1[0] + 1 from jsonTable where array1 is not null"), - Row(2) - ) - } - - test("Handling missing fields") { - val jsonDF = jsonRDD(missingFields) - - val expectedSchema = StructType( - StructField("a", BooleanType, true) :: - StructField("b", LongType, true) :: - StructField("c", ArrayType(IntegerType, false), true) :: - StructField("d", StructType( - StructField("field", BooleanType, true) :: Nil), true) :: - StructField("e", StringType, true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - } - - test("Loading a JSON dataset from a text file") { - val file = getTempFilePath("json") - val path = file.toString - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = jsonFile(path) - - val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: - StructField("boolean", BooleanType, true) :: - StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: - StructField("long", LongType, true) :: - StructField("null", StringType, true) :: - StructField("string", StringType, true) :: Nil) - - assert(expectedSchema === jsonDF.schema) - - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select * from jsonTable"), - Row(new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") - ) - } - - test("Loading a JSON dataset from a text file with SQL") { - val file = getTempFilePath("json") - val path = file.toString - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - - sql( - s""" - |CREATE TEMPORARY TABLE jsonTableSQL - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - - checkAnswer( - sql("select * from jsonTableSQL"), - Row(new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") - ) - } - - test("Applying schemas") { - val file = getTempFilePath("json") - val path = file.toString - primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - - val schema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: - StructField("boolean", BooleanType, true) :: - StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: - StructField("long", LongType, true) :: - StructField("null", StringType, true) :: - StructField("string", StringType, true) :: Nil) - - val jsonDF1 = jsonFile(path, schema) - - assert(schema === jsonDF1.schema) - - jsonDF1.registerTempTable("jsonTable1") - - checkAnswer( - sql("select * from jsonTable1"), - Row(new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") - ) - - val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) - - assert(schema === jsonDF2.schema) - - jsonDF2.registerTempTable("jsonTable2") - - checkAnswer( - sql("select * from jsonTable2"), - Row(new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") - ) - } - - test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = jsonRDD(complexFieldAndType2) - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - Row(true, "str1") - ) - checkAnswer( - sql( - """ - |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] - |from jsonTable - """.stripMargin), - Row("str2", 6) - ) - } - - test("SPARK-3390 Complex arrays") { - val jsonDF = jsonRDD(complexFieldAndType2) - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql( - """ - |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] - |from jsonTable - """.stripMargin), - Row(5, 7, 8) - ) - checkAnswer( - sql( - """ - |select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0], - |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 - |from jsonTable - """.stripMargin), - Row("str1", Nil, "str4", 2) - ) - } - - test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = jsonRDD(jsonArray) - jsonDF.registerTempTable("jsonTable") - - checkAnswer( - sql( - """ - |select a, b, c - |from jsonTable - """.stripMargin), - Row("str_a_1", null, null) :: - Row("str_a_2", null, null) :: - Row(null, "str_b_3", null) :: - Row("str_a_4", "str_b_4", "str_c_4") :: Nil - ) - } - - test("Corrupt records") { - // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val jsonDF = jsonRDD(corruptRecords) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("_unparsed", StringType, true) :: - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) - - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, "") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) - } - - test("SPARK-4068: nulls in arrays") { - val jsonDF = jsonRDD(nullsInArrays) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("field1", - ArrayType(ArrayType(ArrayType(ArrayType(StringType, false), false), true), false), true) :: - StructField("field2", - ArrayType(ArrayType( - StructType(StructField("Test", IntegerType, true) :: Nil), false), true), true) :: - StructField("field3", - ArrayType(ArrayType( - StructType(StructField("Test", StringType, true) :: Nil), true), false), true) :: - StructField("field4", - ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) - - assert(schema === jsonDF.schema) - - checkAnswer( - sql( - """ - |SELECT field1, field2, field3, field4 - |FROM jsonTable - """.stripMargin), - Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: - Row(null, Seq(null, Seq(Row(1))), null, null) :: - Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) :: - Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil - ) - } - - test("SPARK-4228 DataFrame to JSON") - { - val schema1 = StructType( - StructField("f1", IntegerType, false) :: - StructField("f2", StringType, false) :: - StructField("f3", BooleanType, false) :: - StructField("f4", ArrayType(StringType), nullable = true) :: - StructField("f5", IntegerType, true) :: Nil) - - val rowRDD1 = unparsedStrings.map { r => - val values = r.split(",").map(_.trim) - val v5 = try values(3).toInt catch { - case _: NumberFormatException => null - } - Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) - } - - val df1 = applySchema(rowRDD1, schema1) - df1.registerTempTable("applySchema1") - val df2 = df1.toDataFrame - val result = df2.toJSON.collect() - assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") - assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") - - val schema2 = StructType( - StructField("f1", StructType( - StructField("f11", IntegerType, false) :: - StructField("f12", BooleanType, false) :: Nil), false) :: - StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) - - val rowRDD2 = unparsedStrings.map { r => - val values = r.split(",").map(_.trim) - val v4 = try values(3).toInt catch { - case _: NumberFormatException => null - } - Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) - } - - val df3 = applySchema(rowRDD2, schema2) - df3.registerTempTable("applySchema2") - val df4 = df3.toDataFrame - val result2 = df4.toJSON.collect() - - assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") - assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - - val jsonDF = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonDF.toJSON) - primTable.registerTempTable("primativeTable") - checkAnswer( - sql("select * from primativeTable"), - Row(new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - "this is a simple string.") - ) - - val complexJsonDF = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonDF.toJSON) - compTable.registerTempTable("complexTable") - // Access elements of a primitive array. - checkAnswer( - sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), - Row("str1", "str2", null) - ) - - // Access an array of null values. - checkAnswer( - sql("select arrayOfNull from complexTable"), - Row(Seq(null, null, null, null)) - ) - - // Access elements of a BigInteger array (we use DecimalType internally). - checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), - Row(new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) - ) - - // Access elements of an array of arrays. - checkAnswer( - sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), - Row(Seq("1", "2", "3"), Seq("str1", "str2")) - ) - - // Access elements of an array of arrays. - checkAnswer( - sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), - Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) - ) - - // Access elements of an array inside a filed with the type of ArrayType(ArrayType). - checkAnswer( - sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), - Row("str2", 2.1) - ) - - // Access a struct and fields inside of it. - checkAnswer( - sql("select struct, struct.field1, struct.field2 from complexTable"), - Row( - Row(true, new java.math.BigDecimal("92233720368547758070")), - true, - new java.math.BigDecimal("92233720368547758070")) :: Nil - ) - - // Access an array field of a struct. - checkAnswer( - sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), - Row(Seq(4, 5, 6), Seq("str1", "str2")) - ) - - // Access elements of an array field of a struct. - checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), - Row(5, null) - ) - - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala deleted file mode 100644 index 3370b3c98b4b..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import org.apache.spark.sql.test.TestSQLContext - -object TestJsonData { - - val primitiveFieldAndType = - TestSQLContext.sparkContext.parallelize( - """{"string":"this is a simple string.", - "integer":10, - "long":21474836470, - "bigInteger":92233720368547758070, - "double":1.7976931348623157E308, - "boolean":true, - "null":null - }""" :: Nil) - - val primitiveFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( - """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, - "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: - """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, - "num_bool":12, "num_str":null, "str_bool":true}""" :: - """{"num_num_1":21474836470, "num_num_2":92233720368547758070, "num_num_3": 100, - "num_bool":false, "num_str":"str1", "str_bool":false}""" :: - """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, - "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) - - val jsonNullStruct = - TestSQLContext.sparkContext.parallelize( - """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: - """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: - """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: - """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) - - val complexFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( - """{"num_struct":11, "str_array":[1, 2, 3], - "array":[], "struct_array":[], "struct": {}}""" :: - """{"num_struct":{"field":false}, "str_array":null, - "array":null, "struct_array":{}, "struct": null}""" :: - """{"num_struct":null, "str_array":"str", - "array":[4, 5, 6], "struct_array":[7, 8, 9], "struct": {"field":null}}""" :: - """{"num_struct":{}, "str_array":["str1", "str2", 33], - "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) - - val arrayElementTypeConflict = - TestSQLContext.sparkContext.parallelize( - """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], - "array2": [{"field":214748364700}, {"field":1}]}""" :: - """{"array3": [{"field":"str"}, {"field":1}]}""" :: - """{"array3": [1, 2, 3]}""" :: Nil) - - val missingFields = - TestSQLContext.sparkContext.parallelize( - """{"a":true}""" :: - """{"b":21474836470}""" :: - """{"c":[33, 44]}""" :: - """{"d":{"field":true}}""" :: - """{"e":"str"}""" :: Nil) - - val complexFieldAndType1 = - TestSQLContext.sparkContext.parallelize( - """{"struct":{"field1": true, "field2": 92233720368547758070}, - "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, - "arrayOfString":["str1", "str2"], - "arrayOfInteger":[1, 2147483647, -2147483648], - "arrayOfLong":[21474836470, 9223372036854775807, -9223372036854775808], - "arrayOfBigInteger":[922337203685477580700, -922337203685477580800], - "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], - "arrayOfBoolean":[true, false, true], - "arrayOfNull":[null, null, null, null], - "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], - "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], - "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] - }""" :: Nil) - - val complexFieldAndType2 = - TestSQLContext.sparkContext.parallelize( - """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], - "complexArrayOfStruct": [ - { - "field1": [ - { - "inner1": "str1" - }, - { - "inner2": ["str2", "str22"] - }], - "field2": [[1, 2], [3, 4]] - }, - { - "field1": [ - { - "inner2": ["str3", "str33"] - }, - { - "inner1": "str4" - }], - "field2": [[5, 6], [7, 8]] - }], - "arrayOfArray1": [ - [ - [5] - ], - [ - [6, 7], - [8] - ]], - "arrayOfArray2": [ - [ - [ - { - "inner1": "str1" - } - ] - ], - [ - [], - [ - {"inner2": ["str3", "str33"]}, - {"inner2": ["str4"], "inner1": "str11"} - ] - ], - [ - [ - {"inner3": [[{"inner4": 2}]]} - ] - ]] - }""" :: Nil) - - val nullsInArrays = - TestSQLContext.sparkContext.parallelize( - """{"field1":[[null], [[["Test"]]]]}""" :: - """{"field2":[null, [{"Test":1}]]}""" :: - """{"field3":[[null], [{"Test":"2"}]]}""" :: - """{"field4":[[null, [1,2,3]]]}""" :: Nil) - - val jsonArray = - TestSQLContext.sparkContext.parallelize( - """[{"a":"str_a_1"}]""" :: - """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: - """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: - """[]""" :: Nil) - - val corruptRecords = - TestSQLContext.sparkContext.parallelize( - """{""" :: - """""" :: - """{"a":1, b:2}""" :: - """{"a":{, b:3}""" :: - """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: - """]""" :: Nil) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala deleted file mode 100644 index ff91a0eb4204..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import parquet.filter2.predicate.Operators._ -import parquet.filter2.predicate.{FilterPredicate, Operators} - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} - -/** - * A test suite that tests Parquet filter2 API based filter pushdown optimization. - * - * NOTE: - * - * 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the - * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` - * results in a `GtEq` filter predicate rather than a `Not`. - * - * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred - * data type is nullable. - */ -class ParquetFilterSuite extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - private def checkFilterPredicate( - rdd: DataFrame, - predicate: Predicate, - filterClass: Class[_ <: FilterPredicate], - checker: (DataFrame, Seq[Row]) => Unit, - expected: Seq[Row]): Unit = { - val output = predicate.collect { case a: Attribute => a }.distinct - - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) - - assert(maybeAnalyzedPredicate.isDefined) - maybeAnalyzedPredicate.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach { f => - // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(f.getClass === filterClass) - } - } - - checker(query, expected) - } - } - - private def checkFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: DataFrame): Unit = { - checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) - } - - private def checkFilterPredicate[T] - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: DataFrame): Unit = { - checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) - } - - test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - - checkFilterPredicate('_1 === true, classOf[Eq [_]], true) - checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) - } - } - - test("filter pushdown - short") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, - classOf[Operators.And], 3) - checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], Seq(Row(1), Row(4))) - } - } - - test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } - } - - test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } - } - - test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } - } - - test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } - } - - test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) - } - } - - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: DataFrame): Unit = { - def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { - rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted - } - } - - checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) - } - - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: DataFrame): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) - } - - test("filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes("UTF-8") - } - - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate( - '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala deleted file mode 100644 index d9ab16baf9a6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import scala.collection.JavaConversions._ -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import parquet.example.data.simple.SimpleGroup -import parquet.example.data.{Group, GroupWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.{ParquetFileWriter, ParquetWriter} -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} -import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types.DecimalType - -// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport -// with an empty configuration (it is after all not intended to be used in this way?) -// and members are private so we need to make our own in order to pass the schema -// to the writer. -private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { - var groupWriter: GroupWriter = null - - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - groupWriter = new GroupWriter(recordConsumer, schema) - } - - override def init(configuration: Configuration): WriteContext = { - new WriteContext(schema, new java.util.HashMap[String, String]()) - } - - override def write(record: Group) { - groupWriter.write(record) - } -} - -/** - * A test suite that tests basic Parquet I/O. - */ -class ParquetIOSuite extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - /** - * Writes `data` to a Parquet file, reads it back and check file contents. - */ - protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) - } - - test("basic data types (without binary)") { - val data = (1 to 4).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - } - checkParquetFile(data) - } - - test("raw binary") { - val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => - assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted - } - } - } - - test("string") { - val data = (1 to 4).map(i => Tuple1(i.toString)) - // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL - // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) - } - - test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .select($"_1" cast decimal as "abcd") - - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - withTempPath { dir => - val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) - } - } - - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() - } - } - - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() - } - } - } - - test("map") { - val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) - checkParquetFile(data) - } - - test("array") { - val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) - checkParquetFile(data) - } - - test("struct") { - val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) - } - } - - test("nested struct with array of array as field") { - val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) - } - } - - test("nested map with struct as value type") { - val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => - Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) - }) - } - } - - test("nulls") { - val allNulls = ( - null.asInstanceOf[java.lang.Boolean], - null.asInstanceOf[Integer], - null.asInstanceOf[java.lang.Long], - null.asInstanceOf[java.lang.Float], - null.asInstanceOf[java.lang.Double]) - - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(5)(null): _*)) - } - } - - test("nones") { - val allNones = ( - None.asInstanceOf[Option[Int]], - None.asInstanceOf[Option[Long]], - None.asInstanceOf[Option[String]]) - - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(3)(null): _*)) - } - } - - test("compression codec") { - def compressionCodecFor(path: String) = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) - codecs.head - } - - val data = (0 until 10).map(i => (i, i.toString)) - - def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { - withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) - } - } - } - } - - // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) - - checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) - checkCompressionCodec(CompressionCodecName.GZIP) - checkCompressionCodec(CompressionCodecName.SNAPPY) - } - - test("read raw Parquet file") { - def makeRawParquetFile(path: Path): Unit = { - val schema = MessageTypeParser.parseMessageType( - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - |} - """.stripMargin) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - (0 until 10).foreach { i => - val record = new SimpleGroup(schema) - record.add(0, i % 2 == 0) - record.add(1, i) - record.add(2, i.toLong) - record.add(3, i.toFloat) - record.add(4, i.toDouble) - writer.write(record) - } - - writer.close() - } - - withTempDir { dir => - val path = new Path(dir.toURI.toString, "part-r-0.parquet") - makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) - } - } - - test("write metadata") { - withTempPath { file => - val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) - - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) - - actualSchema.checkContains(expectedSchema) - expectedSchema.checkContains(actualSchema) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala deleted file mode 100644 index 5ec7a156d935..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * A test suite that tests various Parquet queries. - */ -class ParquetQuerySuite extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - test("simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) - } - } - - // This test case will trigger the NPE mentioned in - // https://issues.apache.org/jira/browse/PARQUET-151. - ignore("overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM t") - checkAnswer(table("t"), data.map(Row.fromTuple)) - } - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } - } - - test("SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) - - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala deleted file mode 100644 index 64274950b868..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import org.scalatest.FunSuite -import parquet.schema.MessageTypeParser - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext - -class ParquetSchemaSuite extends FunSuite with ParquetTest { - val sqlContext = TestSQLContext - - /** - * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. - */ - private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String): Unit = { - test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes(ScalaReflection.attributesFor[T]) - val expected = MessageTypeParser.parseMessageType(messageType) - actual.checkContains(expected) - expected.checkContains(actual) - } - } - - testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])]( - "basic types", - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - | optional binary _6; - |} - """.stripMargin) - - testSchema[(Byte, Short, Int, Long)]( - "logical integral types", - """ - |message root { - | required int32 _1 (INT_8); - | required int32 _2 (INT_16); - | required int32 _3 (INT_32); - | required int64 _4 (INT_64); - |} - """.stripMargin) - - // Currently String is the only supported logical binary type. - testSchema[Tuple1[String]]( - "binary logical types", - """ - |message root { - | optional binary _1 (UTF8); - |} - """.stripMargin) - - testSchema[Tuple1[Seq[Int]]]( - "array", - """ - |message root { - | optional group _1 (LIST) { - | repeated int32 array; - | } - |} - """.stripMargin) - - testSchema[Tuple1[Map[Int, String]]]( - "map", - """ - |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required int32 key; - | optional binary value (UTF8); - | } - | } - |} - """.stripMargin) - - testSchema[Tuple1[Pair[Int, String]]]( - "struct", - """ - |message root { - | optional group _1 { - | required int32 _1; - | optional binary _2 (UTF8); - | } - |} - """.stripMargin) - - testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( - "deeply nested type", - """ - |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required int32 key; - | optional group value { - | optional binary _1 (UTF8); - | optional group _2 (LIST) { - | repeated group bag { - | optional group array { - | required int32 _1; - | required double _2; - | } - | } - | } - | } - | } - | } - |} - """.stripMargin) - - testSchema[(Option[Int], Map[Int, Option[Double]])]( - "optional types", - """ - |message root { - | optional int32 _1; - | optional group _2 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required int32 key; - | optional double value; - | } - | } - |} - """.stripMargin) - - test("DataType string parser compatibility") { - // This is the generated string from previous versions of the Spark SQL, using the following: - // val schema = StructType(List( - // StructField("c1", IntegerType, false), - // StructField("c2", BinaryType, true))) - val caseClassString = - "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" - - val jsonString = - """ - |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} - """.stripMargin - - val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) - val fromJson = ParquetTypesConverter.convertFromString(jsonString) - - (fromCaseClassString, fromJson).zipped.foreach { (a, b) => - assert(a.name == b.name) - assert(a.dataType === b.dataType) - assert(a.nullable === b.nullable) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index b02389978b62..9bc3f6bcf6fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -17,38 +17,45 @@ package org.apache.spark.sql.sources -import java.io.File +import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensisitiveContext._ - - var path: File = null +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ + private lazy val sparkContext = caseInsensitiveContext.sparkContext + private var path: File = null override def beforeAll(): Unit = { - path = util.getTempFilePath("jsonCTAS").getCanonicalFile + super.beforeAll() + path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - dropTempTable("jt") + try { + caseInsensitiveContext.dropTempTable("jt") + } finally { + super.afterAll() + } } after { - if (path.exists()) Utils.deleteRecursively(path) + Utils.deleteRecursively(path) } test("CREATE TEMPORARY TABLE AS SELECT") { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -59,14 +66,37 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") + } + + test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { + val childPath = new File(path.toString, "child") + path.mkdir() + childPath.createNewFile() + path.setWritable(false) + + val e = intercept[IOException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + sql("SELECT a, b FROM jsonTable").collect() + } + assert(e.getMessage().contains("Unable to clear output directory")) + + path.setWritable(true) } test("create a table, drop it and create another one with the same name") { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -77,13 +107,11 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") - - val message = intercept[RuntimeException]{ + val message = intercept[DDLException]{ sql( s""" - |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -91,27 +119,42 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) }.getMessage assert( - message.contains(s"path ${path.toString} already exists."), + message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") - // Explicitly delete it. + // Overwrite the temporary table. + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT a * 4 FROM jt").collect()) + + caseInsensitiveContext.dropTempTable("jsonTable") + // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS - |SELECT a * 4 FROM jt + |SELECT b FROM jt """.stripMargin) checkAnswer( sql("SELECT * FROM jsonTable"), - sql("SELECT a * 4 FROM jt").collect()) + sql("SELECT b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { @@ -119,7 +162,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -136,7 +179,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -144,4 +187,31 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) } } + + test("it is not allowed to write to a table while querying it.") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jsonTable + """.stripMargin) + }.getMessage + assert( + message.contains("Cannot overwrite table "), + "Writing to a table while querying it should not be allowed.") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 000000000000..853707c036c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,89 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def shortName(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala new file mode 100644 index 000000000000..5f8514e1a241 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) + } +} + +case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan { + + override def schema: StructType = + StructType(Seq( + StructField("intType", IntegerType, nullable = false, + new MetadataBuilder().putString("comment", s"test comment $table").build()), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil + ) + ) + )) + + override def needConversion: Boolean = false + + override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + sqlContext.sparkContext.parallelize(from to to).map { e => + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] + } +} + +class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ + + override def beforeAll(): Unit = { + super.beforeAll() + sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + } + + sqlTest( + "describe ddlPeople", + Seq( + Row("intType", "int", "test comment test1"), + Row("stringType", "string", ""), + Row("dateType", "date", ""), + Row("timestampType", "timestamp", ""), + Row("doubleType", "double", ""), + Row("bigintType", "bigint", ""), + Row("tinyintType", "tinyint", ""), + Row("decimalType", "decimal(10,0)", ""), + Row("fixedDecimalType", "decimal(5,1)", ""), + Row("binaryType", "binary", ""), + Row("booleanType", "boolean", ""), + Row("smallIntType", "smallint", ""), + Row("floatType", "float", ""), + Row("mapType", "map", ""), + Row("arrayType", "array", ""), + Row("structType", "struct", "") + )) + + test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { + val attributes = sql("describe ddlPeople") + .queryExecution.executedPlan.output + assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) + assert(attributes.map(_.dataType).toSet === Set(StringType)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 9626252e742e..d74d29fb0beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -18,17 +18,22 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Analyzer -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfter -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { - // Case sensitivity is not configurable yet, but we want to test some edge cases. - // TODO: Remove when it is configurable - implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { - @transient - override protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) + +private[sql] abstract class DataSourceTest extends QueryTest { + protected def _sqlContext: SQLContext + + // We want to test some edge cases. + protected lazy val caseInsensitiveContext: SQLContext = { + val ctx = new SQLContext(_sqlContext.sparkContext) + ctx.setConf(SQLConf.CASE_SENSITIVE, false) + ctx } -} + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 390538d35a34..68ce37c00077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -32,31 +35,61 @@ class FilteredScanSource extends RelationProvider { } case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends PrunedFilteredScan { + extends BaseRelation + with PrunedFilteredScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: Nil) + StructField("b", IntegerType, nullable = false) :: + StructField("c", StringType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) + case "c" => (i: Int) => + val c = (i - 1 + 'a').toChar.toString + Seq(c * 5 + c.toUpperCase() * 5) } FiltersPushed.list = filters - val filterFunctions = filters.collect { + // Predicate test on integer column + def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) + case IsNull("a") => (a: Int) => false // Int can't be null + case IsNotNull("a") => (a: Int) => true + case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a) + case And(left, right) => (a: Int) => + translateFilterOnA(left)(a) && translateFilterOnA(right)(a) + case Or(left, right) => (a: Int) => + translateFilterOnA(left)(a) || translateFilterOnA(right)(a) + case _ => (a: Int) => true } - def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + // Predicate test on string column + def translateFilterOnC(filter: Filter): String => Boolean = filter match { + case StringStartsWith("c", v) => _.startsWith(v) + case StringEndsWith("c", v) => _.endsWith(v) + case StringContains("c", v) => _.contains(v) + case EqualTo("c", v: String) => _.equals(v) + case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters") + case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s) + case _ => (c: String) => true + } + + def eval(a: Int) = { + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 + !filters.map(translateFilterOnA(_)(a)).contains(false) && + !filters.map(translateFilterOnC(_)(c)).contains(false) + } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -68,11 +101,11 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { - - import caseInsensisitiveContext._ +class FilteredScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered @@ -86,7 +119,8 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 + + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -121,20 +155,52 @@ class FilteredScanSuite extends DataSourceTest { (2 to 10 by 2).map(i => Row(i, i)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1, 3, 5).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IS NULL", + Seq.empty[Row]) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IS NOT NULL", + (1 to 10).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 5 AND a > 1", + (2 to 4).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE NOT (a < 6)", + (6 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", + Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", + Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", + Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) @@ -162,6 +228,22 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) + testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala deleted file mode 100644 index f91cea6a3706..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import java.io.File - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util -import org.apache.spark.util.Utils - -class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensisitiveContext._ - - var path: File = null - - override def beforeAll: Unit = { - path = util.getTempFilePath("jsonCTAS").getCanonicalFile - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${path.toString}' - |) - """.stripMargin) - } - - override def afterAll: Unit = { - dropTempTable("jsonTable") - dropTempTable("jt") - if (path.exists()) Utils.deleteRecursively(path) - } - - test("Simple INSERT OVERWRITE a JSONRelation") { - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) - ) - } - - test("INSERT OVERWRITE a JSONRelation multiple times") { - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) - ) - } - - test("INSERT INTO not supported for JSONRelation for now") { - intercept[RuntimeException]{ - sql( - s""" - |INSERT INTO TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala new file mode 100644 index 000000000000..78bd3e558296 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ + private lazy val sparkContext = caseInsensitiveContext.sparkContext + private var path: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + path = Utils.createTempDir() + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable (a int, b string) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) + """.stripMargin) + } + + override def afterAll(): Unit = { + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } + } + + test("Simple INSERT OVERWRITE a JSONRelation") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("PreInsert casting and renaming") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 2, s"${i * 4}")) + ) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 4, s"${i * 6}")) + ) + } + + test("SELECT clause generating a different number of columns is not allowed.") { + val message = intercept[RuntimeException] { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("generates the same number of columns as its schema"), + "SELECT clause generating a different number of columns should not be not allowed." + ) + } + + test("INSERT OVERWRITE a JSONRelation multiple times") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + + // Writing the table to less part files. + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) + caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + + // Writing the table to more part files. + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) + caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1 + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 10, s"str$i")) + ) + + caseInsensitiveContext.dropTempTable("jt1") + caseInsensitiveContext.dropTempTable("jt2") + } + + test("INSERT INTO JSONRelation for now") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect() + ) + + sql( + s""" + |INSERT INTO TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt UNION ALL SELECT a, b FROM jt").collect() + ) + } + + test("save directly to the path of a JSON table") { + caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") + .write.mode(SaveMode.Overwrite).json(path.toString) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 5, s"str$i")) + ) + + caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("it is not allowed to write to a table while querying it.") { + val message = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable + """.stripMargin) + }.getMessage + assert( + message.contains("Cannot insert overwrite into table that is also being read from."), + "INSERT OVERWRITE to a table while querying it should not be allowed.") + } + + test("Caching") { + // write something to the jsonTable + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + // Cached Query Execution + caseInsensitiveContext.cacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable")) + checkAnswer( + sql("SELECT * FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i"))) + + assertCached(sql("SELECT a FROM jsonTable")) + checkAnswer( + sql("SELECT a FROM jsonTable"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT a FROM jsonTable WHERE a < 5")) + checkAnswer( + sql("SELECT a FROM jsonTable WHERE a < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT a * 2 FROM jsonTable")) + checkAnswer( + sql("SELECT a * 2 FROM jsonTable"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Insert overwrite and keep the same schema. + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt + """.stripMargin) + // jsonTable should be recached. + assertCached(sql("SELECT * FROM jsonTable")) + // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation +// // The cached data is the new data. +// checkAnswer( +// sql("SELECT a, b FROM jsonTable"), +// sql("SELECT a * 2, b FROM jt").collect()) +// +// // Verify uncaching +// caseInsensitiveContext.uncacheTable("jsonTable") +// assertCached(sql("SELECT * FROM jsonTable"), 0) + } + + test("it's not allowed to insert into a relation that is not an InsertableRelation") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq + ) + + val message = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("does not allow insertion."), + "It is not allowed to insert into a table that is not an InsertableRelation." + ) + + caseInsensitiveContext.dropTempTable("oneToTen") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala new file mode 100644 index 000000000000..79b6e9b45c00 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("write many partitions") { + val path = Utils.createTempDir() + path.delete() + + val df = ctx.range(100).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + ctx.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } + + test("write many partitions with repeats") { + val path = Utils.createTempDir() + path.delete() + + val base = ctx.range(100) + val df = base.unionAll(base).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + ctx.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index a33cf1172cac..a89c5f8007e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -31,14 +33,15 @@ class PrunedScanSource extends RelationProvider { } case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends PrunedScan { + extends BaseRelation + with PrunedScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String]) = { + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) @@ -49,10 +52,11 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { - import caseInsensisitiveContext._ +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenPruned @@ -130,7 +134,7 @@ class PrunedScanSuite extends DataSourceTest { queryExecution) } - if (rawOutput.size != expectedColumns.size) { + if (rawOutput.numFields != expectedColumns.size) { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala new file mode 100644 index 000000000000..27d1cd92fca1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -0,0 +1,60 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.ResolvedDataSource + +class ResolvedDataSourceSuite extends SparkFunSuite { + + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } + + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } + + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index fe2f76cc397f..f18546b4c2d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,61 +19,93 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.util - -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensisitiveContext._ - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ + private lazy val sparkContext = caseInsensitiveContext.sparkContext + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { - originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") + super.beforeAll() + originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName - path = util.getTempFilePath("datasource").getCanonicalFile + path = Utils.createTempDir() + path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = jsonRDD(rdd) + df = caseInsensitiveContext.read.json(rdd) + df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } finally { + super.afterAll() + } } after { - if (path.exists()) Utils.deleteRecursively(path) + Utils.deleteRecursively(path) } - def checkLoad(): Unit = { - checkAnswer(load(path.toString), df.collect()) - checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect()) + def checkLoad(expectedDF: DataFrame = df, tbl: String = "jsonTable"): Unit = { + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + checkAnswer(caseInsensitiveContext.read.load(path.toString), expectedDF.collect()) + + // Test if we can pick up the data source name passed in load. + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + expectedDF.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + expectedDF.collect()) + val schema = StructType(StructField("b", StringType, true) :: Nil) + checkAnswer( + caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), + sql(s"SELECT b FROM $tbl").collect()) } - test("save with overwrite and load") { - df.save(path.toString) - checkLoad + test("save with path and load") { + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + df.write.save(path.toString) + checkLoad() + } + + test("save with string mode and path, and load") { + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + path.createNewFile() + df.write.mode("overwrite").save(path.toString) + checkLoad() + } + + test("save with path and datasource, and load") { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.write.json(path.toString) + checkLoad() } test("save with data source and options, and load") { - df.save("org.apache.spark.sql.json", ("path", path.toString)) - checkLoad + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.write.mode(SaveMode.ErrorIfExists).json(path.toString) + checkLoad() } test("save and save again") { - df.save(path.toString) + df.write.json(path.toString) - val message = intercept[RuntimeException] { - df.save(path.toString) + val message = intercept[AnalysisException] { + df.write.json(path.toString) }.getMessage assert( @@ -82,7 +114,17 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { if (path.exists()) Utils.deleteRecursively(path) - df.save(path.toString) - checkLoad + df.write.json(path.toString) + checkLoad() + + df.write.mode(SaveMode.Overwrite).json(path.toString) + checkLoad() + + // verify the append mode + df.write.mode(SaveMode.Append).json(path.toString) + val df2 = df.unionAll(df) + df2.registerTempTable("jsonTable2") + + checkLoad(df2, "jsonTable2") } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 0a4d4b6342d4..12af8068c398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.sources -import java.sql.{Timestamp, Date} +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource @@ -33,12 +36,12 @@ class SimpleScanSource extends RelationProvider { } case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends TableScan { + extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) + override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } class AllDataTypesScanSource extends SchemaRelationProvider { @@ -46,23 +49,30 @@ class AllDataTypesScanSource extends SchemaRelationProvider { sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { + // Check that weird parameters are passed correctly. + parameters("option_with_underscores") + parameters("option.with.dots") + AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) } } case class AllDataTypesScan( - from: Int, - to: Int, - userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) - extends TableScan { + from: Int, + to: Int, + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) + extends BaseRelation + with TableScan { + + override def schema: StructType = userSpecifiedSchema - override def schema = userSpecifiedSchema + override def needConversion: Boolean = true - override def buildScan() = { + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", - s"str_$i".getBytes(), + s"str_$i".getBytes(StandardCharsets.UTF_8), i % 2 == 0, i.toByte, i.toShort, @@ -72,7 +82,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -80,15 +90,16 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) } } } -class TableScanSuite extends DataSourceTest { - import caseInsensisitiveContext._ +class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - var tableWithSchemaExpected = (1 to 10).map { i => + private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", s"str_$i", @@ -101,7 +112,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -109,17 +120,20 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTen |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) @@ -150,7 +164,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) } @@ -186,7 +202,7 @@ class TableScanSuite extends DataSourceTest { StructField("longField_:,<>=+/~^", LongType, true) :: StructField("floatField", FloatType, true) :: StructField("doubleField", DoubleType, true) :: - StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField1", DecimalType.USER_DEFAULT, true) :: StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: @@ -213,7 +229,7 @@ class TableScanSuite extends DataSourceTest { Nil ) - assert(expectedSchema == table("tableWithSchema").schema) + assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( sql( @@ -264,11 +280,11 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))).toSeq) test("Caching") { // Cached Query Execution - cacheTable("oneToTen") + caseInsensitiveContext.cacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( sql("SELECT * FROM oneToTen"), @@ -289,13 +305,14 @@ class TableScanSuite extends DataSourceTest { sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching - uncacheTable("oneToTen") + caseInsensitiveContext.uncacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen"), 0) } @@ -352,7 +369,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | from '1', - | to '10' + | to '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala new file mode 100644 index 000000000000..152c9c8459de --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.{IOException, InputStream} + +import scala.sys.process.BasicIO + +object ProcessTestUtils { + class ProcessOutputCapturer(stream: InputStream, capture: String => Unit) extends Thread { + this.setDaemon(true) + + override def run(): Unit = { + try { + BasicIO.processFully(capture)(stream) + } catch { case _: IOException => + // Ignores the IOException thrown when the process termination, which closes the input + // stream abruptly. + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 000000000000..1374a97476ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def _sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val testData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = _sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = _sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + _sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + protected lazy val salary: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 000000000000..cdd691e03589 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File +import java.util.UUID + +import scala.util.Try +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.util.Utils + +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def _sqlContext: SQLContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = _sqlContext.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def configuration: Configuration = { + _sqlContext.sparkContext.hadoopConfiguration + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(_sqlContext.conf.setConfString) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) + case (key, None) => _sqlContext.conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(_sqlContext.dropTempTable) + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + _sqlContext.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + _sqlContext.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + _sqlContext.sql(s"USE $db") + try f finally _sqlContext.sql(s"USE default") + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(_sqlContext, plan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 000000000000..8a061b6bc690 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.{ColumnName, SQLContext} + + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +private[sql] trait SharedSQLContext extends SQLTestUtils { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = null + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def ctx: TestSQLContext = _ctx + protected def sqlContext: TestSQLContext = _ctx + protected override def _sqlContext: SQLContext = _ctx + + /** + * Initialize the [[TestSQLContext]]. + */ + protected override def beforeAll(): Unit = { + if (_ctx == null) { + _ctx = new TestSQLContext + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() + } + } + + /** + * Converts $"col name" into an [[Column]]. + * @since 1.3.0 + */ + // This must be duplicated here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala new file mode 100644 index 000000000000..92ef2f7d74ba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{SQLConf, SQLContext} + + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) + } + + // Use fewer partitions to speed up testing + protected[sql] override def createSession(): SQLSession = new this.SQLSession() + + /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ + protected[sql] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + } + } + + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def _sqlContext: SQLContext = self + } +} diff --git a/sql/core/src/test/scripts/gen-avro.sh b/sql/core/src/test/scripts/gen-avro.sh new file mode 100755 index 000000000000..48174b287fd7 --- /dev/null +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 000000000000..ada432c68ab9 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift new file mode 100644 index 000000000000..98bf778aec5d --- /dev/null +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift + +enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS +} + +struct Nested { + 1: required list nestedIntsColumn; + 2: required string nestedStringColumn; +} + +/** + * This is a test struct for testing parquet-thrift compatibility. + */ +struct ParquetThriftCompat { + 1: required bool boolColumn; + 2: required byte byteColumn; + 3: required i16 shortColumn; + 4: required i32 intColumn; + 5: required i64 longColumn; + 6: required double doubleColumn; + 7: required binary binaryColumn; + 8: required string stringColumn; + 9: required Suit enumColumn + + 10: optional bool maybeBoolColumn; + 11: optional byte maybeByteColumn; + 12: optional i16 maybeShortColumn; + 13: optional i32 maybeIntColumn; + 14: optional i64 maybeLongColumn; + 15: optional double maybeDoubleColumn; + 16: optional binary maybeBinaryColumn; + 17: optional string maybeStringColumn; + 18: optional Suit maybeEnumColumn; + + 19: required list stringsColumn; + 20: required set intSetColumn; + 21: required map intToStringColumn; + 22: required map> complexColumn; +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 123a1f629ab1..3566c87dd248 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,17 @@ spark-hive_${scala.binary.version} ${project.version}
    + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + com.google.guava + guava + ${hive.group} hive-cli @@ -49,10 +60,39 @@ ${hive.group} hive-jdbc + + ${hive.group} + hive-service + ${hive.group} hive-beeline + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + com.sun.jersey + jersey-server + + + + org.seleniumhq.selenium + selenium-java + test + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test +
    target/scala-${scala.binary.version}/classes diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala new file mode 100644 index 000000000000..2228f651e238 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.service.server + +import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} + +/** + * Class to upgrade a package-private class to public, and + * implement a `process()` operation consistent with + * the behavior of older Hive versions + * @param serverName name of the hive server + */ +private[apache] class HiveServerServerOptionsProcessor(serverName: String) + extends ServerOptionsProcessor(serverName) { + + def process(args: Array[String]): Boolean = { + // A parse failure automatically triggers a system exit + val response = super.parse(args) + val executor = response.getServerOptionsExecutor() + // return true if the parsed option was to start the service + executor.isInstanceOf[StartOptionExecutor] + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala deleted file mode 100644 index 59f3a7576808..000000000000 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import scala.collection.JavaConversions._ - -import org.apache.commons.lang.exception.ExceptionUtils -import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} -import org.apache.hadoop.hive.ql.Driver -import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse - -import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} - -private[hive] abstract class AbstractSparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { - - private[hive] var tableSchema: Schema = _ - private[hive] var hiveResponse: Seq[String] = _ - - override def init(): Unit = { - } - - private def getResultSetSchema(query: context.QueryExecution): Schema = { - val analyzed = query.analyzed - logDebug(s"Result Schema: ${analyzed.output}") - if (analyzed.output.size == 0) { - new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) - } else { - val fieldSchemas = analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - - new Schema(fieldSchemas, null) - } - } - - override def run(command: String): CommandProcessorResponse = { - // TODO unify the error code - try { - context.sparkContext.setJobDescription(command) - val execution = context.executePlan(context.sql(command).logicalPlan) - hiveResponse = execution.stringResult() - tableSchema = getResultSetSchema(execution) - new CommandProcessorResponse(0) - } catch { - case cause: Throwable => - logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) - } - } - - override def close(): Int = { - hiveResponse = null - tableSchema = null - 0 - } - - override def getSchema: Schema = tableSchema - - override def destroy() { - super.destroy() - hiveResponse = null - tableSchema = null - } -} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6e07df18b0e1..dd9fef9206d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,17 +17,27 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.Locale +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} +import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} -import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListener} +import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.{Logging, SparkContext} + /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -35,6 +45,8 @@ import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListener} */ object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) + var uiTab: Option[ThriftServerTab] = _ + var listener: HiveThriftServer2Listener = _ /** * :: DeveloperApi :: @@ -43,13 +55,20 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) + sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() - sqlContext.sparkContext.addSparkListener(new HiveThriftServer2Listener(server)) + listener = new HiveThriftServer2Listener(server, sqlContext.conf) + sqlContext.sparkContext.addSparkListener(listener) + uiTab = if (sqlContext.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) { + Some(new ThriftServerTab(sqlContext.sparkContext)) + } else { + None + } } def main(args: Array[String]) { - val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) } @@ -57,20 +76,23 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + ShutdownHookManager.addShutdownHook { () => + SparkSQLEnv.stop() + uiTab.foreach(_.detach()) + } try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("HiveThriftServer2 started") - SparkSQLEnv.sparkContext.addSparkListener(new HiveThriftServer2Listener(server)) + listener = new HiveThriftServer2Listener(server, SparkSQLEnv.hiveContext.conf) + SparkSQLEnv.sparkContext.addSparkListener(listener) + uiTab = if (SparkSQLEnv.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) { + Some(new ThriftServerTab(SparkSQLEnv.sparkContext)) + } else { + None + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) @@ -78,42 +100,198 @@ object HiveThriftServer2 extends Logging { } } + private[thriftserver] class SessionInfo( + val sessionId: String, + val startTimestamp: Long, + val ip: String, + val userName: String) { + var finishTimestamp: Long = 0L + var totalExecution: Int = 0 + def totalTime: Long = { + if (finishTimestamp == 0L) { + System.currentTimeMillis - startTimestamp + } else { + finishTimestamp - startTimestamp + } + } + } + + private[thriftserver] object ExecutionState extends Enumeration { + val STARTED, COMPILED, FAILED, FINISHED = Value + type ExecutionState = Value + } + + private[thriftserver] class ExecutionInfo( + val statement: String, + val sessionId: String, + val startTimestamp: Long, + val userName: String) { + var finishTimestamp: Long = 0L + var executePlan: String = "" + var detail: String = "" + var state: ExecutionState.Value = ExecutionState.STARTED + val jobId: ArrayBuffer[String] = ArrayBuffer[String]() + var groupId: String = "" + def totalTime: Long = { + if (finishTimestamp == 0L) { + System.currentTimeMillis - startTimestamp + } else { + finishTimestamp - startTimestamp + } + } + } + + /** * A inner sparkListener called in sc.stop to clean up the HiveThriftServer2 */ - class HiveThriftServer2Listener(val server: HiveServer2) extends SparkListener { + private[thriftserver] class HiveThriftServer2Listener( + val server: HiveServer2, + val conf: SQLConf) extends SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - } + private var onlineSessionNum: Int = 0 + private val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + private val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) + private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) + private var totalRunning = 0 + + def getOnlineSessionNum: Int = synchronized { onlineSessionNum } + def getTotalRunning: Int = synchronized { totalRunning } + + def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } + + def getSession(sessionId: String): Option[SessionInfo] = synchronized { + sessionList.get(sessionId) + } + + def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { + for { + props <- Option(jobStart.properties) + groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + (_, info) <- executionList if info.groupId == groupId + } { + info.jobId += jobStart.jobId.toString + info.groupId = groupId + } + } + + def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { + synchronized { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + onlineSessionNum += 1 + trimSessionIfNecessary() + } + } + + def onSessionClosed(sessionId: String): Unit = synchronized { + sessionList(sessionId).finishTimestamp = System.currentTimeMillis + onlineSessionNum -= 1 + trimSessionIfNecessary() + } + + def onStatementStart( + id: String, + sessionId: String, + statement: String, + groupId: String, + userName: String = "UNKNOWN"): Unit = synchronized { + val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) + info.state = ExecutionState.STARTED + executionList.put(id, info) + trimExecutionIfNecessary() + sessionList(sessionId).totalExecution += 1 + executionList(id).groupId = groupId + totalRunning += 1 + } + + def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { + executionList(id).executePlan = executionPlan + executionList(id).state = ExecutionState.COMPILED + } + + def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { + synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + trimExecutionIfNecessary() + } + } + + def onStatementFinish(id: String): Unit = synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).state = ExecutionState.FINISHED + totalRunning -= 1 + trimExecutionIfNecessary() + } + + private def trimExecutionIfNecessary() = { + if (executionList.size > retainedStatements) { + val toRemove = math.max(retainedStatements / 10, 1) + executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => + executionList.remove(s._1) + } + } + } + + private def trimSessionIfNecessary() = { + if (sessionList.size > retainedSessions) { + val toRemove = math.max(retainedSessions / 10, 1) + sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => + sessionList.remove(s._1) + } + } + + } + } } private[hive] class HiveThriftServer2(hiveContext: HiveContext) extends HiveServer2 with ReflectedCompositeService { + // state is tracked internally so that the server only attempts to shut down if it successfully + // started, and then once only. + private val started = new AtomicBoolean(false) override def init(hiveConf: HiveConf) { - val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + val sparkSqlCliService = new SparkSQLCLIService(this, hiveContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) - if (isHTTPTransportMode(hiveConf)) { - val thriftCliService = new ThriftHttpCLIService(sparkSqlCliService) - setSuperField(this, "thriftCLIService", thriftCliService) - addService(thriftCliService) + val thriftCliService = if (isHTTPTransportMode(hiveConf)) { + new ThriftHttpCLIService(sparkSqlCliService) } else { - val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) - setSuperField(this, "thriftCLIService", thriftCliService) - addService(thriftCliService) + new ThriftBinaryCLIService(sparkSqlCliService) } + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) initCompositeService(hiveConf) } private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { - val transportMode: String = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.equalsIgnoreCase("http") + val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) + transportMode.toLowerCase(Locale.ENGLISH).equals("http") } + + override def start(): Unit = { + super.start() + started.set(true) + } + + override def stop(): Unit = { + if (started.getAndSet(false)) { + super.stop() + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala new file mode 100644 index 000000000000..306f98bcb534 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.security.PrivilegedExceptionAction +import java.sql.{Date, Timestamp} +import java.util.concurrent.RejectedExecutionException +import java.util.{Arrays, Map => JMap, UUID} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.Utils +import org.apache.hive.service.cli.operation.ExecuteStatementOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.execution.SetCommand +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} + + +private[hive] class SparkExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + runInBackground: Boolean = true) + (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) + extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground) + with Logging { + + private var result: DataFrame = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + private var statementId: String = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + hiveContext.sparkContext.clearJobGroup() + logDebug(s"CLOSING $statementId") + cleanup(OperationState.CLOSED) + } + + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to += from.getString(ordinal) + case IntegerType => + to += from.getInt(ordinal) + case BooleanType => + to += from.getBoolean(ordinal) + case DoubleType => + to += from.getDouble(ordinal) + case FloatType => + to += from.getFloat(ordinal) + case DecimalType() => + to += from.getDecimal(ordinal) + case LongType => + to += from.getLong(ordinal) + case ByteType => + to += from.getByte(ordinal) + case ShortType => + to += from.getShort(ordinal) + case DateType => + to += from.getAs[Date](ordinal) + case TimestampType => + to += from.getAs[Timestamp](ordinal) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) + to += hiveString + } + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + validateDefaultFetchOrientation(order) + assertState(OperationState.FINISHED) + setHasResultSet(true) + val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + if (!iter.hasNext) { + resultRowSet + } else { + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + var curRow = 0 + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = ArrayBuffer[Any]() + var curCol = 0 + while (curCol < sparkRow.length) { + if (sparkRow.isNullAt(curCol)) { + row += null + } else { + addNonNullColumnValue(sparkRow, row, curCol) + } + curCol += 1 + } + resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + curRow += 1 + } + resultRowSet + } + } + + def getResultSetSchema: TableSchema = { + if (result == null || result.queryExecution.analyzed.output.size == 0) { + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) + } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema.asJava) + } + } + + override def run(): Unit = { + setState(OperationState.PENDING) + setHasResultSet(true) // avoid no resultset for async run + + if (!runInBackground) { + runInternal() + } else { + val parentSessionState = SessionState.get() + val hiveConf = getConfigForOperation() + val sparkServiceUGI = Utils.getUGI() + val sessionHive = getCurrentHive() + val currentSqlSession = hiveContext.currentSession + + // Runnable impl to call runInternal asynchronously, + // from a different thread + val backgroundOperation = new Runnable() { + + override def run(): Unit = { + val doAsAction = new PrivilegedExceptionAction[Object]() { + override def run(): Object = { + + // User information is part of the metastore client member in Hive + hiveContext.setSession(currentSqlSession) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = + hiveContext.executionHive.state.getConf.getClassLoader + sessionHive.getConf.setClassLoader(executionHiveClassLoader) + parentSessionState.getConf.setClassLoader(executionHiveClassLoader) + + Hive.set(sessionHive) + SessionState.setCurrentSessionState(parentSessionState) + try { + runInternal() + } catch { + case e: HiveSQLException => + setOperationException(e) + log.error("Error running hive query: ", e) + } + return null + } + } + + try { + sparkServiceUGI.doAs(doAsAction) + } catch { + case e: Exception => + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + + sparkServiceUGI.getShortUserName(), e) + } + } + } + try { + // This submit blocks if no background threads are available to run this operation + val backgroundHandle = + getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + case rejected: RejectedExecutionException => + setState(OperationState.ERROR) + throw new HiveSQLException("The background threadpool cannot accept" + + " new task for execution, please retry the operation", rejected) + case NonFatal(e) => + logError(s"Error executing query in background", e) + setState(OperationState.ERROR) + throw e + } + } + } + + override def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + logInfo(s"Running query '$statement' with $statementId") + setState(OperationState.RUNNING) + HiveThriftServer2.listener.onStatementStart( + statementId, + parentSession.getSessionHandle.getSessionId.toString, + statement, + statementId, + parentSession.getUsername) + hiveContext.sparkContext.setJobGroup(statementId, statement) + sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } + try { + result = hiveContext.sql(statement) + logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL.key, Some(value)))) => + sessionToActivePool(parentSession.getSessionHandle) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) + iter = { + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + result.rdd.toLocalIterator + } else { + result.collect().iterator + } + } + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + } catch { + case e: HiveSQLException => + if (getStatus().getState() == OperationState.CANCELED) { + return + } else { + setState(OperationState.ERROR); + throw e + } + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + val currentState = getStatus().getState() + logError(s"Error executing query, currentState $currentState, ", e) + setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, e.getStackTraceString) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + HiveThriftServer2.listener.onStatementFinish(statementId) + } + + override def cancel(): Unit = { + logInfo(s"Cancel '$statement' with $statementId") + if (statementId != null) { + hiveContext.sparkContext.cancelJobGroup(statementId) + } + cleanup(OperationState.CANCELED) + } + + private def cleanup(state: OperationState) { + setState(state) + if (runInBackground) { + val backgroundHandle = getBackgroundHandle() + if (backgroundHandle != null) { + backgroundHandle.cancel(true) + } + } + } + + /** + * If there are query specific settings to overlay, then create a copy of config + * There are two cases we need to clone the session config that's being passed to hive driver + * 1. Async query - + * If the client changes a config setting, that shouldn't reflect in the execution + * already underway + * 2. confOverlay - + * The query specific settings should only be applied to the query config and not session + * @return new configuration + * @throws HiveSQLException + */ + private def getConfigForOperation(): HiveConf = { + var sqlOperationConf = getParentSession().getHiveConf() + if (!getConfOverlay().isEmpty() || runInBackground) { + // clone the partent session config for this query + sqlOperationConf = new HiveConf(sqlOperationConf) + + // apply overlay query specific settings, if any + getConfOverlay().asScala.foreach { case (k, v) => + try { + sqlOperationConf.verifyAndSet(k, v) + } catch { + case e: IllegalArgumentException => + throw new HiveSQLException("Error applying statement specific settings", e) + } + } + } + return sqlOperationConf + } + + private def getCurrentHive(): Hive = { + try { + return Hive.get() + } catch { + case e: HiveException => + throw new HiveSQLException("Failed to get current Hive object", e); + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala old mode 100755 new mode 100644 index bb19ac232fcb..a29df567983b --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -17,34 +17,38 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io._ -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, Locale} + +import scala.collection.JavaConverters._ -import jline.{ConsoleReader, History} +import jline.console.ConsoleReader +import jline.console.history.FileHistory -import org.apache.commons.lang.StringUtils +import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException -import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.util.{ShutdownHookManager, Utils} -private[hive] object SparkSQLCLIDriver { +/** + * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver + * has dropped its support. + */ +private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') - private var transport:TSocket = _ + private var transport: TSocket = _ installSignalHandler() @@ -75,7 +79,12 @@ private[hive] object SparkSQLCLIDriver { System.exit(1) } - val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + val cliConf = new HiveConf(classOf[SessionState]) + // Override the location of the metastore since this is only used for local execution. + HiveContext.newTemporaryConfiguration().foreach { + case (key, value) => cliConf.set(key, value) + } + val sessionState = new CliSessionState(cliConf) sessionState.in = System.in try { @@ -92,33 +101,24 @@ private[hive] object SparkSQLCLIDriver { // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf - sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => - conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) - sessionState.getOverriddenConfigurations.put( - item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.cmdProperties.entrySet().asScala.foreach { item => + val key = item.getKey.toString + val value = item.getValue.toString + // We do not propagate metastore options to the execution copy of hive. + if (key != "javax.jdo.option.ConnectionURL") { + conf.set(key, value) + sessionState.getOverriddenConfigurations.put(key, value) + } } SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } + val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. - if (sessionState.getHost != null) { - sessionState.connect() - if (sessionState.isRemoteMode) { - prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt - continuedPrompt = "".padTo(prompt.length, ' ') - } - } - - if (!sessionState.isRemoteMode) { + if (!remoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -129,6 +129,9 @@ private[hive] object SparkSQLCLIDriver { } conf.setClassLoader(loader) Thread.currentThread().setContextClassLoader(loader) + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } val cli = new SparkSQLCLIDriver @@ -145,6 +148,10 @@ private[hive] object SparkSQLCLIDriver { case e: UnsupportedEncodingException => System.exit(3) } + if (sessionState.database != null) { + SparkSQLEnv.hiveContext.runSqlHive(s"USE ${sessionState.database}") + } + // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) @@ -158,64 +165,69 @@ private[hive] object SparkSQLCLIDriver { } } catch { case e: FileNotFoundException => - System.err.println(s"Could not open input file for reading. (${e.getMessage})") + logError(s"Could not open input file for reading. (${e.getMessage})") System.exit(3) } val reader = new ConsoleReader() reader.setBellEnabled(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) - CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) val historyDirectory = System.getProperty("user.home") try { if (new File(historyDirectory).exists()) { val historyFile = historyDirectory + File.separator + ".hivehistory" - reader.setHistory(new History(new File(historyFile))) + reader.setHistory(new FileHistory(new File(historyFile))) } else { - System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") } } catch { case e: Exception => - System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + logWarning("WARNING: Encountered an error while trying to initialize Hive's " + "history file. History will not be available during this session.") - System.err.println(e.getMessage) + logWarning(e.getMessage) } + // TODO: missing +/* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") clientTransportTSocketField.setAccessible(true) transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] +*/ + transport = null var ret = 0 var prefix = "" val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) - def promptWithCurrentDB = s"$prompt$currentDB" - def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + def promptWithCurrentDB: String = s"$prompt$currentDB" + def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) var currentPrompt = promptWithCurrentDB var line = reader.readLine(currentPrompt + "> ") while (line != null) { - if (prefix.nonEmpty) { - prefix += '\n' - } + if (!line.startsWith("--")) { + if (prefix.nonEmpty) { + prefix += '\n' + } - if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { - line = prefix + line - ret = cli.processLine(line, true) - prefix = "" - currentPrompt = promptWithCurrentDB - } else { - prefix = prefix + line - currentPrompt = continuedPromptWithDBSpaces + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } } - line = reader.readLine(currentPrompt + "> ") } @@ -223,6 +235,13 @@ private[hive] object SparkSQLCLIDriver { System.exit(ret) } + + + def isRemoteMode(state: CliSessionState): Boolean = { + // sessionState.isRemoteMode + state.isHiveServerQuery + } + } private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { @@ -232,25 +251,33 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + private val isRemoteMode = { + SparkSQLCLIDriver.isRemoteMode(sessionState) + } + private val conf: Configuration = if (sessionState != null) sessionState.getConf else new Configuration() // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver // because the Hive unit tests do not go through the main() code path. - if (!sessionState.isRemoteMode) { + if (!isRemoteMode) { SparkSQLEnv.init() + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - if (cmd_trimmed.toLowerCase.equals("quit") || - cmd_trimmed.toLowerCase.equals("exit") || - tokens(0).equalsIgnoreCase("source") || + if (cmd_lower.equals("quit") || + cmd_lower.equals("exit") || + tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || - sessionState.isRemoteMode) { + isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) val end = System.currentTimeMillis() @@ -260,21 +287,23 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { + // scalastyle:off println + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || + proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver driver.init() val out = sessionState.out - val start:Long = System.currentTimeMillis() + val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) } val rc = driver.run(cmd) val end = System.currentTimeMillis() - val timeTaken:Double = (end - start) / 1000.0 + val timeTaken: Double = (end - start) / 1000.0 ret = rc.getResponseCode if (ret != 0) { @@ -287,18 +316,22 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { // Print the column names. - Option(driver.getSchema.getFieldSchemas).map { fields => - out.println(fields.map(_.getName).mkString("\t")) + Option(driver.getSchema.getFieldSchemas).foreach { fields => + out.println(fields.asScala.map(_.getName).mkString("\t")) } } + var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach(out.println) + res.asScala.foreach { l => + counter += 1 + out.println(l) + } res.clear() } } catch { - case e:IOException => + case e: IOException => console.printError( s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} |${org.apache.hadoop.util.StringUtils.stringifyException(e)} @@ -311,7 +344,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret = cret } - console.printInfo(s"Time taken: $timeTaken seconds", null) + var responseMsg = s"Time taken: $timeTaken seconds" + if (counter != 0) { + responseMsg += s", Fetched $counter row(s)" + } + console.printInfo(responseMsg , null) // Destroy the driver to release all the locks. driver.destroy() } else { @@ -320,6 +357,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } ret = proc.run(cmd_1).getResponseCode } + // scalastyle:on println } ret } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d7294..5ad8c54f296d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,38 +21,38 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ +import org.apache.hive.service.server.HiveServer2 import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.util.Utils -private[hive] class SparkSQLCLIService(hiveContext: HiveContext) - extends CLIService +private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) + extends CLIService(hiveServer) with ReflectedCompositeService { override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) - val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims.isSecurityEnabled) { + if (UserGroupInformation.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + sparkServiceUGI = Utils.getUGI() + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) @@ -76,7 +76,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => def initCompositeService(hiveConf: HiveConf) { // Emulating `CompositeService.init(hiveConf)` val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") - serviceList.foreach(_.init(hiveConf)) + serviceList.asScala.foreach(_.init(hiveConf)) // Emulating `AbstractService.init(hiveConf)` invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 000000000000..2619286afc14 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.{Arrays, ArrayList => JArrayList, List => JList} + +import scala.collection.JavaConverters._ + +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver + with Logging { + + private[hive] var tableSchema: Schema = _ + private[hive] var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logDebug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.isEmpty) { + new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas.asJava, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + // TODO unify the error code + try { + context.sparkContext.setJobDescription(command) + val execution = context.executePlan(context.sql(command).logicalPlan) + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logError(s"Failed in [$command]", cause) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava) + hiveResponse = null + true + } + } + + override def getSchema: Schema = tableSchema + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 158c22515972..bacf6cc458fd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ +import java.io.PrintStream + +import scala.collection.JavaConverters._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -35,10 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") - .set("spark.sql.hive.version", HiveShim.version) + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) @@ -50,8 +57,14 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext.addSparkListener(new StatsReportListener()) hiveContext = new HiveContext(sparkContext) + hiveContext.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) + hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) + hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + + hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) + if (log.isDebugEnabled) { - hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + hiveContext.hiveconf.getAllProperties.asScala.toSeq.sorted.foreach { case (k, v) => logDebug(s"HiveConf var: $k=$v") } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 89e9ede7261c..92ac0ec3fca2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -22,15 +22,18 @@ import java.util.concurrent.Executors import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.SessionHandle import org.apache.hive.service.cli.session.SessionManager +import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.hive.service.server.HiveServer2 import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager -import org.apache.hive.service.cli.SessionHandle -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager + +private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveContext) + extends SessionManager(hiveServer) with ReflectedCompositeService { private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) @@ -49,8 +52,29 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) initCompositeService(hiveConf) } + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + ipAddress: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = + super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, + delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9c0bf02391e0..c8031ed0f343 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -44,9 +44,12 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( - hiveContext, sessionToActivePool) + val runInBackground = async && hiveContext.hiveThriftServerAsync + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, + runInBackground)(hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created Operation for $statement with session=$parentSession, " + + s"runInBackground=$runInBackground") operation } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala new file mode 100644 index 000000000000..e990bd06011f --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import java.util.Calendar +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.Logging +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{SessionInfo, ExecutionState, ExecutionInfo} +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui._ + + +/** Page for Spark Web UI that shows statistics of a thrift server */ +private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging { + + private val listener = parent.listener + private val startTime = Calendar.getInstance().getTime() + private val emptyCell = "-" + + /** Render the page */ + def render(request: HttpServletRequest): Seq[Node] = { + val content = + listener.synchronized { // make sure all parts in this page are consistent + generateBasicStats() ++ +
    ++ +

    + {listener.getOnlineSessionNum} session(s) are online, + running {listener.getTotalRunning} SQL statement(s) +

    ++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + } + UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + } + + /** Generate basic stats of the thrift server program */ + private def generateBasicStats(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime.getTime +
      +
    • + Started at: {formatDate(startTime)} +
    • +
    • + Time since start: {formatDurationVerbose(timeSinceStart)} +
    • +
    + } + + /** Generate stats of batch statements of the thrift server program */ + private def generateSQLStatsTable(): Seq[Node] = { + val numStatement = listener.getExecutionList.size + val table = if (numStatement > 0) { + val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", + "Statement", "State", "Detail") + val dataRows = listener.getExecutionList + + def generateDataRow(info: ExecutionInfo): Seq[Node] = { + val jobLink = info.jobId.map { id: String => + + [{id}] + + } + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan +
    + + + + + + + + + {errorMessageCell(detail)} + + } + + Some(UIUtils.listingTable(headerRow, generateDataRow, + dataRows, false, None, Seq(null), false)) + } else { + None + } + + val content = +
    SQL Statistics
    ++ +
    +
      + {table.getOrElse("No statistics have been generated yet.")} +
    +
    + + content + } + + private def errorMessageCell(errorMessage: String): Seq[Node] = { + val isMultiline = errorMessage.indexOf('\n') >= 0 + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else { + errorMessage + }) + val details = if (isMultiline) { + // scalastyle:off + + + details + ++ + + // scalastyle:on + } else { + "" + } + + } + + /** Generate stats of batch sessions of the thrift server program */ + private def generateSessionStatsTable(): Seq[Node] = { + val sessionList = listener.getSessionList + val numBatches = sessionList.size + val table = if (numBatches > 0) { + val dataRows = sessionList + val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", + "Total Execute") + def generateDataRow(session: SessionInfo): Seq[Node] = { + val sessionLink = "%s/%s/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + + + + + + + + + + } + Some(UIUtils.listingTable(headerRow, generateDataRow, dataRows, true, None, Seq(null), false)) + } else { + None + } + + val content = +
    Session Statistics
    ++ +
    +
      + {table.getOrElse("No statistics have been generated yet.")} +
    +
    + + content + } + + + /** + * Returns a human-readable string representing a duration such as "5 second 35 ms" + */ + private def formatDurationOption(msOption: Option[Long]): String = { + msOption.map(formatDurationVerbose).getOrElse(emptyCell) + } + + /** Generate HTML table from string data */ + private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { + def generateDataRow(data: Seq[String]): Seq[Node] = { + {data.map(d => )} + } + UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala new file mode 100644 index 000000000000..af16cb31df18 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import java.util.Calendar +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.Logging +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui._ + +/** Page for Spark Web UI that shows statistics of a streaming job */ +private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) + extends WebUIPage("session") with Logging { + + private val listener = parent.listener + private val startTime = Calendar.getInstance().getTime() + private val emptyCell = "-" + + /** Render the page */ + def render(request: HttpServletRequest): Seq[Node] = { + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val content = + listener.synchronized { // make sure all parts in this page are consistent + val sessionStat = listener.getSession(parameterId).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + generateBasicStats() ++ +
    ++ +

    + User {sessionStat.userName}, + IP {sessionStat.ip}, + Session created at {formatDate(sessionStat.startTimestamp)}, + Total run {sessionStat.totalExecution} SQL +

    ++ + generateSQLStatsTable(sessionStat.sessionId) + } + UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + } + + /** Generate basic stats of the streaming program */ + private def generateBasicStats(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime.getTime +
      +
    • + Started at: {startTime.toString} +
    • +
    • + Time since start: {formatDurationVerbose(timeSinceStart)} +
    • +
    + } + + /** Generate stats of batch statements of the thrift server program */ + private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + val executionList = listener.getExecutionList + .filter(_.sessionId == sessionID) + val numStatement = executionList.size + val table = if (numStatement > 0) { + val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", + "Statement", "State", "Detail") + val dataRows = executionList.sortBy(_.startTimestamp).reverse + + def generateDataRow(info: ExecutionInfo): Seq[Node] = { + val jobLink = info.jobId.map { id: String => + + [{id}] + + } + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan + + + + + + + + + + {errorMessageCell(detail)} + + } + + Some(UIUtils.listingTable(headerRow, generateDataRow, + dataRows, false, None, Seq(null), false)) + } else { + None + } + + val content = +
    SQL Statistics
    ++ +
    +
      + {table.getOrElse("No statistics have been generated yet.")} +
    +
    + + content + } + + private def errorMessageCell(errorMessage: String): Seq[Node] = { + val isMultiline = errorMessage.indexOf('\n') >= 0 + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else { + errorMessage + }) + val details = if (isMultiline) { + // scalastyle:off + + + details + ++ + + // scalastyle:on + } else { + "" + } + + } + + /** Generate stats of batch sessions of the thrift server program */ + private def generateSessionStatsTable(): Seq[Node] = { + val sessionList = listener.getSessionList + val numBatches = sessionList.size + val table = if (numBatches > 0) { + val dataRows = + sessionList.sortBy(_.startTimestamp).reverse.map ( session => + Seq( + session.userName, + session.ip, + session.sessionId, + formatDate(session.startTimestamp), + formatDate(session.finishTimestamp), + formatDurationOption(Some(session.totalTime)), + session.totalExecution.toString + ) + ).toSeq + val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", + "Total Execute") + Some(listingTable(headerRow, dataRows)) + } else { + None + } + + val content = +
    Session Statistics
    ++ +
    +
      + {table.getOrElse("No statistics have been generated yet.")} +
    +
    + + content + } + + + /** + * Returns a human-readable string representing a duration such as "5 second 35 ms" + */ + private def formatDurationOption(msOption: Option[Long]): String = { + msOption.map(formatDurationVerbose).getOrElse(emptyCell) + } + + /** Generate HTML table from string data */ + private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { + def generateDataRow(data: Seq[String]): Seq[Node] = { + {data.map(d => )} + } + UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala new file mode 100644 index 000000000000..4eabeaa6735e --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +import org.apache.spark.sql.hive.thriftserver.{HiveThriftServer2, SparkSQLEnv} +import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab._ +import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.{SparkContext, Logging, SparkException} + +/** + * Spark Web UI tab that shows statistics of a streaming job. + * This assumes the given SparkContext has enabled its SparkUI. + */ +private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) + extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { + + override val name = "JDBC/ODBC Server" + + val parent = getSparkUI(sparkContext) + val listener = HiveThriftServer2.listener + + attachPage(new ThriftServerPage(this)) + attachPage(new ThriftServerSessionPage(this)) + parent.attachTab(this) + + def detach() { + getSparkUI(sparkContext).detachTab(this) + } +} + +private[thriftserver] object ThriftServerTab { + def getSparkUI(sparkContext: SparkContext): SparkUI = { + sparkContext.ui.getOrElse { + throw new SparkException("Parent SparkUI to attach this tab to not found!") + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 60953576d0e3..e59a14ec00d5 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -19,47 +18,83 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ +import java.sql.Timestamp +import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} -class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { +/** + * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary + * Hive metastore and warehouse. + */ +class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { + val warehousePath = Utils.createTempDir() + val metastorePath = Utils.createTempDir() + val scratchDirPath = Utils.createTempDir() + + before { + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() + } + + after { + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() + } + + /** + * Run a CLI operation and expect all the queries and expected answers to be returned. + * @param timeout maximum time for the commands to complete + * @param extraArgs any extra arguments + * @param errorResponses a sequence of strings whose presence in the stdout of the forked process + * is taken as an immediate error condition. That is: if a line beginning + * with one of these strings is found, fail the test immediately. + * The default value is `Seq("Error:")` + * + * @param queriesAndExpectedAnswers one or more tupes of query + answer + */ def runCliWithin( timeout: FiniteDuration, - extraArgs: Seq[String] = Seq.empty)( - queriesAndExpectedAnswers: (String, String)*) { + extraArgs: Seq[String] = Seq.empty, + errorResponses: Seq[String] = Seq("Error:"))( + queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") - val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queriesString = queries.map(_ + "\n").mkString val command = { + val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --driver-class-path ${sys.props("java.class.path")} + | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object def captureOutput(source: String)(line: String): Unit = lock.synchronized { - buffer += s"$source> $line" + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. + buffer += s"${new Timestamp(new Date().getTime)} - $source> $line" + // If we haven't found all expected answers and another expected answer comes up... if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { next += 1 @@ -67,23 +102,36 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { if (next == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } + } else { + errorResponses.foreach { r => + if (line.startsWith(r)) { + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with error line '$line'")) + } + } } } - // Searching expected output line from both stdout and stderr of the CLI process - val process = (Process(command, None) #< queryStream).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + val process = new ProcessBuilder(command: _*).start() + + val stdinWriter = new OutputStreamWriter(process.getOutputStream) + stdinWriter.write(queriesString) + stdinWriter.flush() + stdinWriter.close() + + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => - logError( + val message = s""" |======================= |CliSuite failure output |======================= |Spark SQL CLI command line: ${command.mkString(" ")} - | + |Exception: $cause |Executed query $next "${queries(next)}", |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | @@ -91,11 +139,10 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |=========================== |End CliSuite failure output |=========================== - """.stripMargin, cause) - throw cause + """.stripMargin + logError(message, cause) + fail(message, cause) } finally { - warehousePath.delete() - metastorePath.delete() process.destroy() } } @@ -116,11 +163,60 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "Time taken: " + -> "OK" ) } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + } + + test("Single command with --database") { + runCliWithin(2.minute)( + "CREATE DATABASE hive_test_db;" + -> "OK", + "USE hive_test_db;" + -> "OK", + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "Time taken: " + ) + + runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + "" + -> "OK", + "" + -> "hive_test" + ) + } + + test("Commands using SerDe provided in --jars") { + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( + """CREATE TABLE t1(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "OK", + "CREATE TABLE sourceTable (key INT, val STRING);" + -> "OK", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" + -> "OK", + "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" + -> "Time taken:", + "SELECT count(key) FROM t1;" + -> "5", + "DROP TABLE t1;" + -> "OK", + "DROP TABLE sourceTable;" + -> "OK" + ) } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala deleted file mode 100644 index b52a51d11e4a..000000000000 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.io.File -import java.net.ServerSocket -import java.sql.{Date, DriverManager, Statement} -import java.util.concurrent.TimeoutException - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} -import scala.util.Try - -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.jdbc.HiveDriver -import org.apache.hive.service.auth.PlainSaslHelper -import org.apache.hive.service.cli.GetInfoType -import org.apache.hive.service.cli.thrift.TCLIService.Client -import org.apache.hive.service.cli.thrift._ -import org.apache.thrift.protocol.TBinaryProtocol -import org.apache.thrift.transport.TSocket -import org.scalatest.FunSuite - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.hive.HiveShim - -/** - * Tests for the HiveThriftServer2 using JDBC. - * - * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be - * rebuilt after changing HiveThriftServer2 related code. - */ -class HiveThriftServer2Suite extends FunSuite with Logging { - Class.forName(classOf[HiveDriver].getCanonicalName) - - object TestData { - def getTestDataFilePath(name: String) = { - Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") - } - - val smallKv = getTestDataFilePath("small_kv.txt") - val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") - } - - def randomListeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - def withJdbcStatement( - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: Statement => Unit) { - val port = randomListeningPort - - startThriftServer(port, serverStartTimeout, httpMode) { - val jdbcUri = if (httpMode) { - s"jdbc:hive2://${"localhost"}:$port/" + - "default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice" - } else { - s"jdbc:hive2://${"localhost"}:$port/" - } - - val user = System.getProperty("user.name") - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } - } - } - - def withCLIServiceClient( - serverStartTimeout: FiniteDuration = 1.minute)( - f: ThriftCLIServiceClient => Unit) { - val port = randomListeningPort - - startThriftServer(port) { - // Transport creation logics below mimics HiveConnection.createBinaryTransport - val rawTransport = new TSocket("localhost", port) - val user = System.getProperty("user.name") - val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) - val protocol = new TBinaryProtocol(transport) - val client = new ThriftCLIServiceClient(new Client(protocol)) - - transport.open() - - try { - f(client) - } finally { - transport.close() - } - } - } - - def startThriftServer( - port: Int, - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: => Unit) { - val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) - - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") - val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - val command = - if (httpMode) { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } else { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } - - val serverRunning = Promise[Unit]() - val buffer = new ArrayBuffer[String]() - val LOGGING_MARK = - s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " - var logTailingProcess: Process = null - var logFilePath: String = null - - def captureLogOutput(line: String): Unit = { - buffer += line - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { - serverRunning.success(()) - } - } - - def captureThriftServerOutput(source: String)(line: String): Unit = { - if (line.startsWith(LOGGING_MARK)) { - logFilePath = line.drop(LOGGING_MARK.length).trim - // Ensure that the log file is created so that the `tail' command won't fail - Try(new File(logFilePath).createNewFile()) - logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") - .run(ProcessLogger(captureLogOutput, _ => ())) - } - } - - val env = Seq( - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - "SPARK_TESTING" -> "0") - - Process(command, None, env: _*).run(ProcessLogger( - captureThriftServerOutput("stdout"), - captureThriftServerOutput("stderr"))) - - try { - Await.result(serverRunning.future, serverStartTimeout) - f - } catch { - case cause: Exception => - cause match { - case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) - case _ => - } - logError( - s""" - |===================================== - |HiveThriftServer2Suite failure output - |===================================== - |HiveThriftServer2 command line: ${command.mkString(" ")} - |Binding port: $port - |System user: ${System.getProperty("user.name")} - | - |${buffer.mkString("\n")} - |========================================= - |End HiveThriftServer2Suite failure output - |========================================= - """.stripMargin, cause) - throw cause - } finally { - warehousePath.delete() - metastorePath.delete() - Process(stopScript, None, env: _*).run().exitValue() - // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Thread.sleep(3.seconds.toMillis) - Option(logTailingProcess).map(_.destroy()) - Option(logFilePath).map(new File(_).delete()) - } - } - - test("Test JDBC query execution") { - withJdbcStatement() { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("Test JDBC query execution in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("SPARK-3004 regression: result set containing NULL") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_null", - "CREATE TABLE test_null(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") - - (0 until 5).foreach { _ => - resultSet.next() - assert(resultSet.getInt(1) === 0) - assert(resultSet.wasNull()) - } - - assert(!resultSet.next()) - } - } - - test("GetInfo Thrift API") { - withCLIServiceClient() { client => - val user = System.getProperty("user.name") - val sessionHandle = client.openSession(user, "") - - assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue - } - - assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue - } - - assertResult(true, "Spark version shouldn't be \"Unknown\"") { - val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue - logInfo(s"Spark version: $version") - version != "Unknown" - } - } - } - - test("Checks Hive version") { - withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("Checks Hive version in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("SPARK-4292 regression: result set iterator issue") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_4292", - "CREATE TABLE test_4292(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT key FROM test_4292") - - Seq(238, 86, 311, 27, 165).foreach { key => - resultSet.next() - assert(resultSet.getInt(1) === key) - } - - statement.executeQuery("DROP TABLE IF EXISTS test_4292") - } - } - - test("SPARK-4309 regression: Date type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_date", - "CREATE TABLE test_date(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") - - queries.foreach(statement.execute) - - assertResult(Date.valueOf("2011-01-01")) { - val resultSet = statement.executeQuery( - "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") - resultSet.next() - resultSet.getDate(1) - } - } - } - - test("SPARK-4407 regression: Complex type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_map", - "CREATE TABLE test_map(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") - - queries.foreach(statement.execute) - - assertResult("""{238:"val_238"}""") { - val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - - assertResult("""["238","val_238"]""") { - val resultSet = statement.executeQuery( - "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - } - } -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala new file mode 100644 index 000000000000..b72249b3bf8c --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -0,0 +1,704 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.File +import java.net.URL +import java.sql.{Date, DriverManager, SQLException, Statement} + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise, future} +import scala.util.{Random, Try} + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} + +object TestData { + def getTestDataFilePath(name: String): URL = { + Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") + } + + val smallKv = getTestDataFilePath("small_kv.txt") + val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") +} + +class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.binary + + private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { + // Transport creation logic below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", serverPort) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + try f(client) finally transport.close() + } + + test("GetInfo Thrift API") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) + } + } + + test("SPARK-3004 regression: result set containing NULL") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_null", + "CREATE TABLE test_null(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + + (0 until 5).foreach { _ => + resultSet.next() + assert(resultSet.getInt(1) === 0) + assert(resultSet.wasNull()) + } + + assert(!resultSet.next()) + } + } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) === key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } + + test("SPARK-4309 regression: Date type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_date", + "CREATE TABLE test_date(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") + + queries.foreach(statement.execute) + + assertResult(Date.valueOf("2011-01-01")) { + val resultSet = statement.executeQuery( + "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") + resultSet.next() + resultSet.getDate(1) + } + } + } + + test("SPARK-4407 regression: Complex type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + assertResult("""{238:"val_238"}""") { + val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + + assertResult("""["238","val_238"]""") { + val resultSet = statement.executeQuery( + "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + } + } + + test("test multiple session") { + import org.apache.spark.sql.SQLConf + var defaultV1: String = null + var defaultV2: String = null + + withMultipleConnectionJdbcStatement( + // create table + { statement => + + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map", + "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC") + + queries.foreach(statement.execute) + + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + }, + + // first session, we get the default value of the session status + { statement => + + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") + rs1.next() + defaultV1 = rs1.getString(1) + assert(defaultV1 != "200") + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + + defaultV2 = rs2.getString(1) + assert(defaultV1 != "true") + rs2.close() + }, + + // second session, we update the session status + { statement => + + val queries = Seq( + s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291", + "SET hive.cli.print.header=true" + ) + + queries.map(statement.execute) + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") + rs1.next() + assert("spark.sql.shuffle.partitions" === rs1.getString(1)) + assert("291" === rs1.getString(2)) + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + assert("hive.cli.print.header" === rs2.getString(1)) + assert("true" === rs2.getString(2)) + rs2.close() + }, + + // third session, we get the latest session status, supposed to be the + // default value + { statement => + + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") + rs1.next() + assert(defaultV1 === rs1.getString(1)) + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + assert(defaultV2 === rs2.getString(1)) + rs2.close() + }, + + // accessing the cached data in another session + { statement => + + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + statement.executeQuery("UNCACHE TABLE test_table") + + // TODO need to figure out how to determine if the data loaded from cache + val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf3 = new collection.mutable.ArrayBuffer[Int]() + while (rs3.next()) { + buf3 += rs3.getInt(1) + } + rs3.close() + + assert(buf1 === buf3) + }, + + // accessing the uncached table + { statement => + + // TODO need to figure out how to determine if the data loaded from cache + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + } + ) + } + + test("test jdbc cancel") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + val largeJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ") + val f = future { Thread.sleep(100); statement.cancel(); } + val e = intercept[SQLException] { + statement.executeQuery(largeJoin) + } + assert(e.getMessage contains "cancelled") + Await.result(f, 3.minute) + + // cancel is a noop + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + val sf = future { Thread.sleep(100); statement.cancel(); } + val smallJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + val rs1 = statement.executeQuery(smallJoin) + Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } + } + + test("test add jar") { + withMultipleConnectionJdbcStatement( + { + statement => + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + statement.executeQuery(s"ADD JAR $jarFile") + }, + + { + statement => + val queries = Seq( + "DROP TABLE IF EXISTS smallKV", + "CREATE TABLE smallKV(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE smallKV", + "DROP TABLE IF EXISTS addJar", + """CREATE TABLE addJar(key string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + + queries.foreach(statement.execute) + + statement.executeQuery( + """ + |INSERT INTO TABLE addJar SELECT 'k1' as key FROM smallKV limit 1 + """.stripMargin) + + val actualResult = + statement.executeQuery("SELECT key FROM addJar") + val actualResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (actualResult.next()) { + actualResultBuffer += actualResult.getString(1) + } + actualResult.close() + + val expectedResult = + statement.executeQuery("SELECT 'k1'") + val expectedResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (expectedResult.next()) { + expectedResultBuffer += expectedResult.getString(1) + } + expectedResult.close() + + assert(expectedResultBuffer === actualResultBuffer) + + statement.executeQuery("DROP TABLE IF EXISTS addJar") + statement.executeQuery("DROP TABLE IF EXISTS smallKV") + } + ) + } +} + +class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.http + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) + } + } +} + +object ServerMode extends Enumeration { + val binary, http = Value +} + +abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { + Utils.classForName(classOf[HiveDriver].getCanonicalName) + + private def jdbcUri = if (mode == ServerMode.http) { + s"""jdbc:hive2://localhost:$serverPort/ + |default? + |hive.server2.transport.mode=http; + |hive.server2.thrift.http.path=cliservice + """.stripMargin.split("\n").mkString.trim + } else { + s"jdbc:hive2://localhost:$serverPort/" + } + + def withMultipleConnectionJdbcStatement(fs: (Statement => Unit)*) { + val user = System.getProperty("user.name") + val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } + val statements = connections.map(_.createStatement()) + + try { + statements.zip(fs).foreach { case (s, f) => f(s) } + } finally { + statements.foreach(_.close()) + connections.foreach(_.close()) + } + } + + def withJdbcStatement(f: Statement => Unit) { + withMultipleConnectionJdbcStatement(f) + } +} + +abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging { + def mode: ServerMode.Value + + private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") + private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to " + + protected val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + protected val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + private var listeningPort: Int = _ + protected def serverPort: Int = listeningPort + + protected def user = System.getProperty("user.name") + + protected var warehousePath: File = _ + protected var metastorePath: File = _ + protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + + private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private var logPath: File = _ + private var logTailingProcess: Process = _ + private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + + protected def serverStartCommand(port: Int) = { + val portConf = if (mode == ServerMode.binary) { + ConfVars.HIVE_SERVER2_THRIFT_PORT + } else { + ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT + } + + val driverClassPath = { + // Writes a temporary log4j.properties and prepend it to driver classpath, so that it + // overrides all other potential log4j configurations contained in other dependency jar files. + val tempLog4jConf = Utils.createTempDir().getCanonicalPath + + Files.write( + """log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) + + tempLog4jConf + } + + s"""$startScript + | --master local + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost + | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf $portConf=$port + | --driver-class-path $driverClassPath + | --driver-java-options -Dlog4j.debug + | --conf spark.ui.enabled=false + """.stripMargin.split("\\s+").toSeq + } + + /** + * String to scan for when looking for the the thrift binary endpoint running. + * This can change across Hive versions. + */ + val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" + + /** + * String to scan for when looking for the the thrift HTTP endpoint running. + * This can change across Hive versions. + */ + val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" + + val SERVER_STARTUP_TIMEOUT = 3.minutes + + private def startThriftServer(port: Int, attempt: Int) = { + warehousePath = Utils.createTempDir() + warehousePath.delete() + metastorePath = Utils.createTempDir() + metastorePath.delete() + logPath = null + logTailingProcess = null + + val command = serverStartCommand(port) + + diagnosisBuffer ++= + s""" + |### Attempt $attempt ### + |HiveThriftServer2 command line: $command + |Listening port: $port + |System user: $user + """.stripMargin.split("\n") + + logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") + + logPath = { + val lines = Utils.executeAndGetOutput( + command = command, + extraEnvironment = Map( + // Disables SPARK_TESTING to exclude log4j.properties in test directories. + "SPARK_TESTING" -> "0", + // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be + // started at a time, which is not Jenkins friendly. + "SPARK_PID_DIR" -> pidDir.getCanonicalPath), + redirectStderr = true) + + lines.split("\n").collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } + } + + val serverStarted = Promise[Unit]() + + // Ensures that the following "tail" command won't fail. + logPath.createNewFile() + val successLines = Seq(THRIFT_BINARY_SERVICE_LIVE, THRIFT_HTTP_SERVICE_LIVE) + + logTailingProcess = { + val command = s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}".split(" ") + // Using "-n +0" to make sure all lines in the log file are checked. + val builder = new ProcessBuilder(command: _*) + val captureOutput = (line: String) => diagnosisBuffer.synchronized { + diagnosisBuffer += line + + successLines.foreach { r => + if (line.contains(r)) { + serverStarted.trySuccess(()) + } + } + } + + val process = builder.start() + + new ProcessOutputCapturer(process.getInputStream, captureOutput).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput).start() + process + } + + Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) + } + + private def stopThriftServer(): Unit = { + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Utils.executeAndGetOutput( + command = Seq(stopScript), + extraEnvironment = Map("SPARK_PID_DIR" -> pidDir.getCanonicalPath)) + Thread.sleep(3.seconds.toMillis) + + warehousePath.delete() + warehousePath = null + + metastorePath.delete() + metastorePath = null + + Option(logPath).foreach(_.delete()) + logPath = null + + Option(logTailingProcess).foreach(_.destroy()) + logTailingProcess = null + } + + private def dumpLogs(): Unit = { + logError( + s""" + |===================================== + |HiveThriftServer2Suite failure output + |===================================== + |${diagnosisBuffer.mkString("\n")} + |========================================= + |End HiveThriftServer2Suite failure output + |========================================= + """.stripMargin) + } + + override protected def beforeAll(): Unit = { + // Chooses a random port between 10000 and 19999 + listeningPort = 10000 + Random.nextInt(10000) + diagnosisBuffer.clear() + + // Retries up to 3 times with different port numbers if the server fails to start + (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => + started.orElse { + listeningPort += 1 + stopThriftServer() + Try(startThriftServer(listeningPort, attempt)) + } + }.recover { + case cause: Throwable => + dumpLogs() + throw cause + }.get + + logInfo(s"HiveThriftServer2 started successfully") + } + + override protected def afterAll(): Unit = { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala new file mode 100644 index 000000000000..bf431cd6b026 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.util.Random + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.selenium.WebBrowser +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.ui.SparkUICssErrorHandler + +class UISeleniumSuite + extends HiveThriftJdbcTest + with WebBrowser with Matchers with BeforeAndAfterAll { + + implicit var webDriver: WebDriver = _ + var server: HiveThriftServer2 = _ + val uiPort = 20000 + Random.nextInt(10000) + override def mode: ServerMode.Value = ServerMode.binary + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } + super.beforeAll() + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + super.afterAll() + } + + override protected def serverStartCommand(port: Int) = { + val portConf = if (mode == ServerMode.binary) { + ConfVars.HIVE_SERVER2_THRIFT_PORT + } else { + ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT + } + + s"""$startScript + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost + | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf $portConf=$port + | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=true + | --conf spark.ui.port=$uiPort + """.stripMargin.split("\\s+").toSeq + } + + ignore("thrift server ui test") { + withJdbcStatement { statement => + val baseURL = s"http://localhost:$uiPort" + + val queries = Seq( + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to baseURL + find(cssSelector("""ul li a[href*="sql"]""")) should not be None + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (baseURL + "/sql") + find(id("sessionstat")) should not be None + find(id("sqlstat")) should not be None + + // check whether statements exists + queries.foreach { line => + findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line) + } + } + } + } +} diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala deleted file mode 100644 index ea9d61d8d0f5..000000000000 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.sql.{Date, Timestamp} -import java.util.{ArrayList => JArrayList, Map => JMap} - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.HiveSession - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} -import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.types._ - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.12.0" - - def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { - val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) - setSuperField(sparkCliService, "serverUserName", serverUserName) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } -} - -private[hive] class SparkExecuteStatementOperation( - parentSession: HiveSession, - statement: String, - confOverlay: JMap[String, String])( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - - private var result: DataFrame = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType() => - val hiveDecimal = from.getDecimal(ordinal) - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case DateType => - to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType() => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case DateType => - to.addColumnValue(ColumnValue.dateValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => - sessionToActivePool(parentSession.getSessionHandle) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - hiveContext.sparkContext.setJobDescription(statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - iter = { - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - result.rdd.toLocalIterator - } else { - result.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - setState(OperationState.ERROR) - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - } -} diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala deleted file mode 100644 index 71e3954b2c7a..000000000000 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.sql.{Date, Timestamp} -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} - -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.HiveSession - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} -import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.types._ - -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.13.1" - - def setServerUserName( - sparkServiceUGI: UserGroupInformation, - sparkCliService:SparkSQLCLIService) = { - setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JList[_]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) - hiveResponse = null - true - } - } -} - -private[hive] class SparkExecuteStatementOperation( - parentSession: HiveSession, - statement: String, - confOverlay: JMap[String, String], - runInBackground: Boolean = true)( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution - extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { - - private var result: DataFrame = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to += from.getString(ordinal) - case IntegerType => - to += from.getInt(ordinal) - case BooleanType => - to += from.getBoolean(ordinal) - case DoubleType => - to += from.getDouble(ordinal) - case FloatType => - to += from.getFloat(ordinal) - case DecimalType() => - to += from.getDecimal(ordinal) - case LongType => - to += from.getLong(ordinal) - case ByteType => - to += from.getByte(ordinal) - case ShortType => - to += from.getShort(ordinal) - case DateType => - to += from.getAs[Date](ordinal) - case TimestampType => - to += from.getAs[Timestamp](ordinal) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to += hiveString - } - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - validateDefaultFetchOrientation(order) - assertState(OperationState.FINISHED) - setHasResultSet(true) - val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) - if (!iter.hasNext) { - resultRowSet - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = ArrayBuffer[Any]() - var curCol = 0 - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - row += null - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) - curRow += 1 - } - resultRowSet - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => - sessionToActivePool(parentSession.getSessionHandle) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - hiveContext.sparkContext.setJobDescription(statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - iter = { - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - result.rdd.toLocalIterator - } else { - result.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - setState(OperationState.ERROR) - logError("Error executing query:", e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala new file mode 100644 index 000000000000..1a5ba20404c4 --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.File + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with hash joins. + */ +class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) + } + + override def afterAll() { + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) + + // Only run those query tests in the realWhileList (do not try other ignored query files). + override def testCases: Seq[(String, File)] = super.testCases.filter { + case (name, _) => realWhiteList.contains(name) + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0d934620aca0..ab309e0a1d36 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive /** @@ -48,17 +48,21 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Add Locale setting Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + RuleExecutor.resetTime() } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -116,6 +120,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // This test is totally fine except that it includes wrong queries and expects errors, but error // message format in Hive and Spark SQL differ. Should workaround this later. "udf_to_unix_timestamp", + // we can cast dates likes '2015-03-18' to a timestamp and extract the seconds. + // Hive returns null for second('2015-03-18') + "udf_second", + // we can cast dates likes '2015-03-18' to a timestamp and extract the minutes. + // Hive returns null for minute('2015-03-18') + "udf_minute", + // Cant run without local map/reduce. "index_auto_update", @@ -185,7 +196,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive does not support buckets. ".*bucket.*", - // No window support yet + // We have our own tests based on these query files. ".*window.*", // Fails in hive with authorization errors. @@ -222,8 +233,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", + // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive + // is src(key STRING, value STRING), and in the reflect.q, it failed in + // Integer.valueOf, which expect the first argument passed as STRING type not INT. + "udf_reflect", // Sort with Limit clause causes failure. "ctas", @@ -231,8 +244,66 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // timestamp in array, the output format of Hive contains double quotes, while // Spark SQL doesn't - "udf_sort_array" - ) ++ HiveShim.compatibilityBlackList + "udf_sort_array", + + // It has a bug and it has been fixed by + // https://issues.apache.org/jira/browse/HIVE-7673 (in Hive 0.14 and trunk). + "input46", + + // These tests were broken by the hive client isolation PR. + "part_inherit_tbl_props", + "part_inherit_tbl_props_with_star", + + "nullformatCTAS", // SPARK-7411: need to finish CTAS parser + + // The isolated classloader seemed to make some of our test reset mechanisms less robust. + "combine1", // This test changes compression settings in a way that breaks all subsequent tests. + "load_dyn_part14.*", // These work alone but fail when run with other tests... + + // the answer is sensitive for jdk version + "udf_java_method", + + // Spark SQL use Long for TimestampType, lose the precision under 1us + "timestamp_1", + "timestamp_2", + "timestamp_udf", + + // Hive returns string from UTC formatted timestamp, spark returns timestamp type + "date_udf", + + // Can't compare the result that have newline in it + "udf_get_json_object", + + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this + "udf7", + + // Trivial changes to DDL output + "compute_stats_empty_table", + "compute_stats_long", + "create_view_translate", + "show_create_table_serde", + "show_tblproperties", + + // Odd changes to output + "merge4", + + // Thift is broken... + "inputddl8", + + // Hive changed ordering of ddl: + "varchar_union1", + + // Parser changes in Hive 1.2 + "input25", + "input26", + + // Uses invalid table name + "innerjoin", + + // classpath problems + "compute_stats.*", + "udf_bitmap_.*" + ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or @@ -357,13 +428,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "database_drop", "database_location", "database_properties", + "date_1", "date_2", "date_3", "date_4", "date_comparison", "date_join1", "date_serde", - "date_udf", "decimal_1", "decimal_4", "decimal_join", @@ -517,10 +588,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl2", "inputddl3", "inputddl4", + "inputddl5", "inputddl6", "inputddl7", "inputddl8", "insert1", + "insert1_overwrite_partitions", "insert2_overwrite_partitions", "insert_compressed", "join0", @@ -625,6 +698,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "mapreduce8", "merge1", "merge2", + "merge4", "mergejoins", "multiMapJoin1", "multiMapJoin2", @@ -638,6 +712,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nonblock_op_deduplicate", "notable_alias1", "notable_alias2", + "nullformatCTAS", "nullgroup", "nullgroup2", "nullgroup3", @@ -717,6 +792,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "select_unquote_and", "select_unquote_not", "select_unquote_or", + "semicolon", "semijoin", "serde_regex", "serde_reported_schema", @@ -768,13 +844,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", - "timestamp_1", - "timestamp_2", "timestamp_3", "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "timestamp_udf", "touch", "transform_ppr1", "transform_ppr2", @@ -789,7 +862,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - "udf7", "udf8", "udf9", "udf_10_trims", @@ -851,7 +923,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_int", "udf_isnotnull", "udf_isnull", - "udf_java_method", "udf_lcase", "udf_length", "udf_lessthan", @@ -866,7 +937,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_lpad", "udf_ltrim", "udf_map", - "udf_minute", "udf_modulo", "udf_month", "udf_named_struct", @@ -883,6 +953,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_power", "udf_radians", "udf_rand", + "udf_reflect2", "udf_regexp", "udf_regexp_extract", "udf_regexp_replace", @@ -892,7 +963,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_round_3", "udf_rpad", "udf_rtrim", - "udf_second", "udf_sign", "udf_sin", "udf_smallint", @@ -919,6 +989,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_trim", "udf_ucase", "udf_unix_timestamp", + "udf_unhex", "udf_upper", "udf_var_pop", "udf_var_samp", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala new file mode 100644 index 000000000000..92bb9e6d73af --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -0,0 +1,826 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.File +import java.util.{Locale, TimeZone} + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils + +/** + * The test suite for window functions. To actually compare results with Hive, + * every test should be created by `createQueryTest`. Because we are reusing tables + * for different tests and there are a few properties needed to let Hive generate golden + * files, every `createQueryTest` calls should explicitly set `reset` to `false`. + */ +class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + private val testTempDir = Utils.createTempDir() + + override def beforeAll() { + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + + // Create the table used in windowing.q + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part + """.stripMargin) + + sql("DROP TABLE IF EXISTS over1k") + sql( + """ + |create table over1k( + | t tinyint, + | si smallint, + | i int, + | b bigint, + | f float, + | d double, + | bo boolean, + | s string, + | ts timestamp, + | dec decimal(4,2), + | bin binary) + |row format delimited + |fields terminated by '|' + """.stripMargin) + val testData2 = TestHive.getHiveFile("data/files/over1k").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData2' overwrite into table over1k + """.stripMargin) + + // The following settings are used for generating golden files with Hive. + // We have to use kryo to correctly let Hive serialize plans with window functions. + // This is used to generate golden files. + sql("set hive.plan.serialization.format=kryo") + // Explicitly set fs to local fs. + sql(s"set fs.default.name=file://$testTempDir/") + // Ask Hive to run jobs in-process as a single map and reduce task. + sql("set mapred.job.tracker=local") + } + + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests based on windowing_multipartitioning.q + // Results of the original query file are not deterministic. + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing_multipartitioning.q (deterministic) 1", + s""" + |select s, + |rank() over (partition by s order by si) r, + |sum(b) over (partition by s order by si) sum + |from over1k + |order by s, r, sum; + """.stripMargin, reset = false) + + /* timestamp comparison issue with Hive? + createQueryTest("windowing_multipartitioning.q (deterministic) 2", + s""" + |select s, + |rank() over (partition by s order by dec desc) r, + |sum(b) over (partition by s order by ts desc) as sum + |from over1k + |where s = 'tom allen' or s = 'bob steinbeck' + |order by s, r, sum; + """.stripMargin, reset = false) + */ + + createQueryTest("windowing_multipartitioning.q (deterministic) 3", + s""" + |select s, sum(i) over (partition by s), sum(f) over (partition by si) + |from over1k where s = 'tom allen' or s = 'bob steinbeck'; + """.stripMargin, reset = false) + + createQueryTest("windowing_multipartitioning.q (deterministic) 4", + s""" + |select s, rank() over (partition by s order by bo), + |rank() over (partition by si order by bin desc) from over1k + |where s = 'tom allen' or s = 'bob steinbeck'; + """.stripMargin, reset = false) + + createQueryTest("windowing_multipartitioning.q (deterministic) 5", + s""" + |select s, sum(f) over (partition by i), row_number() over (order by f) + |from over1k where s = 'tom allen' or s = 'bob steinbeck'; + """.stripMargin, reset = false) + + createQueryTest("windowing_multipartitioning.q (deterministic) 6", + s""" + |select s, rank() over w1, + |rank() over w2 + |from over1k + |where s = 'tom allen' or s = 'bob steinbeck' + |window + |w1 as (partition by s order by dec), + |w2 as (partition by si order by f) ; + """.stripMargin, reset = false) + + ///////////////////////////////////////////////////////////////////////////// + // Tests based on windowing_navfn.q + // Results of the original query file are not deterministic. + // Also, the original query of + // select i, lead(s) over (partition by bin order by d,i desc) from over1k ; + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing_navfn.q (deterministic)", + s""" + |select s, row_number() over (partition by d order by dec) rn from over1k + |order by s, rn desc; + |select i, lead(s) over (partition by cast(bin as string) order by d,i desc) as l + |from over1k + |order by i desc, l; + |select i, lag(dec) over (partition by i order by s,i,dec) l from over1k + |order by i, l; + |select s, last_value(t) over (partition by d order by f) l from over1k + |order by s, l; + |select s, first_value(s) over (partition by bo order by s) f from over1k + |order by s, f; + |select t, s, i, last_value(i) over (partition by t order by s) + |from over1k where (s = 'oscar allen' or s = 'oscar carson') and t = 10; + """.stripMargin, reset = false) + + ///////////////////////////////////////////////////////////////////////////// + // Tests based on windowing_ntile.q + // Results of the original query file are not deterministic. + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing_ntile.q (deterministic)", + s""" + |select i, ntile(10) over (partition by s order by i) n from over1k + |order by i, n; + |select s, ntile(100) over (partition by i order by s) n from over1k + |order by s, n; + |select f, ntile(4) over (partition by d order by f) n from over1k + |order by f, n; + |select d, ntile(1000) over (partition by dec order by d) n from over1k + |order by d, n; + """.stripMargin, reset = false) + + ///////////////////////////////////////////////////////////////////////////// + // Tests based on windowing_udaf.q + // Results of the original query file are not deterministic. + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing_udaf.q (deterministic)", + s""" + |select s, min(i) over (partition by s) m from over1k + |order by s, m; + |select s, avg(f) over (partition by si order by s) a from over1k + |order by s, a; + |select s, avg(i) over (partition by t, b order by s) a from over1k + |order by s, a; + |select max(i) over w m from over1k + |order by m window w as (partition by f) ; + |select s, avg(d) over (partition by t order by f) a from over1k + |order by s, a; + """.stripMargin, reset = false) + + ///////////////////////////////////////////////////////////////////////////// + // Tests based on windowing_windowspec.q + // Results of the original query file are not deterministic. + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing_windowspec.q (deterministic)", + s""" + |select s, sum(b) over (partition by i order by s,b rows unbounded preceding) as sum + |from over1k order by s, sum; + |select s, sum(f) over (partition by d order by s,f rows unbounded preceding) as sum + |from over1k order by s, sum; + |select s, sum(f) over + |(partition by ts order by f range between current row and unbounded following) as sum + |from over1k order by s, sum; + |select s, avg(f) + |over (partition by ts order by s,f rows between current row and 5 following) avg + |from over1k order by s, avg; + |select s, avg(d) over + |(partition by t order by s,d desc rows between 5 preceding and 5 following) avg + |from over1k order by s, avg; + |select s, sum(i) over(partition by ts order by s) sum from over1k + |order by s, sum; + |select f, sum(f) over + |(partition by ts order by f range between unbounded preceding and current row) sum + |from over1k order by f, sum; + |select s, i, round(avg(d) over (partition by s order by i) / 10.0 , 2) avg + |from over1k order by s, i, avg; + |select s, i, round((avg(d) over w1 + 10.0) - (avg(d) over w1 - 10.0),2) avg + |from over1k + |order by s, i, avg window w1 as (partition by s order by i); + """.stripMargin, reset = false) + + ///////////////////////////////////////////////////////////////////////////// + // Tests based on windowing_rank.q + // Results of the original query file are not deterministic. + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing_rank.q (deterministic) 1", + s""" + |select s, rank() over (partition by f order by t) r from over1k order by s, r; + |select s, dense_rank() over (partition by ts order by i,s desc) as r from over1k + |order by s desc, r desc; + |select s, cume_dist() over (partition by bo order by b,s) cd from over1k + |order by s, cd; + |select s, percent_rank() over (partition by dec order by f) r from over1k + |order by s desc, r desc; + """.stripMargin, reset = false) + + createQueryTest("windowing_rank.q (deterministic) 2", + s""" + |select ts, dec, rnk + |from + | (select ts, dec, + | rank() over (partition by ts order by dec) as rnk + | from + | (select other.ts, other.dec + | from over1k other + | join over1k on (other.b = over1k.b) + | ) joined + | ) ranked + |where rnk = 1 + |order by ts, dec, rnk; + """.stripMargin, reset = false) + + createQueryTest("windowing_rank.q (deterministic) 3", + s""" + |select ts, dec, rnk + |from + | (select ts, dec, + | rank() over (partition by ts order by dec) as rnk + | from + | (select other.ts, other.dec + | from over1k other + | join over1k on (other.b = over1k.b) + | ) joined + | ) ranked + |where dec = 89.5 + |order by ts, dec, rnk; + """.stripMargin, reset = false) + + createQueryTest("windowing_rank.q (deterministic) 4", + s""" + |select ts, dec, rnk + |from + | (select ts, dec, + | rank() over (partition by ts order by dec) as rnk + | from + | (select other.ts, other.dec + | from over1k other + | join over1k on (other.b = over1k.b) + | where other.t < 10 + | ) joined + | ) ranked + |where rnk = 1 + |order by ts, dec, rnk; + """.stripMargin, reset = false) + + ///////////////////////////////////////////////////////////////////////////// + // Tests from windowing.q + // We port tests in windowing.q to here because this query file contains too + // many tests and the syntax of test "-- 7. testJoinWithWindowingAndPTF" + // is not supported right now. + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing.q -- 1. testWindowing", + s""" + |select p_mfgr, p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |sum(p_retailprice) over + |(distribute by p_mfgr sort by p_name rows between unbounded preceding and current row) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 2. testGroupByWithPartitioning", + s""" + |select p_mfgr, p_name, p_size, + |min(p_retailprice), + |rank() over(distribute by p_mfgr sort by p_name)as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + |group by p_mfgr, p_name, p_size + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 3. testGroupByHavingWithSWQ", + s""" + |select p_mfgr, p_name, p_size, min(p_retailprice), + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + |group by p_mfgr, p_name, p_size + |having p_size > 0 + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 4. testCount", + s""" + |select p_mfgr, p_name, + |count(p_size) over(distribute by p_mfgr sort by p_name) as cd + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 5. testCountWithWindowingUDAF", + s""" + |select p_mfgr, p_name, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |count(p_size) over(distribute by p_mfgr sort by p_name) as cd, + |p_retailprice, sum(p_retailprice) over (distribute by p_mfgr sort by p_name + | rows between unbounded preceding and current row) as s1, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 6. testCountInSubQ", + s""" + |select sub1.r, sub1.dr, sub1.cd, sub1.s1, sub1.deltaSz + |from (select p_mfgr, p_name, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |count(p_size) over(distribute by p_mfgr sort by p_name) as cd, + |p_retailprice, sum(p_retailprice) over (distribute by p_mfgr sort by p_name + | rows between unbounded preceding and current row) as s1, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + |) sub1 + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 8. testMixedCaseAlias", + s""" + |select p_mfgr, p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name, p_size desc) as R + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 9. testHavingWithWindowingNoGBY", + s""" + |select p_mfgr, p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |sum(p_retailprice) over (distribute by p_mfgr sort by p_name + | rows between unbounded preceding and current row) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 10. testHavingWithWindowingCondRankNoGBY", + s""" + |select p_mfgr, p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |sum(p_retailprice) over (distribute by p_mfgr sort by p_name + | rows between unbounded preceding and current row) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 11. testFirstLast", + s""" + |select p_mfgr,p_name, p_size, + |sum(p_size) over (distribute by p_mfgr sort by p_name + |rows between current row and current row) as s2, + |first_value(p_size) over w1 as f, + |last_value(p_size, false) over w1 as l + |from part + |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 12. testFirstLastWithWhere", + s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |sum(p_size) over (distribute by p_mfgr sort by p_name + |rows between current row and current row) as s2, + |first_value(p_size) over w1 as f, + |last_value(p_size, false) over w1 as l + |from part + |where p_mfgr = 'Manufacturer#3' + |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 13. testSumWindow", + s""" + |select p_mfgr,p_name, p_size, + |sum(p_size) over w1 as s1, + |sum(p_size) over (distribute by p_mfgr sort by p_name + |rows between current row and current row) as s2 + |from part + |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 14. testNoSortClause", + s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr + |from part + |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 15. testExpressions", + s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |percent_rank() over(distribute by p_mfgr sort by p_name) as pr, + |ntile(3) over(distribute by p_mfgr sort by p_name) as nt, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |avg(p_size) over(distribute by p_mfgr sort by p_name) as avg, + |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, + |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, + |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 16. testMultipleWindows", + s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |sum(p_size) over (distribute by p_mfgr sort by p_name + |range between unbounded preceding and current row) as s1, + |sum(p_size) over (distribute by p_mfgr sort by p_size + |range between 5 preceding and current row) as s2, + |first_value(p_size) over w1 as fv1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + + createQueryTest("windowing.q -- 17. testCountStar", + s""" + |select p_mfgr,p_name, p_size, + |count(*) over(distribute by p_mfgr sort by p_name ) as c, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 18. testUDAFs", + s""" + |select p_mfgr,p_name, p_size, + |sum(p_retailprice) over w1 as s, + |min(p_retailprice) over w1 as mi, + |max(p_retailprice) over w1 as ma, + |avg(p_retailprice) over w1 as ag + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 19. testUDAFsWithGBY", + """ + |select p_mfgr,p_name, p_size, p_retailprice, + |sum(p_retailprice) over w1 as s, + |min(p_retailprice) as mi , + |max(p_retailprice) as ma , + |avg(p_retailprice) over w1 as ag + |from part + |group by p_mfgr,p_name, p_size, p_retailprice + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following); + """.stripMargin, reset = false) + + // collect_set() output array in an arbitrary order, hence causes different result + // when running this test suite under Java 7 and 8. + // We change the original sql query a little bit for making the test suite passed + // under different JDK + createQueryTest("windowing.q -- 20. testSTATs", + """ + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( + |select p_mfgr,p_name, p_size, + |stddev(p_retailprice) over w1 as sdev, + |stddev_pop(p_retailprice) over w1 as sdev_pop, + |collect_set(p_size) over w1 as uniq_size, + |variance(p_retailprice) over w1 as var, + |corr(p_size, p_retailprice) over w1 as cor, + |covar_pop(p_size, p_retailprice) over w1 as covarp + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 21. testDISTs", + """ + |select p_mfgr,p_name, p_size, + |histogram_numeric(p_retailprice, 5) over w1 as hist, + |percentile(p_partkey, 0.5) over w1 as per, + |row_number() over(distribute by p_mfgr sort by p_name) as rn + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 24. testLateralViews", + """ + |select p_mfgr, p_name, + |lv_col, p_size, sum(p_size) over w1 as s + |from (select p_mfgr, p_name, p_size, array(1,2,3) arr from part) p + |lateral view explode(arr) part_lv as lv_col + |window w1 as (distribute by p_mfgr sort by p_size, lv_col + | rows between 2 preceding and current row) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 26. testGroupByHavingWithSWQAndAlias", + """ + |select p_mfgr, p_name, p_size, min(p_retailprice) as mi, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + |group by p_mfgr, p_name, p_size + |having p_size > 0 + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 27. testMultipleRangeWindows", + """ + |select p_mfgr,p_name, p_size, + |sum(p_size) over (distribute by p_mfgr sort by p_size + |range between 10 preceding and current row) as s2, + |sum(p_size) over (distribute by p_mfgr sort by p_size + |range between current row and 10 following ) as s1 + |from part + |window w1 as (rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 28. testPartOrderInUDAFInvoke", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over (partition by p_mfgr order by p_name + |rows between 2 preceding and 2 following) as s + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 29. testPartOrderInWdwDef", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s + |from part + |window w1 as (partition by p_mfgr order by p_name + | rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 30. testDefaultPartitioningSpecRules", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s, + |sum(p_size) over w2 as s2 + |from part + |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following), + | w2 as (partition by p_mfgr order by p_name) + """.stripMargin, reset = false) + + /* p_name is not a numeric column. What is Hive's semantic? + createQueryTest("windowing.q -- 31. testWindowCrossReference", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s1, + |sum(p_size) over w2 as s2 + |from part + |window w1 as (partition by p_mfgr order by p_name + | range between 2 preceding and 2 following), + | w2 as w1 + """.stripMargin, reset = false) + */ + /* + createQueryTest("windowing.q -- 32. testWindowInheritance", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s1, + |sum(p_size) over w2 as s2 + |from part + |window w1 as (partition by p_mfgr order by p_name + | range between 2 preceding and 2 following), + | w2 as (w1 rows between unbounded preceding and current row) + """.stripMargin, reset = false) + */ + + /* p_name is not a numeric column. What is Hive's semantic? + createQueryTest("windowing.q -- 33. testWindowForwardReference", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s1, + |sum(p_size) over w2 as s2, + |sum(p_size) over w3 as s3 + |from part + |window w1 as (distribute by p_mfgr sort by p_name + | range between 2 preceding and 2 following), + | w2 as w3, + | w3 as (distribute by p_mfgr sort by p_name + | range between unbounded preceding and current row) + """.stripMargin, reset = false) + */ + /* + createQueryTest("windowing.q -- 34. testWindowDefinitionPropagation", + """ + |select p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s1, + |sum(p_size) over w2 as s2, + |sum(p_size) over (w3 rows between 2 preceding and 2 following) as s3 + |from part + |window w1 as (distribute by p_mfgr sort by p_name + | range between 2 preceding and 2 following), + | w2 as w3, + | w3 as (distribute by p_mfgr sort by p_name + | range between unbounded preceding and current row) + """.stripMargin, reset = false) + */ + + /* Seems Hive evaluate SELECT DISTINCT before window functions? + createQueryTest("windowing.q -- 35. testDistinctWithWindowing", + """ + |select DISTINCT p_mfgr, p_name, p_size, + |sum(p_size) over w1 as s + |from part + |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) + """.stripMargin, reset = false) + */ + + createQueryTest("windowing.q -- 36. testRankWithPartitioning", + """ + |select p_mfgr, p_name, p_size, + |rank() over (partition by p_mfgr order by p_name ) as r + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 37. testPartitioningVariousForms", + """ + |select p_mfgr, + |round(sum(p_retailprice) over (partition by p_mfgr order by p_mfgr),2) as s1, + |min(p_retailprice) over (partition by p_mfgr) as s2, + |max(p_retailprice) over (distribute by p_mfgr sort by p_mfgr) as s3, + |round(avg(p_retailprice) over (distribute by p_mfgr),2) as s4, + |count(p_retailprice) over (cluster by p_mfgr ) as s5 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 38. testPartitioningVariousForms2", + """ + |select p_mfgr, p_name, p_size, + |sum(p_retailprice) over (partition by p_mfgr, p_name order by p_mfgr, p_name + |rows between unbounded preceding and current row) as s1, + |min(p_retailprice) over (distribute by p_mfgr, p_name sort by p_mfgr, p_name + |rows between unbounded preceding and current row) as s2, + |max(p_retailprice) over (partition by p_mfgr, p_name order by p_name) as s3 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 39. testUDFOnOrderCols", + """ + |select p_mfgr, p_type, substr(p_type, 2) as short_ptype, + |rank() over (partition by p_mfgr order by substr(p_type, 2)) as r + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 40. testNoBetweenForRows", + """ + |select p_mfgr, p_name, p_size, + |sum(p_retailprice) over (distribute by p_mfgr sort by p_name rows unbounded preceding) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 41. testNoBetweenForRange", + """ + |select p_mfgr, p_name, p_size, + |sum(p_retailprice) over (distribute by p_mfgr sort by p_size range unbounded preceding) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 42. testUnboundedFollowingForRows", + """ + |select p_mfgr, p_name, p_size, + |sum(p_retailprice) over (distribute by p_mfgr sort by p_name + |rows between current row and unbounded following) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 43. testUnboundedFollowingForRange", + """ + |select p_mfgr, p_name, p_size, + |sum(p_retailprice) over (distribute by p_mfgr sort by p_size + |range between current row and unbounded following) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 44. testOverNoPartitionSingleAggregate", + """ + |select p_name, p_retailprice, + |round(avg(p_retailprice) over(),2) + |from part + |order by p_name + """.stripMargin, reset = false) +} + +class HiveWindowFunctionQueryFileSuite + extends HiveCompatibilitySuite with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + private val testTempDir = Utils.createTempDir() + + override def beforeAll() { + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + + // The following settings are used for generating golden files with Hive. + // We have to use kryo to correctly let Hive serialize plans with window functions. + // This is used to generate golden files. + // sql("set hive.plan.serialization.format=kryo") + // Explicitly set fs to local fs. + // sql(s"set fs.default.name=file://$testTempDir/") + // Ask Hive to run jobs in-process as a single map and reduce task. + // sql("set mapred.job.tracker=local") + } + + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } + + override def blackList: Seq[String] = Seq( + // Partitioned table functions are not supported. + "ptf*", + // tests of windowing.q are in HiveWindowFunctionQueryBaseSuite + "windowing.q", + + // This one failed on the expression of + // sum(lag(p_retailprice,1,0.0)) over w1 + // lag(p_retailprice,1,0.0) is a GenericUDF and the argument inspector of + // p_retailprice created by HiveInspectors is + // PrimitiveObjectInspectorFactory.javaDoubleObjectInspector. + // However, seems Hive assumes it is + // PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, which introduces an error. + "windowing_expressions", + + // Hive's results are not deterministic + "windowing_multipartitioning", + "windowing_navfn", + "windowing_ntile", + "windowing_udaf", + "windowing_windowspec", + "windowing_rank" + ) + + override def whiteList: Seq[String] = Seq( + "windowing_udaf2", + "windowing_columnPruning", + "windowing_adjust_rowcontainer_sz" + ) + + // Only run those query tests in the realWhileList (do not try other ignored query files). + override def testCases: Seq[(String, File)] = super.testCases.filter { + case (name, _) => realWhiteList.contains(name) + } +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 58b0722464be..be1607476e25 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../../pom.xml @@ -36,37 +36,64 @@ + + + com.twitter + parquet-hadoop-bundle + org.apache.spark spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} ${project.version} + - org.codehaus.jackson - jackson-mapper-asl + ${hive.group} + hive-exec + ${hive.group} - hive-serde + hive-metastore + org.apache.avro @@ -79,11 +106,79 @@ avro-mapred ${avro.mapred.classifier} + + commons-httpclient + commons-httpclient + + + org.apache.calcite + calcite-avatica + + + org.apache.calcite + calcite-core + + + org.apache.httpcomponents + httpclient + + + org.codehaus.jackson + jackson-mapper-asl + + + + commons-codec + commons-codec + + + joda-time + joda-time + + + org.jodd + jodd-core + + + com.google.code.findbugs + jsr305 + + + org.datanucleus + datanucleus-core + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + org.scalacheck scalacheck_${scala.binary.version} test + + junit + junit + test + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + ${project.version} + test + @@ -102,7 +197,6 @@ - src/test/scala compatibility/src/test/scala @@ -112,16 +206,6 @@ - - hive-0.12.0 - - - com.twitter - parquet-hive-bundle - 1.5.0 - - - @@ -159,7 +243,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..4a774fbf1fdf --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 3f20c6142e59..7f8449cdc282 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") + protected val JAR = Keyword("JAR") protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f6d9027f90a9..c0a458fa9ab8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,158 +17,288 @@ package org.apache.spark.sql.hive -import java.io.{BufferedReader, InputStreamReader, PrintStream} -import java.sql.{Date, Timestamp} +import java.io.File +import java.net.{URL, URLClassLoader} +import java.sql.Timestamp +import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.concurrent.duration._ import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.SQLConf.SQLConfEntry +import org.apache.spark.sql.SQLConf.SQLConfEntry._ +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} -import org.apache.spark.sql.sources.{CreateTableUsing, DataSourceStrategy} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} +import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy} +import org.apache.spark.sql.hive.client._ +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +/** + * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext + */ +private[hive] class HiveQLDialect extends ParserDialect { + override def parse(sqlText: String): LogicalPlan = { + HiveQl.parseSql(sqlText) + } +} /** * An instance of the Spark SQL execution engine that integrates with data stored in Hive. * Configuration for Hive is read from hive-site.xml on the classpath. + * + * @since 1.0.0 */ -class HiveContext(sc: SparkContext) extends SQLContext(sc) { +class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { self => - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - } + import HiveContext._ + + logDebug("create HiveContext") /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive * SerDe. */ - protected[sql] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET) - override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = + getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - override def sql(sqlText: String): DataFrame = { - val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) - // TODO: Create a framework for registering parsers instead of just hardcoding if statements. - if (conf.dialect == "sql") { - super.sql(substituted) - } else if (conf.dialect == "hiveql") { - DataFrame(this, - ddlParser(sqlText, exceptionOnError = false).getOrElse(HiveQl.parseSql(substituted))) - } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") - } - } + /** + * When true, a table created by a Hive CTAS statement (no USING clause) will be + * converted to a data source table, using the data source set by spark.sql.sources.default. + * The table in CTAS statement will be converted when it meets any of the following conditions: + * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or + * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml + * is either TextFile or SequenceFile. + * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe + * is specified (no ROW FORMAT SERDE clause). + * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format + * and no SerDe is specified (no ROW FORMAT SERDE clause). + */ + protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) /** - * Creates a table using the schema of the given class. - * - * @param tableName The name of the table to create. - * @param allowExisting When false, an exception will be thrown if the table already exists. - * @tparam A A case class that is used to describe the schema of the table to be created. + * The version of the hive client that will be used to communicate with the metastore. Note that + * this does not necessarily need to be the same version of Hive that is used internally by + * Spark SQL for execution. */ - @Deprecated - def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) - } + protected[hive] def hiveMetastoreVersion: String = + getConf(HIVE_METASTORE_VERSION, hiveExecutionVersion) /** - * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, - * Spark SQL or the external data source library it uses might cache certain metadata about a - * table, such as the location of blocks. When those change outside of Spark SQL, users should - * call this function to invalidate the cache. + * The location of the jars that should be used to instantiate the HiveMetastoreClient. This + * property can be one of three options: + * - a classpath in the standard format for both hive and hadoop. + * - builtin - attempt to discover the jars that were used to load Spark SQL and use those. This + * option is only valid when using the execution version of Hive. + * - maven - download the correct version of hive on demand from maven. */ - def refreshTable(tableName: String): Unit = { - // TODO: Database support... - catalog.refreshTable("default", tableName) - } + protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS) - protected[hive] def invalidateTable(tableName: String): Unit = { - // TODO: Database support... - catalog.invalidateTable("default", tableName) + /** + * A comma separated list of class prefixes that should be loaded using the classloader that + * is shared between Spark SQL and a specific version of Hive. An example of classes that should + * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + * to be shared are those that interact with classes that are already shared. For example, + * custom appenders that are used by log4j. + */ + protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = + getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") + + /** + * A comma separated list of class prefixes that should explicitly be reloaded for each version + * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + * prefix that typically would be shared (i.e. org.apache.spark.*) + */ + protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = + getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") + + /* + * hive thrift server use background spark sql thread pool to execute sql queries + */ + protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) + + @transient + protected[sql] lazy val substitutor = new VariableSubstitution() + + /** + * The copy of the hive client that is used for execution. Currently this must always be + * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the + * client is used for execution related tasks like registering temporary functions or ensuring + * that the ThreadLocal SessionState is correctly populated. This copy of Hive is *not* used + * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. + */ + @transient + protected[hive] lazy val executionHive: ClientWrapper = { + logInfo(s"Initializing execution hive, version $hiveExecutionVersion") + new ClientWrapper( + version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + config = newTemporaryConfiguration(), + initClassLoader = Utils.getContextOrSparkClassLoader) } + SessionState.setCurrentSessionState(executionHive.state) - @Experimental - def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = { - val dataSourceName = conf.defaultDataSourceName - createTable(tableName, dataSourceName, allowExisting, ("path", path)) + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + private[sql] def defaultOverides() = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") } - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - allowExisting: Boolean, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = None, - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting) - executePlan(cmd).toRdd + defaultOverides() + + /** + * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. + * The version of the Hive client that is used here must match the metastore that is configured + * in the hive-site.xml file. + */ + @transient + protected[hive] lazy val metadataHive: ClientInterface = { + val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) + + // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options + // into the isolated client loader + val metadataConf = new HiveConf() + + val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("defalt warehouse location is " + defaltWarehouseLocation) + + // `configure` goes second to override other settings. + val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure + + val isolatedLoader = if (hiveMetastoreJars == "builtin") { + if (hiveExecutionVersion != hiveMetastoreVersion) { + throw new IllegalArgumentException( + "Builtin jars can only be used when hive execution version == hive metastore version. " + + s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") + } + + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) + } + + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") + } + + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using Spark classes.") + new IsolatedClientLoader( + version = metaVersion, + execJars = jars.toSeq, + config = allConfig, + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } else if (hiveMetastoreJars == "maven") { + // TODO: Support for loading the jars from an already downloaded location. + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") + IsolatedClientLoader.forVersion( + version = hiveMetastoreVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } else { + // Convert to files and expand any directories. + val jars = + hiveMetastoreJars + .split(File.pathSeparator) + .flatMap { + case path if new File(path).getName() == "*" => + val files = new File(path).getParentFile().listFiles() + if (files == null) { + logWarning(s"Hive jar path '$path' does not exist.") + Nil + } else { + files.filter(_.getName().toLowerCase().endsWith(".jar")) + } + case path => + new File(path) :: Nil + } + .map(_.toURI.toURL) + + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using $jars") + new IsolatedClientLoader( + version = metaVersion, + execJars = jars.toSeq, + config = allConfig, + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } + isolatedLoader.client } - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - schema: StructType, - allowExisting: Boolean, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = Some(schema), - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting) - executePlan(cmd).toRdd + protected[sql] override def parseSql(sql: String): LogicalPlan = { + var state = SessionState.get() + if (state == null) { + SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState) + } + super.parseSql(substitutor.substitute(hiveconf, sql)) } - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - allowExisting: Boolean, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*) + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution(plan) + + /** + * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, + * Spark SQL or the external data source library it uses might cache certain metadata about a + * table, such as the location of blocks. When those change outside of Spark SQL, users should + * call this function to invalidate the cache. + * + * @since 1.3.0 + */ + def refreshTable(tableName: String): Unit = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.refreshTable(tableIdent) } - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - schema: StructType, - allowExisting: Boolean, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*) + protected[hive] def invalidateTable(tableName: String): Unit = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.invalidateTable(tableIdent) } /** @@ -177,10 +307,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * * Right now, it only supports Hive tables and it only updates the size of a Hive table * in the Hive metastore. + * + * @since 1.2.0 */ @Experimental def analyze(tableName: String) { - val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName))) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { case relation: MetastoreRelation => @@ -192,10 +325,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. + val stagingDir = metadataHive.getConf(HiveConf.ConfVars.STAGINGDIR.varname, + HiveConf.ConfVars.STAGINGDIR.defaultStrVal) + def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDir) { - fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + fs.listStatus(path) + .map { status => + if (!status.getPath().getName().startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + } + .sum } else { fileStatus.getLen } @@ -222,7 +366,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableParameters = relation.hiveQlTable.getParameters val oldTotalSize = - Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize)) + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(0L) val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) @@ -230,165 +374,156 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // recorded in the Hive metastore. // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). if (newTotalSize > 0 && newTotalSize != oldTotalSize) { - tableParameters.put(HiveShim.getStatsSetupConstTotalSize, newTotalSize.toString) - val hiveTTable = relation.hiveQlTable.getTTable - hiveTTable.setParameters(tableParameters) - val tableFullName = - relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.client.alterTable( + relation.table.copy( + properties = relation.table.properties + + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") - } - } - - // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - @transient - protected lazy val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size - } - - override def toString = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } - /** - * SQLConf and HiveConf contracts: - * - * 1. reuse existing started SessionState if any - * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the - * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be - * set in the SQLConf *as well as* in the HiveConf. - */ - @transient protected[hive] lazy val (hiveconf, sessionState) = - Option(SessionState.get()) - .orElse { - val newState = new SessionState(new HiveConf(classOf[SessionState])) - // Only starts newly created `SessionState` instance. Any existing `SessionState` instance - // returned by `SessionState.get()` must be the most recently started one. - SessionState.start(newState) - Some(newState) - } - .map { state => - setConf(state.getConf.getAllProperties) - if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8") - if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8") - (state.getConf, state) - } - .get + protected[hive] def hiveconf = tlSession.get().asInstanceOf[this.SQLSession].hiveconf override def setConf(key: String, value: String): Unit = { super.setConf(key, value) - runSqlHive(s"SET $key=$value") + executionHive.runSqlHive(s"SET $key=$value") + metadataHive.runSqlHive(s"SET $key=$value") + // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), + // this setConf will be called in the constructor of the SQLContext. + // Also, calling hiveconf will create a default session containing a HiveConf, which + // will interfer with the creation of executionHive (which is a lazy val). So, + // we put hiveconf.set at the end of this method. + hiveconf.set(key, value) + } + + override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + setConf(entry.key, entry.stringConverter(value)) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient - override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog + override protected[sql] lazy val catalog = + new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new HiveFunctionRegistry(FunctionRegistry.builtin) /* An analyzer that uses the Hive metastore. */ @transient - override protected[sql] lazy val analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) { - override val extendedRules = + override protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, functionRegistry, conf) { + override val extendedResolutionRules = + catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: + ResolveHiveWindowFunction :: + PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + PreWriteCheck(catalog) + ) } - /** - * Runs the specified SQL query using Hive. - */ - protected[sql] def runSqlHive(sql: String): Seq[String] = { - val maxResults = 100000 - val results = runHive(sql, maxResults) - // It is very confusing when you only get back some of the results... - if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") - results + override protected[sql] def createSession(): SQLSession = { + new this.SQLSession() } - /** - * Execute the command using Hive and return the results as a sequence. Each element - * in the sequence is one row. - */ - protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = synchronized { - try { - val cmd_trimmed: String = cmd.trim() - val tokens: Array[String] = cmd_trimmed.split("\\s+") - val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hiveconf) - - // Makes sure the session represented by the `sessionState` field is activated. This implies - // Spark SQL Hive support uses a single `SessionState` for all Hive operations and breaks - // session isolation under multi-user scenarios (i.e. HiveThriftServer2). - // TODO Fix session isolation - if (SessionState.get() != sessionState) { - SessionState.start(sessionState) - } + /** Overridden by child classes that need to set configuration before the client init. */ + protected def configure(): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } - proc match { - case driver: Driver => - val results = HiveShim.createDriverResultsArray - val response: CommandProcessorResponse = driver.run(cmd) - // Throw an exception if there is an error in query processing. - if (response.getResponseCode != 0) { - driver.close() - throw new QueryExecutionException(response.getErrorMessage) - } - driver.setMaxRows(maxRows) - driver.getResults(results) - driver.close() - HiveShim.processResults(results) - case _ => - if (sessionState.out != null) { - sessionState.out.println(tokens(0) + " " + cmd_1) - } - Seq(proc.run(cmd_1).getResponseCode.toString) + protected[hive] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + } + + /** + * SQLConf and HiveConf contracts: + * + * 1. reuse existing started SessionState if any + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. + */ + protected[hive] lazy val sessionState: SessionState = { + var state = SessionState.get() + if (state == null) { + state = new SessionState(new HiveConf(classOf[SessionState])) + SessionState.start(state) } - } catch { - case e: Exception => - logError( - s""" - |====================== - |HIVE FAILURE OUTPUT - |====================== - |${outputBuffer.toString} - |====================== - |END HIVE FAILURE OUTPUT - |====================== - """.stripMargin) - throw e + state + } + + protected[hive] lazy val hiveconf: HiveConf = { + setConf(sessionState.getConf.getAllProperties) + sessionState.getConf } } + override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") { + classOf[HiveQLDialect].getCanonicalName + } else { + super.dialectClassName + } + @transient private val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self @@ -398,22 +533,32 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, - ParquetOperations, + TakeOrderedAndProject, InMemoryScans, - ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, HashAggregation, + Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin ) } + protected[hive] def runSqlHive(sql: String): Seq[String] = { + if (sql.toLowerCase.contains("create temporary function")) { + executionHive.runSqlHive(sql) + } else if (sql.trim.toLowerCase.startsWith("set")) { + metadataHive.runSqlHive(sql) + executionHive.runSqlHive(sql) + } else { + metadataHive.runSqlHive(sql) + } + } + @transient override protected[sql] val planner = hivePlanner @@ -457,7 +602,80 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } -private object HiveContext { +private[hive] object HiveContext { + /** The version of hive used internally by Spark SQL. */ + val hiveExecutionVersion: String = "1.2.1" + + val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", + defaultValue = Some("builtin"), + doc = s""" + | Location of the jars that should be used to instantiate the HiveMetastoreClient. + | This property can be one of three options: " + | 1. "builtin" + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | -Phive is enabled. When this option is chosen, + | spark.sql.hive.metastore.version must be either + | ${hiveExecutionVersion} or not defined. + | 2. "maven" + | Use Hive jars of specified version downloaded from Maven repositories. + | 3. A classpath in the standard format for both Hive and Hadoop. + """.stripMargin) + val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", + defaultValue = Some(true), + doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + + "the built in support.") + + val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( + "spark.sql.hive.convertMetastoreParquet.mergeSchema", + defaultValue = Some(false), + doc = "TODO") + + val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", + defaultValue = Some(false), + doc = "TODO") + + val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", + defaultValue = Some(jdbcPrefixes), + doc = "A comma separated list of class prefixes that should be loaded using the classloader " + + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + + "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " + + "classes that need to be shared are those that interact with classes that are already " + + "shared. For example, custom appenders that are used by log4j.") + + private def jdbcPrefixes = Seq( + "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc") + + val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes", + defaultValue = Some(Seq()), + doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " + + "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " + + "declared in a prefix that typically would be shared (i.e. org.apache.spark.*).") + + val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", + defaultValue = Some(true), + doc = "TODO") + + /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ + def newTemporaryConfiguration(): Map[String, String] = { + val tempDir = Utils.createTempDir() + val localMetastore = new File(tempDir, "metastore") + val propMap: HashMap[String, String] = HashMap() + // We have to mask all properties in hive-site.xml that relates to metastore data source + // as we used a local metastore here. + HiveConf.ConfVars.values().foreach { confvar => + if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { + propMap.put(confvar.varname, confvar.getDefaultExpr()) + } + } + propMap.put(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, localMetastore.toURI.toString) + propMap.put(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") + propMap.put("datanucleus.rdbms.datastoreAdapterClassName", + "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + propMap.toMap + } + protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) @@ -469,18 +687,18 @@ private object HiveContext { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" - case (d: Date, DateType) => new DateWritable(d).toString + case (d: Int, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal).toString + HiveDecimal.create(decimal).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 82dba99900df..cfe2bb05ad89 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,24 +17,27 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ - -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import org.apache.spark.sql.{AnalysisException, types} +import org.apache.spark.unsafe.types.UTF8String /** * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -42,15 +45,14 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.types.Decimal + * [[org.apache.spark.sql.types.Decimal]] * Array[Byte] * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * org.apache.spark.sql.catalyst.expression.Row + * Map: [[org.apache.spark.sql.types.MapData]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -121,7 +123,7 @@ import scala.collection.JavaConversions._ * even a normal java object (POJO) * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet) * - * 3) ConstantObjectInspector: + * 3) ConstantObjectInspector: * Constant object inspector can be either primitive type or Complex type, and it bundles a * constant value as its property, usually the value is created when the constant object inspector * constructed. @@ -132,7 +134,7 @@ import scala.collection.JavaConversions._ } }}} * Hive provides 3 built-in constant object inspectors: - * Primitive Object Inspectors: + * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector * WritableConstantHiveDecimalObjectInspector @@ -146,9 +148,9 @@ import scala.collection.JavaConversions._ * WritableConstantByteObjectInspector * WritableConstantBinaryObjectInspector * WritableConstantDateObjectInspector - * Map Object Inspector: + * Map Object Inspector: * StandardConstantMapObjectInspector - * List Object Inspector: + * List Object Inspector: * StandardConstantListObjectInspector]] * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union @@ -175,7 +177,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -191,8 +193,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -215,6 +217,20 @@ private[hive] trait HiveInspectors { // Hive seems to return this for struct types? case c: Class[_] if c == classOf[java.lang.Object] => NullType + + // java list type unsupported + case c: Class[_] if c == classOf[java.util.List[_]] => + throw new AnalysisException( + "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>") + + // java map type unsupported + case c: Class[_] if c == classOf[java.util.Map[_, _]] => + throw new AnalysisException( + "Map type in java is unsupported because " + + "JVM type erasure makes spark fail to catch key and value types in Map<>") + + case c => throw new AnalysisException(s"Unsupported java type $c") } /** @@ -239,18 +255,20 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => - poi.getWritableConstantValue.getTimestamp.clone() - case poi: WritableConstantIntObjectInspector => + val t = poi.getWritableConstantValue + t.getSeconds * 1000000L + t.getNanos / 1000L + case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() - case poi: WritableConstantDoubleObjectInspector => + case poi: WritableConstantDoubleObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantBooleanObjectInspector => poi.getWritableConstantValue.get() @@ -267,26 +285,33 @@ private[hive] trait HiveInspectors { val temp = new Array[Byte](writable.getLength) System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp - case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get() + case poi: WritableConstantDateObjectInspector => + DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data - mi.getWritableConstantValue.map { case (k, v) => - (unwrap(k, mi.getMapKeyObjectInspector), - unwrap(v, mi.getMapValueObjectInspector)) - }.toMap + val keyValues = mi.getWritableConstantValue.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray + ArrayBasedMapData(keys, values) case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue.asScala + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String.fromString(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -300,35 +325,41 @@ private[hive] trait HiveInspectors { // In order to keep backward-compatible, we have to copy the // bytes with old apis val bw = x.getPrimitiveWritableObject(data) - val result = new Array[Byte](bw.getLength()) + val result = new Array[Byte](bw.getLength()) System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).get() - // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object - // if next timestamp is null, so Timestamp object is cloned + DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) case x: TimestampObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).getTimestamp.clone() - case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() + val t = x.getPrimitiveWritableObject(data) + t.getSeconds * 1000000L + t.getNanos / 1000L + case ti: TimestampObjectInspector => + DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => - Option(mi.getMap(data)).map( - _.map { - case (k, v) => - (unwrap(k, mi.getMapKeyObjectInspector), - unwrap(v, mi.getMapValueObjectInspector)) - }.toMap).orNull + val map = mi.getMap(data) + if (map == null) { + null + } else { + val keyValues = map.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray + ArrayBasedMapData(keys, values) + } // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + InternalRow.fromSeq(allRefs.asScala.map( + r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -336,20 +367,52 @@ private[hive] trait HiveInspectors { * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) + } else { + null + } case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => + if (o != null) { + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + } else { + null + } + + case _: JavaDateObjectInspector => + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + } else { + null + } + + case _: JavaTimestampObjectInspector => + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + } else { + null + } case soi: StandardStructObjectInspector => - val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + val schema = dataType.asInstanceOf[StructType] + val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { + case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) + } (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { @@ -358,21 +421,34 @@ private[hive] trait HiveInspectors { } case loi: ListObjectInspector => - val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + val elementType = dataType.asInstanceOf[ArrayType].elementType + val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType) + (o: Any) => { + if (o != null) { + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(elementType, (_, e) => { + values.add(wrapper(e)) + }) + values + } else { + null + } + } case moi: MapObjectInspector => - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map + val mt = dataType.asInstanceOf[MapType] + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType) - val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) - val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) (o: Any) => { if (o != null) { - mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) + val map = o.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => { + jmap.put(keyWrapper(k), valueWrapper(v)) }) + jmap } else { null } @@ -382,6 +458,30 @@ private[hive] trait HiveInspectors { identity[Any] } + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + field.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped @@ -398,101 +498,113 @@ private[hive] trait HiveInspectors { * * NOTICE: the complex data type requires recursive wrapping. */ - def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match { + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { case x: ConstantObjectInspector => x.getWritableConstantValue case _ if a == null => null case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] - case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) + case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() + case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) + case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) + case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) + case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) + case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) + case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) + case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] case _: HiveDecimalObjectInspector if x.preferWritable() => - HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) + getDecimalWritable(a.asInstanceOf[Decimal]) case _: HiveDecimalObjectInspector => - HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) + HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) + case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) - case _: DateObjectInspector => a.asInstanceOf[java.sql.Date] - case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) - case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] + case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) + case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int]) + case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) + case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long]) } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Row] + val structType = dataType.asInstanceOf[StructType] + val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { // 2. set the property for the pojo + val tpe = structType(i).dataType x.setStructFieldData( result, fieldRefs.get(i), - wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Row] - val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + val structType = dataType.asInstanceOf[StructType] + val row = a.asInstanceOf[InternalRow] + val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.length) { - result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + while (i < fieldRefs.size) { + val tpe = structType(i).dataType + result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: ListObjectInspector => val list = new java.util.ArrayList[Object] - a.asInstanceOf[Seq[_]].foreach { - v => list.add(wrap(v, x.getListElementObjectInspector)) - } + val tpe = dataType.asInstanceOf[ArrayType].elementType + a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => { + list.add(wrap(e, x.getListElementObjectInspector, tpe)) + }) list case x: MapObjectInspector => + val keyType = dataType.asInstanceOf[MapType].keyType + val valueType = dataType.asInstanceOf[MapType].valueType + val map = a.asInstanceOf[MapData] + // Some UDFs seem to assume we pass in a HashMap. - val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(a.asInstanceOf[Map[_, _]].map { - case (k, v) => - wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + val hashMap = new java.util.HashMap[Any, Any](map.numElements()) + + map.foreach(keyType, valueType, (k, v) => { + hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), + wrap(v, x.getMapValueObjectInspector, valueType)) }) hashMap } def wrap( - row: Row, + row: InternalRow, inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } cache } def wrap( - row: Seq[Any], - inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } cache @@ -525,8 +637,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) } /** @@ -538,38 +650,40 @@ private[hive] trait HiveInspectors { */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => - HiveShim.getStringWritableConstantObjectInspector(value) + getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => - HiveShim.getIntWritableConstantObjectInspector(value) + getIntWritableConstantObjectInspector(value) case Literal(value, DoubleType) => - HiveShim.getDoubleWritableConstantObjectInspector(value) + getDoubleWritableConstantObjectInspector(value) case Literal(value, BooleanType) => - HiveShim.getBooleanWritableConstantObjectInspector(value) + getBooleanWritableConstantObjectInspector(value) case Literal(value, LongType) => - HiveShim.getLongWritableConstantObjectInspector(value) + getLongWritableConstantObjectInspector(value) case Literal(value, FloatType) => - HiveShim.getFloatWritableConstantObjectInspector(value) + getFloatWritableConstantObjectInspector(value) case Literal(value, ShortType) => - HiveShim.getShortWritableConstantObjectInspector(value) + getShortWritableConstantObjectInspector(value) case Literal(value, ByteType) => - HiveShim.getByteWritableConstantObjectInspector(value) + getByteWritableConstantObjectInspector(value) case Literal(value, BinaryType) => - HiveShim.getBinaryWritableConstantObjectInspector(value) + getBinaryWritableConstantObjectInspector(value) case Literal(value, DateType) => - HiveShim.getDateWritableConstantObjectInspector(value) + getDateWritableConstantObjectInspector(value) case Literal(value, TimestampType) => - HiveShim.getTimestampWritableConstantObjectInspector(value) + getTimestampWritableConstantObjectInspector(value) case Literal(value, DecimalType()) => - HiveShim.getDecimalWritableConstantObjectInspector(value) + getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => - HiveShim.getPrimitiveNullWritableConstantObjectInspector + getPrimitiveNullWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + value.asInstanceOf[ArrayData].foreach(dt, (_, e) => { + list.add(wrap(e, listObjectInspector, dt)) + }) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -578,27 +692,30 @@ private[hive] trait HiveInspectors { if (value == null) { ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null) } else { - val map = new java.util.HashMap[Object, Object]() - value.asInstanceOf[Map[_, _]].foreach (entry => { - map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + val map = value.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + + map.foreach(keyType, valueType, (k, v) => { + jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)) }) - ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) + + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap) } // We will enumerate all of the possible constant expressions, throw exception if we missed case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. - case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType)) // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { + StructType(s.getAllStructFieldRefs.asScala.map(f => types.StructField( f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) + )) case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) case m: MapObjectInspector => MapType( @@ -622,8 +739,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) - case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) + case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -632,17 +749,140 @@ private[hive] trait HiveInspectors { case _: JavaVoidObjectInspector => NullType } + private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) + + private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, getIntWritable(value)) + + private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) + + private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) + + private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, getLongWritable(value)) + + private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) + + private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) + + private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) + + private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) + + private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) + + private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) + + private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) + + private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + + private def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes) + + private def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + private def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + private def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + private def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + private def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + private def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + private def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + private def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) + + private def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) + } + + private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal)) + } + implicit class typeInfoConversions(dt: DataType) { import org.apache.hadoop.hive.serde2.typeinfo._ import TypeInfoFactory._ + private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + } + def toTypeInfo: TypeInfo = dt match { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) :_*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) :_*)) + java.util.Arrays.asList(fields.map(_.name) : _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo @@ -654,7 +894,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case d: DecimalType => HiveShim.decimalTypeInfo(d) + case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 48bea6c1bd68..b8da0840ae56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,46 +17,95 @@ package org.apache.spark.sql.hive -import java.io.IOException -import java.util.{List => JList} - -import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} - -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.hive.metastore.{Warehouse, TableType} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, AlreadyExistsException, FieldSchema} +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.common.base.Objects +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.Warehouse +import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ -import org.apache.hadoop.hive.ql.plan.CreateTableDesc -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.hive.client._ +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} -/* Implicit conversions */ -import scala.collection.JavaConversions._ +private[hive] case class HiveSerDe( + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None) -private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { - import org.apache.spark.sql.hive.HiveMetastoreTypes._ +private[hive] object HiveSerDe { + /** + * Get the Hive SerDe information from the data source abbreviation string or classname. + * + * @param source Currently the source abbreviation can be one of the following: + * SequenceFile, RCFile, ORC, PARQUET, and case insensitive. + * @param hiveConf Hive Conf + * @return HiveSerDe associated with the specified source + */ + def sourceToSerDe(source: String, hiveConf: HiveConf): Option[HiveSerDe] = { + val serdeMap = Map( + "sequencefile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + + "rcfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), + serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))), + + "orc" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), + + "parquet" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + + val key = source.toLowerCase match { + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s + } + + serdeMap.get(key) + } +} - /** Connection to hive metastore. Usages should lock on `this`. */ - protected[hive] val client = Hive.get(hive.hiveconf) +private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) + extends Catalog with Logging { + val conf = hive.conf + + /** Usages should lock on `this`. */ protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) { - def toLowerCase = QualifiedTableName(database.toLowerCase, name.toLowerCase) + def toLowerCase: QualifiedTableName = QualifiedTableName(database.toLowerCase, name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -65,22 +114,48 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") val table = client.getTable(in.database, in.name) - val schemaString = table.getProperty("spark.sql.sources.schema") - val userSpecifiedSchema = - if (schemaString == null) { - None - } else { - Some(DataType.fromJson(schemaString).asInstanceOf[StructType]) + + def schemaStringFromParts: Option[String] = { + table.properties.get("spark.sql.sources.schema.numParts").map { numParts => + val parts = (0 until numParts.toInt).map { index => + val part = table.properties.get(s"spark.sql.sources.schema.part.$index").orNull + if (part == null) { + throw new AnalysisException( + "Could not read schema from the metastore because it is corrupted " + + s"(missing part $index of the schema, $numParts parts are expected).") + } + + part + } + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + table.properties.get("spark.sql.sources.schema").orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) + + // We only need names at here since userSpecifiedSchema we loaded from the metastore + // contains partition columns. We can always get datatypes of partitioning columns + // from userSpecifiedSchema. + val partitionColumns = table.partitionColumns.map(_.name) + // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... - val options = table.getTTable.getSd.getSerdeInfo.getParameters.toMap + val options = table.serdeProperties val resolvedRelation = ResolvedDataSource( hive, userSpecifiedSchema, - table.getProperty("spark.sql.sources.provider"), + partitionColumns.toArray, + table.properties("spark.sql.sources.provider"), options) LogicalRelation(resolvedRelation.relation) @@ -90,261 +165,363 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + override def refreshTable(tableIdent: TableIdentifier): Unit = { + // refreshTable does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRelations converted from Hive Parquet tables and + // adding converted ParquetRelations into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). + invalidateTable(tableIdent) } - def invalidateTable(databaseName: String, tableName: String): Unit = { + def invalidateTable(tableIdent: TableIdentifier): Unit = { + val databaseName = tableIdent.database.getOrElse(client.currentDatabase) + val tableName = tableIdent.table + cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) } val caseSensitive: Boolean = false - /** * - * Creates a data source table (a table created with USING clause) in Hive's metastore. - * Returns true when the table has been created. Otherwise, false. - * @param tableName - * @param userSpecifiedSchema - * @param provider - * @param options - * @param isExternal - * @return - */ + /** + * Creates a data source table (a table created with USING clause) in Hive's metastore. + * Returns true when the table has been created. Otherwise, false. + */ + // TODO: Remove this in SPARK-10104. def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName("default", tableName) - val tbl = new Table(dbName, tblName) + createDataSourceTable( + new SqlParser().parseTableIdentifier(tableName), + userSpecifiedSchema, + partitionColumns, + provider, + options, + isExternal) + } - tbl.setProperty("spark.sql.sources.provider", provider) - if (userSpecifiedSchema.isDefined) { - tbl.setProperty("spark.sql.sources.schema", userSpecifiedSchema.get.json) + def createDataSourceTable( + tableIdent: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + val (dbName, tblName) = { + val database = tableIdent.database.getOrElse(client.currentDatabase) + processDatabaseAndTableName(database, tableIdent.table) + } + + val tableProperties = new mutable.HashMap[String, String] + tableProperties.put("spark.sql.sources.provider", provider) + + // Saves optional user specified schema. Serialized JSON schema string may be too long to be + // stored into a single metastore SerDe property. In this case, we split the JSON string and + // store each part as a separate SerDe property. + userSpecifiedSchema.foreach { schema => + val threshold = conf.schemaStringLengthThreshold + val schemaJsonString = schema.json + // Split the JSON string. + val parts = schemaJsonString.grouped(threshold).toSeq + tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) + parts.zipWithIndex.foreach { case (part, index) => + tableProperties.put(s"spark.sql.sources.schema.part.$index", part) + } } - options.foreach { case (key, value) => tbl.setSerdeParam(key, value) } - if (isExternal) { - tbl.setProperty("EXTERNAL", "TRUE") - tbl.setTableType(TableType.EXTERNAL_TABLE) - } else { - tbl.setProperty("EXTERNAL", "FALSE") - tbl.setTableType(TableType.MANAGED_TABLE) + val metastorePartitionColumns = userSpecifiedSchema.map { schema => + val fields = partitionColumns.map(col => schema(col)) + fields.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + }.toSeq + }.getOrElse { + if (partitionColumns.length > 0) { + // The table does not have a specified schema, which means that the schema will be inferred + // when we load the table. So, we are not expecting partition columns and we will discover + // partitions when we load the table. However, if there are specified partition columns, + // we simply ignore them and provide a warning message. + logWarning( + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") + } + Seq.empty[HiveColumn] } - // create the table - synchronized { - client.createTable(tbl, false) + val tableType = if (isExternal) { + tableProperties.put("EXTERNAL", "TRUE") + ExternalTable + } else { + tableProperties.put("EXTERNAL", "FALSE") + ManagedTable + } + + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) + val dataSource = ResolvedDataSource( + hive, userSpecifiedSchema, partitionColumns, provider, options) + + def newSparkSQLSpecificMetastoreTable(): HiveTable = { + HiveTable( + specifiedDatabase = Option(dbName), + name = tblName, + schema = Seq.empty, + partitionColumns = metastorePartitionColumns, + tableType = tableType, + properties = tableProperties.toMap, + serdeProperties = options) + } + + def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { + def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { + schema.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + } + } + + val partitionColumns = schemaToHiveColumn(relation.partitionColumns) + val dataColumns = schemaToHiveColumn(relation.schema).filterNot(partitionColumns.contains) + + HiveTable( + specifiedDatabase = Option(dbName), + name = tblName, + schema = dataColumns, + partitionColumns = partitionColumns, + tableType = tableType, + properties = tableProperties.toMap, + serdeProperties = options, + location = Some(relation.paths.head), + viewText = None, // TODO We need to place the SQL string here. + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde) + } + + // TODO: Support persisting partitioned data source relations in Hive compatible format + val hiveTable = (maybeSerDe, dataSource.relation) match { + case (Some(serde), relation: HadoopFsRelation) + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) + } + + case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => + logWarning { + "Persisting partitioned data source relation into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + + relation.paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), relation: HadoopFsRelation) => + logWarning { + "Persisting data source relation with multiple input paths into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + + relation.paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), _) => + logWarning { + s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + + "Persisting it into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + + case _ => + logWarning { + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + "Persisting data source relation into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() } + + client.createTable(hiveTable) } def hiveDefaultTableFilePath(tableName: String): String = { - hiveWarehouse.getTablePath(client.getDatabaseCurrent, tableName).toString + hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + } + + def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { + // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) + val database = tableIdent.database.getOrElse(client.currentDatabase) + + new Path( + new Path(client.getDatabase(database).location), + tableIdent.table.toLowerCase).toString } def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - hive.sessionState.getCurrentDatabase) + val databaseName = + tableIdent + .lift(tableIdent.size - 2) + .getOrElse(client.currentDatabase) val tblName = tableIdent.last - try { - client.getTable(databaseName, tblName) != null - } catch { - case ie: InvalidTableException => false - } + client.getTableOption(databaseName, tblName).isDefined } def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - hive.sessionState.getCurrentDatabase) + client.currentDatabase) val tblName = tableIdent.last val table = client.getTable(databaseName, tblName) - if (table.getProperty("spark.sql.sources.provider") != null) { - cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) - } else if (table.isView) { - // if the unresolved relation is from hive view - // parse the text into logic node. - HiveQl.createPlanForView(table, alias) + if (table.properties.get("spark.sql.sources.provider").isDefined) { + val dataSourceTable = + cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + // Then, if alias is specified, wrap the table with a Subquery using the alias. + // Otherwise, wrap the table with a Subquery using the table name. + val withAlias = + alias.map(a => Subquery(a, dataSourceTable)).getOrElse( + Subquery(tableIdent.last, dataSourceTable)) + + withAlias + } else if (table.tableType == VirtualView) { + val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) + alias match { + // because hive use things like `_c0` to build the expanded text + // currently we cannot support view from "create view v1(c1) as ..." + case None => Subquery(table.name, HiveQl.createPlan(viewText)) + case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText)) + } } else { - val partitions: Seq[Partition] = - if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq - } else { - Nil - } - - // Since HiveQL is case insensitive for table names we make them all lowercase. - MetastoreRelation( - databaseName, tblName, alias)( - table.getTTable, partitions.map(part => part.getTPartition))(hive) + MetastoreRelation(databaseName, tblName, alias)(table)(hive) } } - /** - * Create table with specified database, table name, table description and schema - * @param databaseName Database Name - * @param tableName Table Name - * @param schema Schema of the new table, if not specified, will use the schema - * specified in crtTbl - * @param allowExisting if true, ignore AlreadyExistsException - * @param desc CreateTableDesc object which contains the SerDe info. Currently - * we support most of the features except the bucket. - */ - def createTable( - databaseName: String, - tableName: String, - schema: Seq[Attribute], - allowExisting: Boolean = false, - desc: Option[CreateTableDesc] = None) { - val hconf = hive.hiveconf - - val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - val tbl = new Table(dbName, tblName) - - val crtTbl: CreateTableDesc = desc.getOrElse(null) - - // We should respect the passed in schema, unless it's not set - val hiveSchema: JList[FieldSchema] = if (schema == null || schema.isEmpty) { - crtTbl.getCols - } else { - schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) - } - tbl.setFields(hiveSchema) - - // Most of code are similar with the DDLTask.createTable() of Hive, - if (crtTbl != null && crtTbl.getTblProps() != null) { - tbl.getTTable().getParameters().putAll(crtTbl.getTblProps()) - } - - if (crtTbl != null && crtTbl.getPartCols() != null) { - tbl.setPartCols(crtTbl.getPartCols()) - } + private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to + // serialize the Metastore schema to JSON and pass it as a data source option because of the + // evil case insensitivity issue, which is reconciled within `ParquetRelation`. + val parquetOptions = Map( + ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths.toSet == pathsInMetastore.toSet && + logical.schema.sameType(metastoreSchema) && + parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { + PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) + } - if (crtTbl != null && crtTbl.getStorageHandler() != null) { - tbl.setProperty( - org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.META_TABLE_STORAGE, - crtTbl.getStorageHandler()) + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as Parquet. However, we are getting a $other from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } } - /* - * We use LazySimpleSerDe by default. - * - * If the user didn't specify a SerDe, and any of the columns are not simple - * types, we will have to use DynamicSerDe instead. - */ - if (crtTbl == null || crtTbl.getSerName() == null) { - val storageHandler = tbl.getStorageHandler() - if (storageHandler == null) { - logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName") - tbl.setSerializationLib(classOf[LazySimpleSerDe].getName()) - - import org.apache.hadoop.mapred.TextInputFormat - import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat - import org.apache.hadoop.io.Text - - tbl.setInputFormatClass(classOf[TextInputFormat]) - tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]]) - tbl.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") - } else { - val serDeClassName = storageHandler.getSerDeClass().getName() - logInfo(s"Use StorageHandler-supplied $serDeClassName for table $dbName.$tblName") - tbl.setSerializationLib(serDeClassName) + val result = if (metastoreRelation.hiveQlTable.isPartitioned) { + val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) + val partitionColumnDataTypes = partitionSchema.map(_.dataType) + // We're converting the entire table into ParquetRelation, so predicates to Hive metastore + // are empty. + val partitions = metastoreRelation.getHiveQlPartitions().map { p => + val location = p.getLocation + val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { + case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) + }) + ParquetPartition(values, location) } - } else { - // let's validate that the serde exists - val serdeName = crtTbl.getSerName() - try { - val d = ReflectionUtils.newInstance(hconf.getClassByName(serdeName), hconf) - if (d != null) { - logDebug("Found class for $serdeName") - } - } catch { - case e: SerDeException => throw new HiveException("Cannot validate serde: " + serdeName, e) + val partitionSpec = PartitionSpec(partitionSchema, partitions) + val paths = partitions.map(_.path) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = LogicalRelation( + new ParquetRelation( + paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created } - tbl.setSerializationLib(serdeName) - } - if (crtTbl != null && crtTbl.getFieldDelim() != null) { - tbl.setSerdeParam(serdeConstants.FIELD_DELIM, crtTbl.getFieldDelim()) - tbl.setSerdeParam(serdeConstants.SERIALIZATION_FORMAT, crtTbl.getFieldDelim()) - } - if (crtTbl != null && crtTbl.getFieldEscape() != null) { - tbl.setSerdeParam(serdeConstants.ESCAPE_CHAR, crtTbl.getFieldEscape()) - } - - if (crtTbl != null && crtTbl.getCollItemDelim() != null) { - tbl.setSerdeParam(serdeConstants.COLLECTION_DELIM, crtTbl.getCollItemDelim()) - } - if (crtTbl != null && crtTbl.getMapKeyDelim() != null) { - tbl.setSerdeParam(serdeConstants.MAPKEY_DELIM, crtTbl.getMapKeyDelim()) - } - if (crtTbl != null && crtTbl.getLineDelim() != null) { - tbl.setSerdeParam(serdeConstants.LINE_DELIM, crtTbl.getLineDelim()) - } - - if (crtTbl != null && crtTbl.getSerdeProps() != null) { - val iter = crtTbl.getSerdeProps().entrySet().iterator() - while (iter.hasNext()) { - val m = iter.next() - tbl.setSerdeParam(m.getKey(), m.getValue()) + parquetRelation + } else { + val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = LogicalRelation( + new ParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created } - } - - if (crtTbl != null && crtTbl.getComment() != null) { - tbl.setProperty("comment", crtTbl.getComment()) - } - - if (crtTbl != null && crtTbl.getLocation() != null) { - HiveShim.setLocation(tbl, crtTbl) - } - - if (crtTbl != null && crtTbl.getSkewedColNames() != null) { - tbl.setSkewedColNames(crtTbl.getSkewedColNames()) - } - if (crtTbl != null && crtTbl.getSkewedColValues() != null) { - tbl.setSkewedColValues(crtTbl.getSkewedColValues()) - } - - if (crtTbl != null) { - tbl.setStoredAsSubDirectories(crtTbl.isStoredAsSubDirectories()) - tbl.setInputFormatClass(crtTbl.getInputFormat()) - tbl.setOutputFormatClass(crtTbl.getOutputFormat()) - } - - tbl.getTTable().getSd().setInputFormat(tbl.getInputFormatClass().getName()) - tbl.getTTable().getSd().setOutputFormat(tbl.getOutputFormatClass().getName()) - - if (crtTbl != null && crtTbl.isExternal()) { - tbl.setProperty("EXTERNAL", "TRUE") - tbl.setTableType(TableType.EXTERNAL_TABLE) - } - // set owner - try { - tbl.setOwner(hive.hiveconf.getUser) - } catch { - case e: IOException => throw new HiveException("Unable to get current user", e) + parquetRelation } - // set create time - tbl.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) + result.newInstance() + } - // TODO add bucket support - // TODO set more info if Hive upgrade + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + val db = databaseName.getOrElse(client.currentDatabase) - // create the table - synchronized { - try client.createTable(tbl, allowExisting) catch { - case e: org.apache.hadoop.hive.metastore.api.AlreadyExistsException - if allowExisting => // Do nothing - case e: Throwable => throw e - } - } + client.listTables(db).map(tableName => (tableName, false)) } protected def processDatabaseAndTableName( @@ -367,61 +544,135 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + /** + * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet + * data source relations for better performance. + */ + object ParquetConversions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.resolved || plan.analyzed) { + return plan + } + + // Collects all `MetastoreRelation`s which should be replaced + val toBeReplaced = plan.collect { + // Write path + case InsertIntoTable(relation: MetastoreRelation, _, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + + // Write path + case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + + // Read path + case relation: MetastoreRelation if hive.convertMetastoreParquet && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + } + + val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap + val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) + + // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes + // attribute IDs referenced in other nodes. + plan.transformUp { + case r: MetastoreRelation if relationMap.contains(r) => + val parquetRelation = relationMap(r) + val alias = r.alias.getOrElse(r.tableName) + Subquery(alias, parquetRelation) + + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + + case other => other.transformExpressions { + case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) + } + } + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. */ object CreateTables extends Rule[LogicalPlan] { - import org.apache.hadoop.hive.ql.Context - import org.apache.hadoop.hive.ql.parse.{QB, ASTNode, SemanticAnalyzer} - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - - // TODO extra is in type of ASTNode which means the logical plan is not resolved - // Need to think about how to implement the CreateTableAsSelect.resolved - case CreateTableAsSelect(db, tableName, child, allowExisting, Some(extra: ASTNode)) => - val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - - // Get the CreateTableDesc from Hive SemanticAnalyzer - val desc: Option[CreateTableDesc] = if (tableExists(Seq(databaseName, tblName))) { - None + case p: LogicalPlan if p.resolved => p + case p @ CreateTableAsSelect(table, child, allowExisting) => + val schema = if (table.schema.nonEmpty) { + table.schema } else { - val sa = new SemanticAnalyzer(hive.hiveconf) { - override def analyzeInternal(ast: ASTNode) { - // A hack to intercept the SemanticAnalyzer.analyzeInternal, - // to ignore the SELECT clause of the CTAS - val method = classOf[SemanticAnalyzer].getDeclaredMethod( - "analyzeCreateTable", classOf[ASTNode], classOf[QB]) - method.setAccessible(true) - method.invoke(this, ast, this.getQB) - } + child.output.map { + attr => new HiveColumn( + attr.name, + HiveMetastoreTypes.toMetastoreType(attr.dataType), null) } - - sa.analyze(extra, new Context(hive.hiveconf)) - Some(sa.getQB().getTableDesc) } - execution.CreateTableAsSelect( - databaseName, - tableName, - child, - allowExisting, - desc) + val desc = table.copy(schema = schema) - case p: LogicalPlan if p.resolved => p + if (hive.convertCTAS && table.serde.isEmpty) { + // Do the conversion when spark.sql.hive.convertCTAS is true and the query + // does not specify any storage format (file format and storage handler). + if (table.specifiedDatabase.isDefined) { + throw new AnalysisException( + "Cannot specify database name in a CTAS statement " + + "when spark.sql.hive.convertCTAS is set to true.") + } + + val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists + CreateTableUsingAsSelect( + TableIdentifier(desc.name), + hive.conf.defaultDataSourceName, + temporary = false, + Array.empty[String], + mode, + options = Map.empty[String, String], + child + ) + } else { + val desc = if (table.serde.isEmpty) { + // add default serde + table.copy( + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } else { + table + } + + val (dbName, tblName) = + processDatabaseAndTableName( + desc.specifiedDatabase.getOrElse(client.currentDatabase), desc.name) - case p @ CreateTableAsSelect(db, tableName, child, allowExisting, None) => - val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - execution.CreateTableAsSelect( - databaseName, - tableName, - child, - allowExisting, - None) + execution.CreateTableAsSelect( + desc.copy( + specifiedDatabase = Some(dbName), + name = tblName), + child, + allowExisting) + } } } @@ -434,23 +685,26 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => castChildOutput(p, table, child) } - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { + def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) + : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) + val numDynamicPartitions = p.partition.values.count(_.isEmpty) val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) + (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) + .take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { - p + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else if (childOutputDataTypes.size == tableOutputDataTypes.size && childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) { + .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite) + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -468,15 +722,19 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ??? + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } /** * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + throw new UnsupportedOperationException + } - override def unregisterAllTables() = {} + override def unregisterAllTables(): Unit = {} } /** @@ -485,45 +743,84 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with * because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types. */ private[hive] case class InsertIntoHiveTable( - table: LogicalPlan, + table: MetastoreRelation, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = Seq.empty + + val numDynamicPartitions = partition.values.count(_.isEmpty) + + // This is the expected schema of the table prepared to be inserted into, + // including dynamic partition columns. + val tableOutput = table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions) - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType) + override lazy val resolved: Boolean = childrenResolved && child.output.zip(tableOutput).forall { + case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) - (val table: TTable, val partitions: Seq[TPartition]) + (val table: HiveTable) (@transient sqlContext: SQLContext) - extends LeafNode { + extends LeafNode with MultiInstanceRelation with FileRelation { + + override def equals(other: Any): Boolean = other match { + case relation: MetastoreRelation => + databaseName == relation.databaseName && + tableName == relation.tableName && + alias == relation.alias && + output == relation.output + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(databaseName, tableName, alias, output) + } + + @transient val hiveQlTable: Table = { + // We start by constructing an API table as Hive performs several important transformations + // internally when converting an API table to a QL table. + val tTable = new org.apache.hadoop.hive.metastore.api.Table() + tTable.setTableName(table.name) + tTable.setDbName(table.database) + + val tableParameters = new java.util.HashMap[String, String]() + tTable.setParameters(tableParameters) + table.properties.foreach { case (k, v) => tableParameters.put(k, v) } + + tTable.setTableType(table.tableType.name) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tTable.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + tTable.setPartitionKeys( + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - self: Product => + table.location.foreach(sd.setLocation) + table.inputFormat.foreach(sd.setInputFormat) + table.outputFormat.foreach(sd.setOutputFormat) - // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and - // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. - // Right now, using org.apache.hadoop.hive.ql.metadata.Table and - // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException - // which indicates the SerDe we used is not Serializable. + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + table.serde.foreach(serdeInfo.setSerializationLib) + sd.setSerdeInfo(serdeInfo) - @transient val hiveQlTable = new Table(table) + val serdeParameters = new java.util.HashMap[String, String]() + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) - @transient val hiveQlPartitions = partitions.map { p => - new Partition(hiveQlTable, p) + new Table(tTable) } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) - val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) + val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) + val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. An @@ -540,6 +837,44 @@ private[hive] case class MetastoreRelation } ) + // When metastore partition pruning is turned off, we cache the list of all partitions to + // mimic the behavior of Spark < 1.5 + lazy val allPartitions = table.getAllPartitions + + def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { + val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { + table.getPartitions(predicates) + } else { + allPartitions + } + + rawPartitions.map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values.asJava) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + } + /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { @@ -549,11 +884,7 @@ private[hive] case class MetastoreRelation } } - val tableDesc = HiveShim.getTableDesc( - Class.forName( - hiveQlTable.getSerializationLib, - true, - Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], + val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to @@ -563,35 +894,53 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( - f.getName, - sqlContext.ddlParser.parseType(f.getType), + implicit class SchemaAttribute(f: HiveColumn) { + def toAttribute: AttributeReference = AttributeReference( + f.name, + HiveMetastoreTypes.toDataType(f.hiveType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) } - // Must be a stable value since new attributes are born here. - val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + /** PartitionKey attributes */ + val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = table.schema.map(_.toAttribute) val output = attributes ++ partitionKeys /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o,o))) + val attributeMap = AttributeMap(output.map(o => (o, o))) /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + + override def newInstance(): MetastoreRelation = { + MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) + } } -object HiveMetastoreTypes { - protected val ddlParser = new DDLParser - def toDataType(metastoreType: String): DataType = synchronized { - ddlParser.parseType(metastoreType) +private[hive] object HiveMetastoreTypes { + def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) + + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)" } def toMetastoreType(dt: DataType): String = dt match { @@ -610,7 +959,7 @@ object HiveMetastoreTypes { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) + case d: DecimalType => decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ab305e1f82a5..d5cd7e98b526 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,84 +18,72 @@ package org.apache.spark.sql.hive import java.sql.Date +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.ql.{ErrorMsg, Context} +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils -import org.apache.spark.sql.SparkSQLParser +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ - -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler /** * Used when we need to start parsing the AST before deciding that we are going to pass the command * back for Hive to execute natively. Will be replaced with a native command that contains the * cmd string. */ -private[hive] case object NativePlaceholder extends Command +private[hive] case object NativePlaceholder extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq.empty + override def output: Seq[Attribute] = Seq.empty +} -/** - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. - * @param table The table to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. - * It is effective only when the table is a Hive table. - */ -case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends Command { - override def output = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false)(), - AttributeReference("data_type", StringType, nullable = false)(), - AttributeReference("comment", StringType, nullable = false)()) +private[hive] case class CreateTableAsSelect( + tableDesc: HiveTable, + child: LogicalPlan, + allowExisting: Boolean) extends UnaryNode with Command { + + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = + tableDesc.specifiedDatabase.isDefined && + tableDesc.schema.size > 0 && + tableDesc.serde.isDefined && + tableDesc.inputFormat.isDefined && + tableDesc.outputFormat.isDefined && + childrenResolved } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl { +private[hive] object HiveQl extends Logging { protected val nativeCommands = Seq( - "TOK_DESCFUNCTION", - "TOK_DESCDATABASE", - "TOK_SHOW_CREATETABLE", - "TOK_SHOWCOLUMNS", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWINDEXES", - "TOK_SHOWPARTITIONS", - "TOK_SHOWTABLES", - "TOK_SHOW_TBLPROPERTIES", - - "TOK_LOCKTABLE", - "TOK_SHOWLOCKS", - "TOK_UNLOCKTABLE", - - "TOK_CREATEROLE", - "TOK_DROPROLE", - "TOK_GRANT", - "TOK_GRANT_ROLE", - "TOK_REVOKE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - - "TOK_CREATEFUNCTION", - "TOK_DROPFUNCTION", - + "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE", "TOK_ALTERTABLE_ADDCOLS", "TOK_ALTERTABLE_ADDPARTS", "TOK_ALTERTABLE_ALTERPARTS", @@ -110,38 +98,75 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_DROPDATABASE", - "TOK_DROPINDEX", - "TOK_MSCK", - + "TOK_ALTERVIEW", "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", + + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_CREATEROLE", "TOK_CREATEVIEW", + + "TOK_DESCDATABASE", + "TOK_DESCFUNCTION", + + "TOK_DROPDATABASE", + "TOK_DROPFUNCTION", + "TOK_DROPINDEX", + "TOK_DROPROLE", + "TOK_DROPTABLE_PROPERTIES", "TOK_DROPVIEW", + "TOK_DROPVIEW_PROPERTIES", "TOK_EXPORT", + + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_IMPORT", + "TOK_LOAD", - "TOK_SWITCHDATABASE" + "TOK_LOCKTABLE", + + "TOK_MSCK", + + "TOK_REVOKE", + + "TOK_SHOW_COMPACTIONS", + "TOK_SHOW_CREATETABLE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLE_PRINCIPALS", + "TOK_SHOW_ROLES", + "TOK_SHOW_SET_ROLE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOW_TRANSACTIONS", + "TOK_SHOWCOLUMNS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWLOCKS", + "TOK_SHOWPARTITIONS", + + "TOK_SWITCHDATABASE", + + "TOK_UNLOCKTABLE" ) // Commands that we do not need to explain. protected val noExplainCommands = Seq( "TOK_DESCTABLE", + "TOK_SHOWTABLES", "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. ) ++ nativeCommands - protected val hqlParser = { - val fallback = new ExtendedHiveQlParser - new SparkSQLParser(fallback(_)) - } + protected val hqlParser = new ExtendedHiveQlParser /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -167,7 +192,7 @@ private[hive] object HiveQl { .map(ast => Option(ast).map(_.transform(rule)).orNull)) } catch { case e: Exception => - println(dumpTree(n)) + logError(dumpTree(n).toString) throw e } } @@ -176,7 +201,7 @@ private[hive] object HiveQl { * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. */ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.toSeq).getOrElse(Nil) + Option(s).map(_.asScala).getOrElse(Nil) /** * Returns this ASTNode with the text changed to `newText`. @@ -191,7 +216,7 @@ private[hive] object HiveQl { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren) + n.addChildren(newChildren.asJava) n } @@ -201,8 +226,8 @@ private[hive] object HiveQl { * Right now this function only checks the name, type, text and children of the node * for equality. */ - def checkEquals(other: ASTNode) { - def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) { + def checkEquals(other: ASTNode): Unit = { + def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { sys.error(s"$field does not match for trees. " + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") } @@ -214,17 +239,11 @@ private[hive] object HiveQl { val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] leftChildren zip rightChildren foreach { - case (l,r) => l checkEquals r + case (l, r) => l checkEquals r } } } - class ParseException(sql: String, cause: Throwable) - extends Exception(s"Failed to parse: $sql", cause) - - class SemanticException(msg: String) - extends Exception(s"Error in semantic analysis: $msg") - /** * Returns the AST for the given SQL string. */ @@ -234,18 +253,31 @@ private[hive] object HiveQl { * Otherwise, there will be Null pointer exception, * when retrieving properties form HiveConf. */ - val hContext = new Context(new HiveConf()) + val hContext = new Context(SessionState.get().getConf()) val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) hContext.clear() node } + /** + * Returns the HiveConf + */ + private[this] def hiveConf: HiveConf = { + val ss = SessionState.get() // SessionState is lazy initialization, it can be null here + if (ss == null) { + new HiveConf() + } else { + ss.getConf + } + } /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) + + val errorRegEx = "line (\\d+):(\\d+) (.*)".r /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String) = { + def createPlan(sql: String): LogicalPlan = { try { val tree = getAst(sql) if (nativeCommands contains tree.getText) { @@ -257,25 +289,27 @@ private[hive] object HiveQl { } } } catch { - case e: Exception => throw new ParseException(sql, e) - case e: NotImplementedError => sys.error( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) + case pe: org.apache.hadoop.hive.ql.parse.ParseException => + pe.getMessage match { + case errorRegEx(line, start, message) => + throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) + case otherMessage => + throw new AnalysisException(otherMessage) + } + case e: MatchError => throw e + case e: Exception => + throw new AnalysisException(e.getMessage) + case e: NotImplementedError => + throw new AnalysisException( + s""" + |Unsupported language features in query: $sql + |${dumpTree(getAst(sql))} + |$e + |${e.getStackTrace.head} + """.stripMargin) } } - /** Creates LogicalPlan for a given VIEW */ - def createPlanForView(view: Table, alias: Option[String]) = alias match { - // because hive use things like `_c0` to build the expanded text - // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(view.getTableName, createPlan(view.getViewExpandedText)) - case Some(aliasText) => Subquery(aliasText, createPlan(view.getViewExpandedText)) - } - def parseDdl(ddl: String): Seq[Attribute] = { val tree = try { @@ -288,11 +322,11 @@ private[hive] object HiveQl { assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") val tableOps = tree.getChildren val colList = - tableOps + tableOps.asScala .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") .getOrElse(sys.error("No columnList!")).getChildren - colList.map(nodeToAttribute) + colList.asScala.map(nodeToAttribute) } /** Extractor for matching Hive's AST Tokens. */ @@ -300,8 +334,9 @@ private[hive] object HiveQl { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { case t: ASTNode => + CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, - Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) case _ => None } } @@ -322,7 +357,7 @@ private[hive] object HiveQl { clauses } - def getClause(clauseName: String, nodeList: Seq[Node]) = + def getClause(clauseName: String, nodeList: Seq[Node]): Node = getClauseOption(clauseName, nodeList).getOrElse(sys.error( s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) @@ -347,7 +382,7 @@ private[hive] object HiveQl { DecimalType(precision.getText.toInt, scale.getText.toInt) case Token("TOK_DECIMAL", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -386,16 +421,11 @@ private[hive] object HiveQl { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => (None, tableOnly) case Seq(databaseName, table) => (Some(databaseName), table) } @@ -404,7 +434,9 @@ private[hive] object HiveQl { } protected def extractTableIdent(tableNameParts: Node): Seq[String] = { - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -413,16 +445,16 @@ private[hive] object HiveQl { } /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to + * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) + * is equivalent to * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 * Check the following link for details. - * + * https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup * * The bitmask denotes the grouping expressions validity for a grouping set, * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of + * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. */ protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { @@ -436,7 +468,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val bitmasks: Seq[Int] = setASTs.map(set => set match { case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => + case Token("TOK_GROUPING_SETS_EXPRESSION", children) => children.foldLeft(0)((bitmap, col) => { val colString = col.asInstanceOf[ASTNode].toStringTree() require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") @@ -448,6 +480,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (keys, bitmasks) } + protected def getProperties(node: Node): Seq[(String, String)] = node match { + case Token("TOK_TABLEPROPLIST", list) => + list.map { + case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => + (unquoteString(key) -> unquoteString(value)) + } + } + protected def nodeToPlan(node: Node): LogicalPlan = node match { // Special drop table that also uncaches. case Token("TOK_DROPTABLE", @@ -457,8 +497,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C DropTable(tableName, ifExists.nonEmpty) // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: - isNoscan) => + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: + isNoscan) => // Reference: // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables if (partitionSpec.nonEmpty) { @@ -474,23 +514,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)())) + ExplainCommand(OneRowRelation) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( nodeToPlan(crtTbl), - Seq(AttributeReference("plan", StringType,nullable = false)()), - extended != None) + extended = extended.isDefined) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( nodeToPlan(query), - Seq(AttributeReference("plan", StringType, nullable = false)()), - extended != None) + extended = extended.isDefined) case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -509,15 +547,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. val tableIdent = extractTableIdent(nameParts.head) DescribeCommand( - UnresolvedRelation(tableIdent, None), extended.isDefined) + UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => // It is describing a column with the format like "describe db.table column". NativePlaceholder case tableName => // It is describing a table with the format like "describe table". DescribeCommand( - UnresolvedRelation(Seq(tableName.getText), None), - extended.isDefined) + UnresolvedRelation(Seq(tableName.getText), None), isExtended = extended.isDefined) } } // All other cases. @@ -531,6 +568,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val ( Some(tableNameParts) :: _ /* likeTable */ :: + externalTable :: Some(query) :: allowExisting +: ignores) = @@ -538,6 +576,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Seq( "TOK_TABNAME", "TOK_LIKETABLE", + "EXTERNAL", "TOK_QUERY", "TOK_IFNOTEXISTS", "TOK_TABLECOMMENT", @@ -547,11 +586,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TABLESKEWED", // Skewed by "TOK_TABLEROWFORMAT", "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. - "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile - "TOK_TBLTEXTFILE", // Stored as TextFile - "TOK_TBLRCFILE", // Stored as RCFile - "TOK_TBLORCFILE", // Stored as ORC File + "TOK_FILEFORMAT_GENERIC", "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", @@ -559,18 +594,214 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C children) val (db, tableName) = extractDbNameTableName(tableNameParts) - CreateTableAsSelect(db, tableName, nodeToPlan(query), allowExisting != None, Some(node)) + // TODO add bucket support + var tableDesc: HiveTable = HiveTable( + specifiedDatabase = db, + name = tableName, + schema = Seq.empty[HiveColumn], + partitionColumns = Seq.empty[HiveColumn], + properties = Map[String, String](), + serdeProperties = Map[String, String](), + tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = None) + + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) + val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + } + + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) + + children.collect { + case list @ Token("TOK_TABCOLLIST", _) => + val cols = BaseSemanticAnalyzer.getColumns(list, true) + if (cols != null) { + tableDesc = tableDesc.copy( + schema = cols.asScala.map { field => + HiveColumn(field.getName, field.getType, field.getComment) + }) + } + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + // TODO support the sql text + tableDesc = tableDesc.copy(viewText = Option(comment)) + case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => + val cols = BaseSemanticAnalyzer.getColumns(list(0), false) + if (cols != null) { + tableDesc = tableDesc.copy( + partitionColumns = cols.asScala.map { field => + HiveColumn(field.getName, field.getType, field.getComment) + }) + } + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => + val serdeParams = new java.util.HashMap[String, String]() + child match { + case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => + val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) + serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) + serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) + if (rowChild2.length > 1) { + val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) + serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) + } + case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => + val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) + case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => + val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) + case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => + val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + if (!(lineDelim == "\n") && !(lineDelim == "10")) { + throw new AnalysisException( + SemanticAnalyzer.generateErrorMessage( + rowChild, + ErrorMsg.LINES_TERMINATED_BY_NON_NEWLINE.getMsg)) + } + serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) + case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => + val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) + // TODO support the nullFormat + case _ => assert(false) + } + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) + case Token("TOK_TABLELOCATION", child :: Nil) => + var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) + location = EximUtil.relativeToAbsolutePath(hiveConf, location) + tableDesc = tableDesc.copy(location = Option(location)) + case Token("TOK_TABLESERIALIZER", child :: Nil) => + tableDesc = tableDesc.copy( + serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) + if (child.getChildCount == 2) { + val serdeParams = new java.util.HashMap[String, String]() + BaseSemanticAnalyzer.readProps( + (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) + } + case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => + child.getText().toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } + + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } + + case _ => + throw new SemanticException( + s"Unrecognized file format in STORED AS clause: ${child.getText}") + } + + case Token("TOK_TABLESERIALIZER", + Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => + tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) + + otherProps match { + case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) + case Nil => + } - // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) + case list @ Token("TOK_TABLEFILEFORMAT", children) => + tableDesc = tableDesc.copy( + inputFormat = + Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), + outputFormat = + Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) + case Token("TOK_STORAGEHANDLER", _) => + throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) + case _ => // Unsupport features + } + + CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting != None) + + // If its not a "CTAS" like above then take it as a native command case Token("TOK_CREATETABLE", _) => NativePlaceholder // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder + Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder + + case Token("TOK_QUERY", queryArgs) + if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + // check if has CTE + insertClauses.last match { + case Token("TOK_CTE", cteClauses) => + val cteRelations = cteClauses.map(node => { + val relation = nodeToRelation(node).asInstanceOf[Subquery] + (relation.alias, relation) + }).toMap + (Some(args.head), insertClauses.init, Some(cteRelations)) + + case _ => (Some(args.head), insertClauses, None) + } - case Token("TOK_QUERY", - Token("TOK_FROM", fromClause :: Nil) :: - insertClauses) => + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + } // Return one query for each insert clause. val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => @@ -590,7 +821,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C clusterByClause :: distributeByClause :: limitClause :: - lateralViewClause :: Nil) = { + lateralViewClause :: + windowClause :: Nil) = { getClauses( Seq( "TOK_INSERT_INTO", @@ -608,13 +840,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_CLUSTERBY", "TOK_DISTRIBUTEBY", "TOK_LIMIT", - "TOK_LATERAL_VIEW"), + "TOK_LATERAL_VIEW", + "WINDOW"), singleInsert) } - val relations = nodeToRelation(fromClause) + val relations = fromClause match { + case Some(f) => nodeToRelation(f) + case None => OneRowRelation + } + val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.toSeq + val Seq(whereExpr) = whereNode.getChildren.asScala Filter(nodeToExpr(whereExpr), relations) }.getOrElse(relations) @@ -623,7 +860,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Script transformations are expressed as a select clause with a single expression of type // TOK_TRANSFORM - val transformation = select.getChildren.head match { + val transformation = select.getChildren.iterator().next() match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -647,26 +884,28 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) = clause match { + def matchSerDe(clause: Seq[ASTNode]) + : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, "", Nil) + (rowFormat, None, Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, serdeClass, Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (name, value) - } - (Nil, serdeClass, serdeProps) + (BaseSemanticAnalyzer.unescapeSQLString(name), + BaseSemanticAnalyzer.unescapeSQLString(value)) + } + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, "", Nil) + case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) @@ -690,24 +929,26 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withLateralView = lateralViewClause.map { lv => val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head - - val alias = - getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - - Generate( - nodesToGenerator(clauses), - join = true, - outer = false, - Some(alias.toLowerCase), - withWhere) + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() + + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText + + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = false, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + withWhere) }.getOrElse(withWhere) // The projection of the query can either be a normal projection, an aggregation // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) + val selectExpressions = + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -730,43 +971,77 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_CUBE_GROUPBY", children) => Cube(children.map(nodeToExpr), withLateralView, selectExpressions) case _ => sys.error("Expect WITH CUBE") - }), + }), Some(Project(selectExpressions, withLateralView))).flatten.head } - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withProject) else withProject - + // Handle HAVING clause. val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } // Note that we added a cast to boolean. If the expression itself is already boolean, // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withDistinct) - }.getOrElse(withDistinct) + Filter(Cast(havingExpr, BooleanType), withProject) + }.getOrElse(withProject) + // Handle SELECT DISTINCT + val withDistinct = + if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withHaving) + Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) case (None, Some(perPartitionOrdering), None, None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), + false, withDistinct) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, Some(clusterExprs)) => - Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) - case (None, None, None, None) => withHaving + Sort( + clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), + false, + RepartitionByExpression( + clusterExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) + case (None, None, None, None) => withDistinct case _ => sys.error("Unsupported set of ordering / distribution clauses.") } val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.head)) + limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) .map(Limit(_, withSort)) .getOrElse(withSort) + // Collect all window specifications defined in the WINDOW clause. + val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { + case Token("TOK_WINDOWDEF", + Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => + windowName -> nodesToWindowSpecification(spec) + }.toMap) + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val resolvedCrossReference = windowDefinitions.map { + windowDefMap => windowDefMap.map { + case (windowName, WindowSpecReference(other)) => + (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) + case o => o.asInstanceOf[(String, WindowSpecDefinition)] + } + } + + val withWindowDefinitions = + resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) + // TOK_INSERT_INTO means to add files to the table. // TOK_DESTINATION means to overwrite the table. val resultDestination = @@ -774,17 +1049,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val overwrite = intoClause.isEmpty nodeToDest( resultDestination, - withLimit, + withWindowDefinitions, overwrite) } // If there are multiple INSERTS just UNION them together into on query. - queries.reduceLeft(Union) + val query = queries.reduceLeft(Union) - case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) + + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") } val allJoinTokens = "(TOK_.*JOIN)".r @@ -798,14 +1077,17 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - nodeToRelation(relationClause)) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = isOuter.nonEmpty, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + nodeToRelation(relationClause)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -825,7 +1107,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val tableIdent = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -843,12 +1127,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLESPLITSAMPLE", Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => - Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation) + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + require( + fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) + && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 100]") + Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, + relation) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => val fraction = numerator.toDouble / denominator.toDouble - Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) case a: ASTNode => throw new NotImplementedError( s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : @@ -864,7 +1156,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) - val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + val joinExpressions = + tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => @@ -879,7 +1172,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Seq(false, false) => Inner }.toBuffer - val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) // Must be transform down. val joinedResult = joinedTables transform { @@ -889,7 +1182,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C joinType = joinType.remove(joinType.length - 1)) } - val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) // Unique join is not really the same as an outer join so we must group together results where // the joinExpressions are the same, taking the First of each value is only okay because the @@ -899,7 +1192,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // worth the number of hacks that will be required to implement it. Namely, we need to add // some sort of mapped star expansion that would expand all child output row to be similarly // named output expressions where some aggregate expression has been applied (i.e. First). - ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + throw new UnsupportedOperationException case Token(allJoinTokens(joinToken), relation1 :: @@ -953,7 +1247,27 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -961,30 +1275,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite) + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName}:" + + s"\n ${dumpTree(a).toString} ") } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", - e :: Nil) => + case Token("TOK_SELEXPR", e :: Nil) => Some(nodeToExpr(e)) - case Token("TOK_SELEXPR", - e :: Token(alias, Nil) :: Nil) => + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + case Token("TOK_SELEXPR", e :: aliasChildren) => + var aliasNames = ArrayBuffer[String]() + aliasChildren.foreach { _ match { + case Token(name, Nil) => aliasNames += cleanIdentifier(name) + case _ => + } + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + /* Hints are ignored */ case Token("TOK_HINTLIST", _) => None case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName }:" + + s"\n ${dumpTree(a).toString } ") } - protected val escapedIdentifier = "`([^`]+)`".r + protected val doubleQuotedString = "\"([^\"]+)\"".r + protected val singleQuotedString = "'([^']+)'".r + + protected def unquoteString(str: String) = str match { + case singleQuotedString(s) => s + case doubleQuotedString(s) => s + case other => other + } + /** Strips backticks from ident if present */ protected def cleanIdentifier(ident: String): String = ident match { case escapedIdentifier(i) => i @@ -999,16 +1330,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveParser.DecimalLiteral) /* Case insensitive matches */ - val ARRAY = "(?i)ARRAY".r - val COALESCE = "(?i)COALESCE".r val COUNT = "(?i)COUNT".r - val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r - val MAX = "(?i)MAX".r - val MIN = "(?i)MIN".r - val UPPER = "(?i)UPPER".r - val LOWER = "(?i)LOWER".r - val RAND = "(?i)RAND".r val AND = "(?i)AND".r val OR = "(?i)OR".r val NOT = "(?i)NOT".r @@ -1022,19 +1345,17 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r - val SUBSTR = "(?i)SUBSTR(?:ING)?".r - val SQRT = "(?i)SQRT".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute(cleanIdentifier(name)) + UnresolvedAttribute.quoted(cleanIdentifier(name)) case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { - case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) - case other => GetField(other, attr) + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) + case other => UnresolvedExtractValue(other, Literal(attr)) } /* Stars (*) */ @@ -1045,18 +1366,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(name)) /* Aggregate Functions */ - case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg)) case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) - - /* System functions about string operations */ - case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg)) /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => @@ -1086,13 +1398,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.Unlimited) + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DateType) /* Arithmetic */ + case Token("+", child :: Nil) => nodeToExpr(child) case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) @@ -1105,7 +1418,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) @@ -1153,41 +1465,45 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => CaseWhen(branches.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val transformed = branches.drop(1).sliding(2, 2).map { - case Seq(condVal, value) => - // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval(). - // Hence effectful / non-deterministic key expressions are *not* supported at the moment. - // We should consider adding new Expressions to get around this. - Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)), - nodeToExpr(value)) - case Seq(elseVal) => Seq(nodeToExpr(elseVal)) - }.toSeq.reduce(_ ++ _) - CaseWhen(transformed) + val keyExpr = nodeToExpr(branches.head) + CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) /* Complex datatype manipulation */ case Token("[", child :: ordinal :: Nil) => - GetItem(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Other functions */ - case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) => - CreateArray(children.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) - case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) + UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) + + /* Window Functions */ + case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) + nodesToWindowSpecification(spec) match { + case reference: WindowSpecReference => + UnresolvedWindowExpression(function, reference) + case definition: WindowSpecDefinition => + WindowExpression(function, definition) + } + case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => + // Safe to use Literal(1)? + val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) + nodesToWindowSpecification(spec) match { + case reference: WindowSpecReference => + UnresolvedWindowExpression(function, reference) + case definition: WindowSpecDefinition => + WindowExpression(function, definition) + } /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ - case Token("TOK_NULL", Nil) => Literal(null, NullType) - case Token(TRUE(), Nil) => Literal(true, BooleanType) - case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) @@ -1198,21 +1514,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C try { if (ast.getText.endsWith("L")) { // Literal bigint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) } else if (ast.getText.endsWith("S")) { // Literal smallint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) } else if (ast.getText.endsWith("Y")) { // Literal tinyint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") v = Literal(Decimal(strVal)) } else { - v = Literal(ast.getText.toDouble, DoubleType) - v = Literal(ast.getText.toLong, LongType) - v = Literal(ast.getText.toInt, IntegerType) + v = Literal.create(ast.getText.toDouble, DoubleType) + v = Literal.create(ast.getText.toLong, LongType) + v = Literal.create(ast.getText.toInt, IntegerType) } } catch { case nfe: NumberFormatException => // Do nothing @@ -1230,6 +1546,33 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) + case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => + Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : @@ -1237,9 +1580,104 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r + def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { + case Token(windowName, Nil) :: Nil => + // Refer to a window spec defined in the window clause. + WindowSpecReference(windowName) + case Nil => + // OVER() + WindowSpecDefinition( + partitionSpec = Nil, + orderSpec = Nil, + frameSpecification = UnspecifiedFrame) + case spec => + val (partitionClause :: rowFrame :: rangeFrame :: Nil) = + getClauses( + Seq( + "TOK_PARTITIONINGSPEC", + "TOK_WINDOWRANGE", + "TOK_WINDOWVALUES"), + spec) + + // Handle Partition By and Order By. + val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => + val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = + getClauses( + Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), + partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) + + (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { + case (Some(partitionByExpr), Some(orderByExpr), None) => + (partitionByExpr.getChildren.asScala.map(nodeToExpr), + orderByExpr.getChildren.asScala.map(nodeToSortOrder)) + case (Some(partitionByExpr), None, None) => + (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) + case (None, Some(orderByExpr), None) => + (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) + case (None, None, Some(clusterByExpr)) => + val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) + (expressions, expressions.map(SortOrder(_, Ascending))) + case _ => + throw new NotImplementedError( + s"""No parse rules for Node ${partitionAndOrdering.getName} + """.stripMargin) + } + }.getOrElse { + (Nil, Nil) + } + + // Handle Window Frame + val windowFrame = + if (rowFrame.isEmpty && rangeFrame.isEmpty) { + UnspecifiedFrame + } else { + val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) + def nodeToBoundary(node: Node): FrameBoundary = node match { + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow + case _ => + throw new NotImplementedError( + s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} + """.stripMargin) + } + + rowFrame.orElse(rangeFrame).map { frame => + frame.getChildren.asScala.toList match { + case precedingNode :: followingNode :: Nil => + SpecifiedWindowFrame( + frameType, + nodeToBoundary(precedingNode), + nodeToBoundary(followingNode)) + case precedingNode :: Nil => + SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) + case _ => + throw new NotImplementedError( + s"""No parse rules for the Window Frame based on Node ${frame.getName} + """.stripMargin) + } + }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) + } + + WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + } val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): Generator = { + def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head val attributes = nodes.flatMap { @@ -1249,13 +1687,17 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C function match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - Explode(attributes, nodeToExpr(child)) + (Explode(nodeToExpr(child)), attributes) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - HiveGenericUdtf( - new HiveFunctionWrapper(functionName), - attributes, - children.map(nodeToExpr)) + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + + (HiveGenericUDTF( + new HiveFunctionWrapper(functionClassName), + children.map(nodeToExpr)), attributes) case a: ASTNode => throw new NotImplementedError( @@ -1268,11 +1710,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) : StringBuilder = { node match { - case a: ASTNode => builder.append((" " * indent) + a.getText + "\n") + case a: ASTNode => builder.append( + (" " * indent) + a.getText + " " + + a.getLine + ", " + + a.getTokenStartIndex + "," + + a.getTokenStopIndex + ", " + + a.getCharPositionInLine + "\n") case other => sys.error(s"Non ASTNode encountered: $other") } - Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) builder } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 000000000000..004805f3aed0 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import org.apache.avro.Schema + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null && ids.nonEmpty) { + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) + } + if (names != null && names.nonEmpty) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d89111094b9f..d38ad9127327 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,24 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row - -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsLogicalPlan, CreateTableUsingAsSelect, CreateTableUsing} -import org.apache.spark.sql.types.StringType private[hive] trait HiveStrategies { @@ -43,128 +33,6 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext - /** - * :: Experimental :: - * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet - * table scan operator. - * - * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring - * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. - * - * Other issues: - * - Much of this logic assumes case insensitive resolution. - */ - @Experimental - object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase = DataFrame(s.sqlContext, s.logicalPlan) - - def addPartitioningAttributes(attrs: Seq[Attribute]) = { - // Don't add the partitioning key if its already present in the data. - if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { - s - } else { - DataFrame( - s.sqlContext, - s.logicalPlan transform { - case p: ParquetRelation => p.copy(partitioningAttributes = attrs) - }) - } - } - } - - implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]) = - OutputFaker( - originalPlan.output.map(a => - newOutput.find(a.name.toLowerCase == _.name.toLowerCase) - .getOrElse( - sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), - originalPlan) - } - - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) - if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet => - - // Filter out all predicates that only deal with partition keys - val partitionsKeys = AttributeSet(relation.partitionKeys) - val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.subsetOf(partitionsKeys) - } - - // We are going to throw the predicates and projection back at the whole optimization - // sequence so lets unresolve all the attributes, allowing them to be rebound to the - // matching parquet attributes. - val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true))) - - val unresolvedProjection: Seq[Column] = projectList.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).map(Column(_)) - - try { - if (relation.hiveQlTable.isPartitioned) { - val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) - // Translate the predicate so that it automatically casts the input values to the - // correct data types during evaluation. - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) - val key = relation.partitionKeys(idx) - Cast(BoundReference(idx, StringType, nullable = true), key.dataType) - } - - val inputData = new GenericMutableRow(relation.partitionKeys.size) - val pruningCondition = - if (codegenEnabled) { - GeneratePredicate(castedPredicate) - } else { - InterpretedPredicate(castedPredicate) - } - - val partitions = relation.hiveQlPartitions.filter { part => - val partitionValues = part.getValues - var i = 0 - while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) - i += 1 - } - pruningCondition(inputData) - } - - hiveContext - .parquetFile(partitions.map(_.getLocation).mkString(",")) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } else { - hiveContext - .parquetFile(relation.hiveQlTable.getDataLocation.toString) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - } catch { - // parquetFile will throw an exception when there is no data. - // TODO: Remove this hack for Spark 1.3. - case iae: java.lang.IllegalArgumentException - if iae.getMessage.contains("Can not create a Path from an empty string") => - PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil - } - case _ => Nil - } - } - object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => @@ -175,12 +43,14 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil - case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) => + table, partition, planLater(child), overwrite, ifNotExists) :: Nil + case hive.InsertIntoHiveTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil + table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil } } @@ -204,7 +74,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil case _ => Nil } @@ -212,20 +82,17 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) => - ExecutedCommand( - CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil - - case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) => - val logicalPlan = hiveContext.parseSql(query) + case CreateTableUsing( + tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan) + CreateMetastoreDataSource( + tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) => + case CreateTableUsingAsSelect( + tableIdent, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query) + CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil @@ -240,8 +107,11 @@ private[hive] trait HiveStrategies { case t: MetastoreRelation => ExecutedCommand( DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil + case o: LogicalPlan => - ExecutedCommand(RunnableDescribeCommand(planLater(o), describe.output)) :: Nil + val resultPlan = context.executePlan(o).executedPlan + ExecutedCommand(RunnableDescribeCommand( + resultPlan, describe.output, describe.isExtended)) :: Nil } case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c368715f7c6f..dc355690852b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ @@ -25,23 +24,27 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.SerializableWritable +import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[Row] + def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] } @@ -55,7 +58,7 @@ class HadoopTableReader( @transient relation: MetastoreRelation, @transient sc: HiveContext, @transient hiveExtraConf: HiveConf) - extends TableReader { + extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". // https://hadoop.apache.org/docs/r1.0.4/mapred-default.html @@ -70,12 +73,12 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) - override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, - relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + Utils.classForName(relation.tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -90,7 +93,7 @@ class HadoopTableReader( def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[Row] = { + filterOpt: Option[PathFilter]): RDD[InternalRow] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") @@ -115,13 +118,13 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -140,10 +143,49 @@ class HadoopTableReader( def makeRDDForPartitionedTable( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[Row] = { - val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + filterOpt: Option[PathFilter]): RDD[InternalRow] = { + + // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists + def verifyPartitionPath( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]): + Map[HivePartition, Class[_ <: Deserializer]] = { + if (!sc.conf.verifyPartitionPath) { + partitionToDeserializer + } else { + var existPathSet = collection.mutable.Set[String]() + var pathPatternSet = collection.mutable.Set[String]() + partitionToDeserializer.filter { + case (partition, partDeserializer) => + def updateExistPathSetByPathPattern(pathPatternStr: String) { + val pathPattern = new Path(pathPatternStr) + val fs = pathPattern.getFileSystem(sc.hiveconf) + val matches = fs.globStatus(pathPattern) + matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) + } + // convert /demo/data/year/month/day to /demo/data/*/*/*/ + def getPathPatternByPath(parNum: Int, tempPath: Path): String = { + var path = tempPath + for (i <- (1 to parNum)) path = path.getParent + val tails = (1 to parNum).map(_ => "*").mkString("/", "/", "/") + path.toString + tails + } + + val partPath = partition.getDataLocation + val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); + var pathPatternStr = getPathPatternByPath(partNum, partPath) + if (!pathPatternSet.contains(pathPatternStr)) { + pathPatternSet += pathPatternStr + updateExistPathSetByPathPattern(pathPatternStr) + } + existPathSet.contains(partPath.toString) + } + } + } + + val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) + .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] @@ -174,7 +216,7 @@ class HadoopTableReader( relation.partitionKeys.contains(attr) } - def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = relation.partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) @@ -188,15 +230,19 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) + // get the table deserializer + val tableSerDe = tableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, tableDesc.getProperties) // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, + mutableRow, tableSerDe) } }.toSeq // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Row](sc.sparkContext) + new EmptyRDD[InternalRow](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -229,7 +275,7 @@ class HadoopTableReader( val rdd = new HadoopRDD( sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, classOf[Writable], @@ -241,13 +287,13 @@ class HadoopTableReader( } } -private[hive] object HadoopTableReader extends HiveInspectors { +private[hive] object HadoopTableReader extends HiveInspectors with Logging { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. */ def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { - FileInputFormat.setInputPaths(jobConf, path) + FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) if (tableDesc != null) { PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) @@ -260,25 +306,38 @@ private[hive] object HadoopTableReader extends HiveInspectors { * Transform all given raw `Writable`s into `Row`s. * * @param iterator Iterator of all `Writable`s to be transformed - * @param deserializer The `Deserializer` associated with the input `Writable` + * @param rawDeser The `Deserializer` associated with the input `Writable` * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled + * @param tableDeser Table Deserializer * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( iterator: Iterator[Writable], - deserializer: Deserializer, + rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow, + tableDeser: Deserializer): Iterator[InternalRow] = { + + val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { + rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] + } else { + ObjectInspectorConverters.getConvertedOI( + rawDeser.getObjectInspector, + tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] + } + + logDebug(soi.toString) - val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal }.unzip - // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern - // matching and branching costs per row. + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => @@ -297,16 +356,16 @@ private[hive] object HadoopTableReader extends HiveInspectors { (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value)) + row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) @@ -315,9 +374,11 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } + val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) + // Map each tuple to a row object iterator.map { value => - val raw = deserializer.deserialize(value) + val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) @@ -329,7 +390,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { i += 1 } - mutableRow: Row + mutableRow: InternalRow } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala deleted file mode 100644 index 7c1d1133c342..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ /dev/null @@ -1,455 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import java.io.File -import java.util.{Set => JavaSet} - -import scala.collection.mutable -import scala.language.implicitConversions - -import org.apache.hadoop.hive.ql.exec.FunctionRegistry -import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.serde2.RegexSerDe -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.hadoop.hive.serde2.avro.AvroSerDe - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.CacheTableCommand -import org.apache.spark.sql.hive._ -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.execution.HiveNativeCommand - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - -// SPARK-3729: Test key required to check for initialization errors with config. -object TestHive - extends TestHiveContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) - -/** - * A locally running test instance of Spark's Hive execution engine. - * - * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. - * Calling [[reset]] will delete all tables and other state in the database, leaving the database - * in a "clean" state. - * - * TestHive is singleton object version of this class because instantiating multiple copies of the - * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of - * test cases that rely on TestHive must be serialized. - */ -class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { - self => - - // By clearing the port we force Spark to pick a new one. This allows us to rerun tests - // without restarting the JVM. - System.clearProperty("spark.hostPort") - CommandProcessorFactory.clean(hiveconf) - - hiveconf.set("hive.plan.serialization.format", "javaXML") - - lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath - - /** Sets up the system initially or after a RESET command */ - protected def configure(): Unit = { - setConf("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$metastorePath;create=true") - setConf("hive.metastore.warehouse.dir", warehousePath) - Utils.registerShutdownDeleteDir(new File(warehousePath)) - Utils.registerShutdownDeleteDir(new File(metastorePath)) - } - - val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") - testTempDir.delete() - testTempDir.mkdir() - Utils.registerShutdownDeleteDir(testTempDir) - - // For some hive test case which contain ${system:test.tmp.dir} - System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) - - configure() // Must be called before initializing the catalog below. - - /** The location of the compiled hive distribution */ - lazy val hiveHome = envVarToFile("HIVE_HOME") - /** The location of the hive source code. */ - lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") - - // Override so we can intercept relative paths and rewrite them to point at hive. - override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) - - override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - - /** Fewer partitions to speed up testing. */ - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - } - - /** - * Returns the value of specified environmental variable as a [[java.io.File]] after checking - * to ensure it exists - */ - private def envVarToFile(envVar: String): Option[File] = { - Option(System.getenv(envVar)).map(new File(_)) - } - - /** - * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the - * hive test cases assume the system is set up. - */ - private def rewritePaths(cmd: String): String = - if (cmd.toUpperCase contains "LOAD DATA") { - val testDataLocation = - hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) - cmd.replaceAll("\\.\\./\\.\\./", testDataLocation + "/") - } else { - cmd - } - - val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") - hiveFilesTemp.delete() - hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) - - val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { - new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) - } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + - File.separator + "resources") - } - - def getHiveFile(path: String): File = { - val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) - hiveDevHome - .map(new File(_, stripped)) - .filter(_.exists) - .getOrElse(new File(inRepoTests, stripped)) - } - - val describedTable = "DESCRIBE (\\w+)".r - - protected[hive] class HiveQLQueryExecution(hql: String) - extends this.QueryExecution(HiveQl.parseSql(hql)) { - def hiveExec() = runSqlHive(hql) - override def toString = hql + "\n" + super.toString - } - - /** - * Override QueryExecution with special debug workflow. - */ - class QueryExecution(logicalPlan: LogicalPlan) - extends super.QueryExecution(logicalPlan) { - override lazy val analyzed = { - val describedTables = logical match { - case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil - case CacheTableCommand(tbl, _, _) => tbl :: Nil - case _ => Nil - } - - // Make sure any test tables referenced are loaded. - val referencedTables = - describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } - val referencedTestTables = referencedTables.filter(testTables.contains) - logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") - referencedTestTables.foreach(loadTestTable) - // Proceed with analysis. - analyzer(logical) - } - } - - case class TestTable(name: String, commands: (()=>Unit)*) - - protected[hive] implicit class SqlCmd(sql: String) { - def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit - } - - /** - * A list of test tables and the DDL required to initialize them. A test table is loaded on - * demand when a query are run against it. - */ - lazy val testTables = new mutable.HashMap[String, TestTable]() - def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) - - // The test tables that are defined in the Hive QTestUtil. - // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java - val hiveQTestUtilTables = Seq( - TestTable("src", - "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), - TestTable("src1", - "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), - TestTable("srcpart", () => { - runSqlHive( - "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") - for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - runSqlHive( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') - """.stripMargin) - } - }), - TestTable("srcpart1", () => { - runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") - for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - runSqlHive( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') - """.stripMargin) - } - }), - TestTable("src_thrift", () => { - import org.apache.thrift.protocol.TBinaryProtocol - import org.apache.hadoop.hive.serde2.thrift.test.Complex - import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.mapred.SequenceFileInputFormat - import org.apache.hadoop.mapred.SequenceFileOutputFormat - - val srcThrift = new Table("default", "src_thrift") - srcThrift.setFields(Nil) - srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName) - // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat. - srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName) - srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName) - srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName) - srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName) - catalog.client.createTable(srcThrift) - - - runSqlHive( - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") - }), - TestTable("serdeins", - s"""CREATE TABLE serdeins (key INT, value STRING) - |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ('field.delim'='\\t') - """.stripMargin.cmd, - "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), - TestTable("sales", - s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) - |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), - TestTable("episodes", - s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' - |TBLPROPERTIES ( - | 'avro.schema.literal'='{ - | "type": "record", - | "name": "episodes", - | "namespace": "testing.hive.avro.serde", - | "fields": [ - | { - | "name": "title", - | "type": "string", - | "doc": "episode title" - | }, - | { - | "name": "air_date", - | "type": "string", - | "doc": "initial date" - | }, - | { - | "name": "doctor", - | "type": "int", - | "doc": "main actor playing the Doctor in episode" - | } - | ] - | }' - |) - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd - ), - // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING - // IS NOT YET SUPPORTED - TestTable("episodes_part", - s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) - |PARTITIONED BY (doctor_pt INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' - |TBLPROPERTIES ( - | 'avro.schema.literal'='{ - | "type": "record", - | "name": "episodes", - | "namespace": "testing.hive.avro.serde", - | "fields": [ - | { - | "name": "title", - | "type": "string", - | "doc": "episode title" - | }, - | { - | "name": "air_date", - | "type": "string", - | "doc": "initial date" - | }, - | { - | "name": "doctor", - | "type": "int", - | "doc": "main actor playing the Doctor in episode" - | } - | ] - | }' - |) - """.stripMargin.cmd, - // WORKAROUND: Required to pass schema to SerDe for partitioned tables. - // TODO: Pass this automatically from the table to partitions. - s""" - |ALTER TABLE episodes_part SET SERDEPROPERTIES ( - | 'avro.schema.literal'='{ - | "type": "record", - | "name": "episodes", - | "namespace": "testing.hive.avro.serde", - | "fields": [ - | { - | "name": "title", - | "type": "string", - | "doc": "episode title" - | }, - | { - | "name": "air_date", - | "type": "string", - | "doc": "initial date" - | }, - | { - | "name": "doctor", - | "type": "int", - | "doc": "main actor playing the Doctor in episode" - | } - | ] - | }' - |) - """.stripMargin.cmd, - s""" - INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) - SELECT title, air_date, doctor FROM episodes - """.cmd - ) - ) - - hiveQTestUtilTables.foreach(registerTestTable) - - private val loadedTables = new collection.mutable.HashSet[String] - - var cacheTables: Boolean = false - def loadTestTable(name: String) { - if (!(loadedTables contains name)) { - // Marks the table as loaded first to prevent infinite mutually recursive table loading. - loadedTables += name - logInfo(s"Loading test table $name") - val createCmds = - testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) - createCmds.foreach(_()) - - if (cacheTables) { - cacheTable(name) - } - } - } - - /** - * Records the UDFs present when the server starts, so we can delete ones that are created by - * tests. - */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames - - // Database default may not exist in 0.13.1, create it if not exist - HiveShim.createDefaultDBIfNeeded(this) - - /** - * Resets the test instance by deleting any tables that have been created. - * TODO: also clear out UDFs, views, etc. - */ - def reset() { - try { - // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => - log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) - } - - cacheManager.clearCache() - loadedTables.clear() - catalog.cachedDataSourceTables.invalidateAll() - catalog.client.getAllTables("default").foreach { t => - logDebug(s"Deleting table $t") - val table = catalog.client.getTable("default", t) - - catalog.client.getIndexes("default", t, 255).foreach { index => - catalog.client.dropIndex("default", t, index.getIndexName, true) - } - - if (!table.isIndexTable) { - catalog.client.dropTable("default", t) - } - } - - catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - catalog.client.dropDatabase(db, true, false, true) - } - - catalog.unregisterAllTables() - - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => - FunctionRegistry.unregisterTemporaryUDF(udfName) - } - - // Some tests corrupt this value on purpose, which breaks the RESET call below. - hiveconf.set("fs.default.name", new File(".").toURI.toString) - // It is important that we RESET first as broken hooks that might have been set could break - // other sql exec here. - runSqlHive("RESET") - // For some reason, RESET does not reset the following variables... - // https://issues.apache.org/jira/browse/HIVE-9004 - runSqlHive("set hive.table.parameters.default=") - runSqlHive("set datanucleus.cache.collections=true") - runSqlHive("set datanucleus.cache.collections.lazy=true") - // Lots of tests fail if we do not change the partition whitelist from the default. - runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") - configure() - - runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") - } catch { - case e: Exception => - logError("FATAL ERROR: Failed to reset TestDB state.", e) - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala new file mode 100644 index 000000000000..3811c152a7ae --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.PrintStream +import java.util.{Map => JMap} + +import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} +import org.apache.spark.sql.catalyst.expressions.Expression + +private[hive] case class HiveDatabase( + name: String, + location: String) + +private[hive] abstract class TableType { val name: String } +private[hive] case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } +private[hive] case object IndexTable extends TableType { override val name = "INDEX_TABLE" } +private[hive] case object ManagedTable extends TableType { override val name = "MANAGED_TABLE" } +private[hive] case object VirtualView extends TableType { override val name = "VIRTUAL_VIEW" } + +// TODO: Use this for Tables and Partitions +private[hive] case class HiveStorageDescriptor( + location: String, + inputFormat: String, + outputFormat: String, + serde: String, + serdeProperties: Map[String, String]) + +private[hive] case class HivePartition( + values: Seq[String], + storage: HiveStorageDescriptor) + +private[hive] case class HiveColumn(name: String, hiveType: String, comment: String) +private[hive] case class HiveTable( + specifiedDatabase: Option[String], + name: String, + schema: Seq[HiveColumn], + partitionColumns: Seq[HiveColumn], + properties: Map[String, String], + serdeProperties: Map[String, String], + tableType: TableType, + location: Option[String] = None, + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None, + viewText: Option[String] = None) { + + @transient + private[client] var client: ClientInterface = _ + + private[client] def withClient(ci: ClientInterface): this.type = { + client = ci + this + } + + def database: String = specifiedDatabase.getOrElse(sys.error("database not resolved")) + + def isPartitioned: Boolean = partitionColumns.nonEmpty + + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + + def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = + client.getPartitionsByFilter(this, predicates) + + // Hive does not support backticks when passing names to the client. + def qualifiedName: String = s"$database.$name" +} + +/** + * An externally visible interface to the Hive client. This interface is shared across both the + * internal and external classloaders for a given version of Hive and thus must expose only + * shared classes. + */ +private[hive] trait ClientInterface { + + /** Returns the Hive Version of this client. */ + def version: HiveVersion + + /** Returns the configuration for the given key in the current session. */ + def getConf(key: String, defaultValue: String): String + + /** + * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will + * result in one string. + */ + def runSqlHive(sql: String): Seq[String] + + def setOut(stream: PrintStream): Unit + def setInfo(stream: PrintStream): Unit + def setError(stream: PrintStream): Unit + + /** Returns the names of all tables in the given database. */ + def listTables(dbName: String): Seq[String] + + /** Returns the name of the active database. */ + def currentDatabase: String + + /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ + def getDatabase(name: String): HiveDatabase = { + getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) + } + + /** Returns the metadata for a given database, or None if it doesn't exist. */ + def getDatabaseOption(name: String): Option[HiveDatabase] + + /** Returns the specified table, or throws [[NoSuchTableException]]. */ + def getTable(dbName: String, tableName: String): HiveTable = { + getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException) + } + + /** Returns the metadata for the specified table or None if it doens't exist. */ + def getTableOption(dbName: String, tableName: String): Option[HiveTable] + + /** Creates a table with the given metadata. */ + def createTable(table: HiveTable): Unit + + /** Updates the given table with new metadata. */ + def alterTable(table: HiveTable): Unit + + /** Creates a new database with the given name. */ + def createDatabase(database: HiveDatabase): Unit + + /** Returns the specified paritition or None if it does not exist. */ + def getPartitionOption( + hTable: HiveTable, + partitionSpec: JMap[String, String]): Option[HivePartition] + + /** Returns all partitions for the given table. */ + def getAllPartitions(hTable: HiveTable): Seq[HivePartition] + + /** Returns partitions filtered by predicates for the given table. */ + def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] + + /** Loads a static partition into an existing table. */ + def loadPartition( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit + + /** Loads data into an existing table. */ + def loadTable( + loadPath: String, // TODO URI + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit + + /** Loads new dynamic partitions into an existing table. */ + def loadDynamicPartitions( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit + + /** Used for testing only. Removes all metadata from this instance of Hive. */ + def reset(): Unit +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala new file mode 100644 index 000000000000..4d1e3ed9198e --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.{File, PrintStream} +import java.util.{Map => JMap} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.language.reflectiveCalls + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} +import org.apache.hadoop.hive.metastore.{TableType => HTableType} +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.util.{CircularBuffer, Utils} + +/** + * A class that wraps the HiveClient and converts its responses to externally visible classes. + * Note that this class is typically loaded with an internal classloader for each instantiation, + * allowing it to interact directly with a specific isolated version of Hive. Loading this class + * with the isolated classloader however will result in it only being visible as a ClientInterface, + * not a ClientWrapper. + * + * This class needs to interact with multiple versions of Hive, but will always be compiled with + * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility + * must use reflection after matching on `version`. + * + * @param version the version of hive used when pick function calls that are not compatible. + * @param config a collection of configuration options that will be added to the hive conf before + * opening the hive client. + * @param initClassLoader the classloader used when creating the `state` field of + * this ClientWrapper. + */ +private[hive] class ClientWrapper( + override val version: HiveVersion, + config: Map[String, String], + initClassLoader: ClassLoader) + extends ClientInterface + with Logging { + + overrideHadoopShims() + + // !! HACK ALERT !! + // + // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking + // major version number gathered from Hadoop jar files: + // + // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with + // security. + // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. + // + // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It + // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but + // `Hadoop23Shims` is chosen because the major version number here is 2. + // + // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` + // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is + // not available, we decide whether to override the shims or not by checking for existence of a + // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. + private def overrideHadoopShims(): Unit = { + val hadoopVersion = VersionInfo.getVersion + val VersionPattern = """(\d+)\.(\d+).*""".r + + hadoopVersion match { + case null => + logError("Failed to inspect Hadoop version") + + // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. + val probeMethod = "getPathWithoutSchemeAndAuthority" + if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { + logInfo( + s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + + s"we are probably using Hadoop 1.x or 2.0.x") + loadHadoop20SShims() + } + + case VersionPattern(majorVersion, minorVersion) => + logInfo(s"Inspected Hadoop version: $hadoopVersion") + + // Loads Hadoop20SShims for 1.x and 2.0.x versions + val (major, minor) = (majorVersion.toInt, minorVersion.toInt) + if (major < 2 || (major == 2 && minor == 0)) { + loadHadoop20SShims() + } + } + + // Logs the actual loaded Hadoop shims class + val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName + logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") + } + + private def loadHadoop20SShims(): Unit = { + val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" + logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") + + try { + val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") + // scalastyle:off classforname + val shimsClass = Class.forName(hadoop20SShimsClassName) + // scalastyle:on classforname + val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) + shimsField.setAccessible(true) + shimsField.set(null, shims) + } catch { case cause: Throwable => + throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) + } + } + + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. + private val outputBuffer = new CircularBuffer() + + private val shim = version match { + case hive.v12 => new Shim_v0_12() + case hive.v13 => new Shim_v0_13() + case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() + } + + // Create an internal session state for this ClientWrapper. + val state = { + val original = Thread.currentThread().getContextClassLoader + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) + val ret = try { + val oldState = SessionState.get() + if (oldState == null) { + val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) + config.foreach { case (k, v) => + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") + } + initialConf.set(k, v) + } + val newState = new SessionState(initialConf) + SessionState.start(newState) + newState.out = new PrintStream(outputBuffer, true, "UTF-8") + newState.err = new PrintStream(outputBuffer, true, "UTF-8") + newState + } else { + oldState + } + } finally { + Thread.currentThread().setContextClassLoader(original) + } + ret + } + + /** Returns the configuration for the current session. */ + def conf: HiveConf = SessionState.get().getConf + + override def getConf(key: String, defaultValue: String): String = { + conf.get(key, defaultValue) + } + + // TODO: should be a def?s + // When we create this val client, the HiveConf of it (conf) is the one associated with state. + @GuardedBy("this") + private var client = Hive.get(conf) + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClientWrapper got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + Thread.sleep(retryDelayMillis) + try { + client = Hive.get(state.getConf, true) + } catch { + case e: Exception if causedByThrift(e) => + logWarning("Failed to refresh hive client, will retry.", e) + } + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } + + /** + * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. + */ + private def withHiveState[A](f: => A): A = retryLocked { + val original = Thread.currentThread().getContextClassLoader + // Set the thread local metastore client to the client associated with this ClientWrapper. + Hive.set(client) + // setCurrentSessionState will use the classLoader associated + // with the HiveConf in `state` to override the context class loader of the current + // thread. + shim.setCurrentSessionState(state) + val ret = try f finally { + Thread.currentThread().setContextClassLoader(original) + } + ret + } + + def setOut(stream: PrintStream): Unit = withHiveState { + state.out = stream + } + + def setInfo(stream: PrintStream): Unit = withHiveState { + state.info = stream + } + + def setError(stream: PrintStream): Unit = withHiveState { + state.err = stream + } + + override def currentDatabase: String = withHiveState { + state.getCurrentDatabase + } + + override def createDatabase(database: HiveDatabase): Unit = withHiveState { + client.createDatabase( + new Database( + database.name, + "", + new File(database.location).toURI.toString, + new java.util.HashMap), + true) + } + + override def getDatabaseOption(name: String): Option[HiveDatabase] = withHiveState { + Option(client.getDatabase(name)).map { d => + HiveDatabase( + name = d.getName, + location = d.getLocationUri) + } + } + + override def getTableOption( + dbName: String, + tableName: String): Option[HiveTable] = withHiveState { + + logDebug(s"Looking up $dbName.$tableName") + + val hiveTable = Option(client.getTable(dbName, tableName, false)) + val converted = hiveTable.map { h => + + HiveTable( + name = h.getTableName, + specifiedDatabase = Option(h.getDbName), + schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + partitionColumns = h.getPartCols.asScala.map(f => + HiveColumn(f.getName, f.getType, f.getComment)), + properties = h.getParameters.asScala.toMap, + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, + tableType = h.getTableType match { + case HTableType.MANAGED_TABLE => ManagedTable + case HTableType.EXTERNAL_TABLE => ExternalTable + case HTableType.VIRTUAL_VIEW => VirtualView + case HTableType.INDEX_TABLE => IndexTable + }, + location = shim.getDataLocation(h), + inputFormat = Option(h.getInputFormatClass).map(_.getName), + outputFormat = Option(h.getOutputFormatClass).map(_.getName), + serde = Option(h.getSerializationLib), + viewText = Option(h.getViewExpandedText)).withClient(this) + } + converted + } + + private def toInputFormat(name: String) = + Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + + private def toOutputFormat(name: String) = + Utils.classForName(name) + .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] + + private def toQlTable(table: HiveTable): metadata.Table = { + val qlTable = new metadata.Table(table.database, table.name) + + qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + qlTable.setPartCols( + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } + table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } + + // set owner + qlTable.setOwner(conf.getUser) + // set create time + qlTable.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) + + table.location.foreach { loc => shim.setDataLocation(qlTable, loc) } + table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass) + table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass) + table.serde.foreach(qlTable.setSerializationLib) + + qlTable + } + + override def createTable(table: HiveTable): Unit = withHiveState { + val qlTable = toQlTable(table) + client.createTable(qlTable) + } + + override def alterTable(table: HiveTable): Unit = withHiveState { + val qlTable = toQlTable(table) + client.alterTable(table.qualifiedName, qlTable) + } + + private def toHivePartition(partition: metadata.Partition): HivePartition = { + val apiPartition = partition.getTPartition + HivePartition( + values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), + storage = HiveStorageDescriptor( + location = apiPartition.getSd.getLocation, + inputFormat = apiPartition.getSd.getInputFormat, + outputFormat = apiPartition.getSd.getOutputFormat, + serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) + } + + override def getPartitionOption( + table: HiveTable, + partitionSpec: JMap[String, String]): Option[HivePartition] = withHiveState { + + val qlTable = toQlTable(table) + val qlPartition = client.getPartition(qlTable, partitionSpec, false) + Option(qlPartition).map(toHivePartition) + } + + override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState { + val qlTable = toQlTable(hTable) + shim.getAllPartitions(client, qlTable).map(toHivePartition) + } + + override def getPartitionsByFilter( + hTable: HiveTable, + predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { + val qlTable = toQlTable(hTable) + shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) + } + + override def listTables(dbName: String): Seq[String] = withHiveState { + client.getAllTables(dbName).asScala + } + + /** + * Runs the specified SQL query using Hive. + */ + override def runSqlHive(sql: String): Seq[String] = { + val maxResults = 100000 + val results = runHive(sql, maxResults) + // It is very confusing when you only get back some of the results... + if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") + results + } + + /** + * Execute the command using Hive and return the results as a sequence. Each element + * in the sequence is one row. + */ + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { + logDebug(s"Running hiveql '$cmd'") + if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + try { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + // The remainder of the command. + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + val proc = shim.getCommandProcessor(tokens(0), conf) + proc match { + case driver: Driver => + val response: CommandProcessorResponse = driver.run(cmd) + // Throw an exception if there is an error in query processing. + if (response.getResponseCode != 0) { + driver.close() + throw new QueryExecutionException(response.getErrorMessage) + } + driver.setMaxRows(maxRows) + + val results = shim.getDriverResults(driver) + driver.close() + results + + case _ => + if (state.out != null) { + // scalastyle:off println + state.out.println(tokens(0) + " " + cmd_1) + // scalastyle:on println + } + Seq(proc.run(cmd_1).getResponseCode.toString) + } + } catch { + case e: Exception => + logError( + s""" + |====================== + |HIVE FAILURE OUTPUT + |====================== + |${outputBuffer.toString} + |====================== + |END HIVE FAILURE OUTPUT + |====================== + """.stripMargin) + throw e + } + } + + def loadPartition( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = withHiveState { + shim.loadPartition( + client, + new Path(loadPath), // TODO: Use URI + tableName, + partSpec, + replace, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + + def loadTable( + loadPath: String, // TODO URI + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = withHiveState { + shim.loadTable( + client, + new Path(loadPath), + tableName, + replace, + holdDDLTime) + } + + def loadDynamicPartitions( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = withHiveState { + shim.loadDynamicPartitions( + client, + new Path(loadPath), + tableName, + partSpec, + replace, + numDP, + holdDDLTime, + listBucketingEnabled) + } + + def reset(): Unit = withHiveState { + client.getAllTables("default").asScala.foreach { t => + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) + } + if (!table.isIndexTable) { + client.dropTable("default", t) + } + } + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala new file mode 100644 index 000000000000..48bbb21e6c1d --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -0,0 +1,525 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} +import java.lang.reflect.{Method, Modifier} +import java.net.URI +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, IntegralType} + +/** + * A shim that defines the interface between ClientWrapper and the underlying Hive library used to + * talk to the metastore. Each Hive version has its own implementation of this class, defining + * version-specific version of needed functions. + * + * The guideline for writing shims is: + * - always extend from the previous version unless really not possible + * - initialize methods in lazy vals, both for quicker access for multiple invocations, and to + * avoid runtime errors due to the above guideline. + */ +private[client] sealed abstract class Shim { + + /** + * Set the current SessionState to the given SessionState. Also, set the context classloader of + * the current thread to the one set in the HiveConf of this given `state`. + * @param state + */ + def setCurrentSessionState(state: SessionState): Unit + + /** + * This shim is necessary because the return type is different on different versions of Hive. + * All parameters are the same, though. + */ + def getDataLocation(table: Table): Option[String] + + def setDataLocation(table: Table, loc: String): Unit + + def getAllPartitions(hive: Hive, table: Table): Seq[Partition] + + def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] + + def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor + + def getDriverResults(driver: Driver): Seq[String] + + def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + + def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit + + def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit + + def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit + + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { + val method = findMethod(klass, name, args: _*) + require(Modifier.isStatic(method.getModifiers()), + s"Method $name of class $klass is not static.") + method + } + + protected def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = { + klass.getMethod(name, args: _*) + } + +} + +private[client] class Shim_v0_12 extends Shim with Logging { + + private lazy val startMethod = + findStaticMethod( + classOf[SessionState], + "start", + classOf[SessionState]) + private lazy val getDataLocationMethod = findMethod(classOf[Table], "getDataLocation") + private lazy val setDataLocationMethod = + findMethod( + classOf[Table], + "setDataLocation", + classOf[URI]) + private lazy val getAllPartitionsMethod = + findMethod( + classOf[Hive], + "getAllPartitionsForPruner", + classOf[Table]) + private lazy val getCommandProcessorMethod = + findStaticMethod( + classOf[CommandProcessorFactory], + "get", + classOf[String], + classOf[HiveConf]) + private lazy val getDriverResultsMethod = + findMethod( + classOf[Driver], + "getResults", + classOf[JArrayList[String]]) + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) + + override def setCurrentSessionState(state: SessionState): Unit = { + // Starting from Hive 0.13, setCurrentSessionState will internally override + // the context class loader of the current thread by the class loader set in + // the conf of the SessionState. So, for this Hive 0.12 shim, we add the same + // behavior and make shim.setCurrentSessionState of all Hive versions have the + // consistent behavior. + Thread.currentThread().setContextClassLoader(state.getConf.getClassLoader) + startMethod.invoke(null, state) + } + + override def getDataLocation(table: Table): Option[String] = + Option(getDataLocationMethod.invoke(table)).map(_.toString()) + + override def setDataLocation(table: Table, loc: String): Unit = + setDataLocationMethod.invoke(table, new URI(loc)) + + override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq + + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. + // See HIVE-4888. + logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + + "Please use Hive 0.13 or higher.") + getAllPartitions(hive, table) + } + + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] + + override def getDriverResults(driver: Driver): Seq[String] = { + val res = new JArrayList[String]() + getDriverResultsMethod.invoke(driver, res) + res.asScala + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + } + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) + } + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + +} + +private[client] class Shim_v0_13 extends Shim_v0_12 { + + private lazy val setCurrentSessionStateMethod = + findStaticMethod( + classOf[SessionState], + "setCurrentSessionState", + classOf[SessionState]) + private lazy val setDataLocationMethod = + findMethod( + classOf[Table], + "setDataLocation", + classOf[Path]) + private lazy val getAllPartitionsMethod = + findMethod( + classOf[Hive], + "getAllPartitionsOf", + classOf[Table]) + private lazy val getPartitionsByFilterMethod = + findMethod( + classOf[Hive], + "getPartitionsByFilter", + classOf[Table], + classOf[String]) + private lazy val getCommandProcessorMethod = + findStaticMethod( + classOf[CommandProcessorFactory], + "get", + classOf[Array[String]], + classOf[HiveConf]) + private lazy val getDriverResultsMethod = + findMethod( + classOf[Driver], + "getResults", + classOf[JList[Object]]) + + override def setCurrentSessionState(state: SessionState): Unit = + setCurrentSessionStateMethod.invoke(null, state) + + override def setDataLocation(table: Table, loc: String): Unit = + setDataLocationMethod.invoke(table, new Path(loc)) + + override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq + + /** + * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. + * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". + * + * Unsupported predicates are skipped. + */ + def convertFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} "$v"""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s""""$v" ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + + // Hive getPartitionsByFilter() takes a string that represents partition + // predicates like "str_key=\"value\" and int_key=1 ..." + val filter = convertFilters(table, predicates) + val partitions = + if (filter.isEmpty) { + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + } else { + logDebug(s"Hive metastore filter is '$filter'.") + getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] + } + + partitions.asScala.toSeq + } + + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] + + override def getDriverResults(driver: Driver): Seq[String] = { + val res = new JArrayList[Object]() + getDriverResultsMethod.invoke(driver, res) + res.asScala.map { r => + r match { + case s: String => s + case a: Array[Object] => a(0).asInstanceOf[String] + } + } + } + +} + +private[client] class Shim_v0_14 extends Shim_v0_13 { + + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val getTimeVarMethod = + findMethod( + classOf[HiveConf], + "getTimeVar", + classOf[HiveConf.ConfVars], + classOf[TimeUnit]) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + getTimeVarMethod.invoke( + conf, + HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, + TimeUnit.MILLISECONDS).asInstanceOf[Long] + } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + +} + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0L: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala new file mode 100644 index 000000000000..785603750841 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.File +import java.lang.reflect.InvocationTargetException +import java.net.{URL, URLClassLoader} +import java.util + +import scala.language.reflectiveCalls +import scala.util.Try + +import org.apache.commons.io.{FileUtils, IOUtils} + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkSubmitUtils +import org.apache.spark.util.Utils + +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.HiveContext + +/** Factory for `IsolatedClientLoader` with specific versions of hive. */ +private[hive] object IsolatedClientLoader { + /** + * Creates isolated Hive client loaders by downloading the requested version from maven. + */ + def forVersion( + version: String, + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { + val resolvedVersion = hiveVersion(version) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) + new IsolatedClientLoader( + version = hiveVersion(version), + execJars = files, + config = config, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) + } + + def hiveVersion(version: String): HiveVersion = version match { + case "12" | "0.12" | "0.12.0" => hive.v12 + case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 + case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 + } + + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + val hiveArtifacts = version.extraDeps ++ + Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") + .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ + Seq("com.google.guava:guava:14.0.1", + "org.apache.hadoop:hadoop-client:2.4.0") + + val classpath = quietly { + SparkSubmitUtils.resolveMavenCoordinates( + hiveArtifacts.mkString(","), + Some("http://www.datanucleus.org/downloads/maven2"), + ivyPath, + exclusions = version.exclusions) + } + val allFiles = classpath.split(",").map(new File(_)).toSet + + // TODO: Remove copy logic. + val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}") + allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) + tempDir.listFiles().map(_.toURI.toURL) + } + + private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] +} + +/** + * Creates a Hive `ClientInterface` using a classloader that works according to the following rules: + * - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` + * allowing the results of calls to the `ClientInterface` to be visible externally. + * - Hive classes: new instances are loaded from `execJars`. These classes are not + * accessible externally due to their custom loading. + * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`. + * This new instance is able to see a specific version of hive without using reflection. Since + * this is a unique instance, it is not visible externally other than as a generic + * `ClientInterface`, unless `isolationOn` is set to `false`. + * + * @param version The version of hive on the classpath. used to pick specific function signatures + * that are not compatible across versions. + * @param execJars A collection of jar files that must include hive and hadoop. + * @param config A set of options that will be added to the HiveConf of the constructed client. + * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be + * true unless loading the version of hive that is on Sparks classloader. + * @param rootClassLoader The system root classloader. Must not know about Hive classes. + * @param baseClassLoader The spark classloader that is used to load shared classes. + */ +private[hive] class IsolatedClientLoader( + val version: HiveVersion, + val execJars: Seq[URL] = Seq.empty, + val config: Map[String, String] = Map.empty, + val isolationOn: Boolean = true, + val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, + val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, + val sharedPrefixes: Seq[String] = Seq.empty, + val barrierPrefixes: Seq[String] = Seq.empty) + extends Logging { + + // Check to make sure that the root classloader does not know about Hive. + assert(Try(rootClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) + + /** All jars used by the hive specific classloader. */ + protected def allJars = execJars.toArray + + protected def isSharedClass(name: String): Boolean = + name.contains("slf4j") || + name.contains("log4j") || + name.startsWith("org.apache.spark.") || + name.startsWith("scala.") || + (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || + name.startsWith("java.lang.") || + name.startsWith("java.net") || + sharedPrefixes.exists(name.startsWith) + + /** True if `name` refers to a spark class that must see specific version of Hive. */ + protected def isBarrierClass(name: String): Boolean = + name.startsWith(classOf[ClientWrapper].getName) || + name.startsWith(classOf[Shim].getName) || + barrierPrefixes.exists(name.startsWith) + + protected def classToPath(name: String): String = + name.replaceAll("\\.", "/") + ".class" + + /** The classloader that is used to load an isolated version of Hive. */ + protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + val loaded = findLoadedClass(name) + if (loaded == null) doLoadClass(name, resolve) else loaded + } + + def doLoadClass(name: String, resolve: Boolean): Class[_] = { + val classFileName = name.replaceAll("\\.", "/") + ".class" + if (isBarrierClass(name) && isolationOn) { + // For barrier classes, we construct a new copy of the class. + val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) + logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") + defineClass(name, bytes, 0, bytes.length) + } else if (!isSharedClass(name)) { + logDebug(s"hive class: $name - ${getResource(classToPath(name))}") + super.loadClass(name, resolve) + } else { + // For shared classes, we delegate to baseClassLoader. + logDebug(s"shared class: $name") + baseClassLoader.loadClass(name) + } + } + } + + // Pre-reflective instantiation setup. + logDebug("Initializing the logger to avoid disaster...") + Thread.currentThread.setContextClassLoader(classLoader) + + /** The isolated client interface to Hive. */ + val client: ClientInterface = try { + classLoader + .loadClass(classOf[ClientWrapper].getName) + .getConstructors.head + .newInstance(version, config, classLoader) + .asInstanceOf[ClientInterface] + } catch { + case e: InvocationTargetException => + if (e.getCause().isInstanceOf[NoClassDefFoundError]) { + val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] + throw new ClassNotFoundException( + s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + + "Please make sure that jars for your version of hive and hadoop are included in the " + + s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + } else { + throw e + } + } finally { + Thread.currentThread.setContextClassLoader(baseClassLoader) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala new file mode 100644 index 000000000000..b1b8439efa01 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +/** Support for interacting with different versions of the HiveMetastoreClient */ +package object client { + private[client] abstract class HiveVersion( + val fullVersion: String, + val extraDeps: Seq[String] = Nil, + val exclusions: Seq[String] = Nil) + + // scalastyle:off + private[hive] object hive { + case object v12 extends HiveVersion("0.12.0") + case object v13 extends HiveVersion("0.13.1") + + // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in + // maven central anymore, so override those with a version that exists. + // + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. + case object v14 extends HiveVersion("0.14.0", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + "org.apache.calcite:calcite-avatica:1.3.0-incubating"), + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.1", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + } + // scalastyle:on + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index a547babcebff..8422287e177e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,40 +17,59 @@ package org.apache.spark.sql.hive.execution -import org.apache.hadoop.hive.ql.plan.CreateTableDesc - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** - * :: Experimental :: * Create table and insert the query result into it. - * @param database the database name of the new relation - * @param tableName the table name of the new relation + * @param tableDesc the Table Describe, which may contains serde, storage handler etc. * @param query the query whose result will be insert into the new relation * @param allowExisting allow continue working if it's already exists, otherwise * raise exception - * @param desc the CreateTableDesc, which may contains serde, storage handler etc. - */ -@Experimental +private[hive] case class CreateTableAsSelect( - database: String, - tableName: String, + tableDesc: HiveTable, query: LogicalPlan, - allowExisting: Boolean, - desc: Option[CreateTableDesc]) extends RunnableCommand { + allowExisting: Boolean) + extends RunnableCommand { + + def database: String = tableDesc.database + def tableName: String = tableDesc.name - override def run(sqlContext: SQLContext) = { + override def children: Seq[LogicalPlan] = Seq(query) + + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { - // Create Hive Table - hiveContext.catalog.createTable(database, tableName, query.output, allowExisting, desc) + import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + import org.apache.hadoop.io.Text + import org.apache.hadoop.mapred.TextInputFormat + + val withFormat = + tableDesc.copy( + inputFormat = + tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), + outputFormat = + tableDesc.outputFormat + .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)), + serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName()))) + + val withSchema = if (withFormat.schema.isEmpty) { + // Hive doesn't support specifying the column list for target table in CTAS + // However we don't think SparkSQL should follow that. + tableDesc.copy(schema = + query.output.map(c => + HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null))) + } else { + withFormat + } + + hiveContext.catalog.client.createTable(withSchema) // Get the Metastore Relation hiveContext.catalog.lookupRelation(Seq(database, tableName), None) match { @@ -64,17 +83,16 @@ case class CreateTableAsSelect( if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { - throw - new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") + throw new AnalysisException(s"$database.$tableName already exists.") } } else { - hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } Seq.empty[Row] } override def argString: String = { - s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index bfacc51ef57a..441b6b6033e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -17,34 +17,30 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} -import org.apache.spark.sql.execution.{SparkPlan, RunnableCommand} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} -import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". - * - * :: DeveloperApi :: */ -@DeveloperApi +private[hive] case class DescribeHiveTableCommand( table: MetastoreRelation, override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala results ++= columns.map(field => (field.getName, field.getType, field.getComment)) if (partitionColumns.nonEmpty) { val partColumnInfo = @@ -52,7 +48,7 @@ case class DescribeHiveTableCommand( results ++= partColumnInfo ++ Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ partColumnInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 781a2e9164c8..41b645b2c9c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -1,38 +1,34 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.types.StringType - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class HiveNativeCommand(sql: String) extends RunnableCommand { - - override def output = - Seq(AttributeReference("result", StringType, nullable = false)()) - - override def run(sqlContext: SQLContext) = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, SQLContext} + +private[hive] +case class HiveNativeCommand(sql: String) extends RunnableCommand { + + override def output: Seq[AttributeReference] = + Seq(AttributeReference("result", StringType, nullable = false)()) + + override def run(sqlContext: SQLContext): Seq[Row] = + sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index b56175fe7637..806d2b9b0b7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} @@ -26,25 +26,25 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.{BooleanType, DataType} /** - * :: DeveloperApi :: * The Hive table scan operator. Column and partition pruning are both handled. * * @param requestedAttributes Attributes to be fetched from the Hive table. * @param relation The Hive table be be scanned. * @param partitionPruningPred An optional partition pruning predicate for partitioned table. */ -@DeveloperApi +private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Option[Expression])( + partitionPruningPred: Seq[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -56,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.map { pred => + private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -64,7 +64,7 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } - // Create a local copy of hiveconf,so that scan specific modifications should not impact + // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) @@ -73,7 +73,7 @@ case class HiveTableScan( addColumnMetadataToConf(hiveExtraConf) @transient - private[this] val hadoopReader = + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { @@ -98,7 +98,7 @@ case class HiveTableScan( .asInstanceOf[StructObjectInspector] val columnTypeNames = structOI - .getAllStructFieldRefs + .getAllStructFieldRefs.asScala .map(_.getFieldObjectInspector) .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) .mkString(",") @@ -118,23 +118,23 @@ case class HiveTableScan( case None => partitions case Some(shouldKeep) => partitions.filter { part => val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) - } + val castedValues = part.getValues.asScala.zip(dataTypes) + .map { case (value, dataType) => castFromString(value, dataType) } // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = new GenericRow(castedValues.toArray) + val row = InternalRow.fromSeq(castedValues) shouldKeep.eval(row).asInstanceOf[Boolean] } } } - override def execute() = if (!relation.hiveQlTable.isPartitioned) { + protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) } - override def output = attributes + override def output: Seq[Attribute] = attributes } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 91af35f0965c..58f7fa640e8a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,42 +19,40 @@ package org.apache.spark.sql.hive.execution import java.util -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import org.apache.spark.sql.types.DataType +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.util.SerializableJobConf -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, partition: Map[String, Option[String]], child: SparkPlan, - overwrite: Boolean) extends UnaryNode with HiveInspectors { + overwrite: Boolean, + ifNotExists: Boolean) extends UnaryNode with HiveInspectors { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -62,13 +60,13 @@ case class InsertIntoHiveTable( serializer } - def output = child.output + def output: Seq[Attribute] = Seq.empty def saveAsHiveFile( - rdd: RDD[Row], + rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: SerializableWritable[JobConf], + conf: SerializableJobConf, writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -76,7 +74,6 @@ case class InsertIntoHiveTable( val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName assert(outputFileFormatClassName != null, "Output format class not set") conf.value.set("mapred.output.format.class", outputFileFormatClassName) - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( conf.value, @@ -88,7 +85,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -96,8 +93,10 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val wrappers = fieldOIs.map(wrapperFor) + val fieldOIs = standardOI.getAllStructFieldRefs.asScala + .map(_.getFieldObjectInspector).toArray + val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} val outputData = new Array[Any](fieldOIs.length) writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) @@ -105,12 +104,12 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) i += 1 } writerContainer - .getLocalFileWriter(row) + .getLocalFileWriter(row, table.schema) .write(serializer.serialize(outputData, standardOI)) } @@ -130,7 +129,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) @@ -177,7 +176,7 @@ case class InsertIntoHiveTable( } val jobConf = new JobConf(sc.hiveconf) - val jobConfSer = new SerializableWritable(jobConf) + val jobConfSer = new SerializableJobConf(jobConf) val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) @@ -198,41 +197,51 @@ case class InsertIntoHiveTable( if (partition.nonEmpty) { // loadPartition call orders directories created on the iteration order of the this map - val orderedPartitionSpec = new util.LinkedHashMap[String,String]() - table.hiveQlTable.getPartCols().foreach{ - entry=> - orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + val orderedPartitionSpec = new util.LinkedHashMap[String, String]() + table.hiveQlTable.getPartCols.asScala.foreach { entry => + orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } - val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath.toString, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + // scalastyle:off + // ifNotExists is only valid with static partition, refer to + // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries + // scalastyle:on + val oldPart = + catalog.client.getPartitionOption( + catalog.client.getTable(table.databaseName, table.tableName), + partitionSpec.asJava) + + if (oldPart.isEmpty || !ifNotExists) { + catalog.client.loadPartition( + outputPath.toString, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } } else { - db.loadTable( - outputPath, + catalog.client.loadTable( + outputPath.toString, // TODO: URI qualifiedTableName, overwrite, holdDDLTime) @@ -250,5 +259,7 @@ case class InsertIntoHiveTable( override def executeCollect(): Array[Row] = sideEffectResult.toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + protected override def doExecute(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c54fbb6e2469..c7651daffe36 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,40 +17,37 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, InputStreamReader} -import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.io._ import java.util.Properties +import javax.annotation.Nullable + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe -import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.io.Writable -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.util.Utils - - -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.{Logging, TaskContext} /** - * :: DeveloperApi :: * Transforms the input by forking and running the specified script. * * @param input the set of expression that should be passed to the script. * @param script the command that should be executed. * @param output the attributes that are produced by the script. */ -@DeveloperApi +private[hive] case class ScriptTransformation( input: Seq[Expression], script: String, @@ -59,115 +56,225 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) extends UnaryNode { - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs: Seq[HiveContext] = sc :: Nil - def execute() = { - child.execute().mapPartitions { iter => + protected override def doExecute(): RDD[InternalRow] = { + def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd) + val builder = new ProcessBuilder(cmd.asJava) + val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream - val reader = new BufferedReader(new InputStreamReader(inputStream)) - - val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) + val errorStream = proc.getErrorStream - val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { - var cacheRow: Row = null + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator, + input.map(_.dataType), + outputProjection, + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get() + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } + + val reader = new BufferedReader(new InputStreamReader(inputStream)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null - var eof: Boolean = false + val scriptOutputStream = new DataInputStream(inputStream) + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().newInstance + } else { + null + } + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) override def hasNext: Boolean = { if (outputSerde == null) { if (curLine == null) { curLine = reader.readLine() - curLine != null + if (curLine == null) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } else { true } + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } else { - !eof - } - } - - def deserialize(): Row = { - if (cacheRow != null) return cacheRow - - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) - try { - val dataInputStream = new DataInputStream(inputStream) - val writable = outputSerde.getSerializedClass().newInstance - writable.readFields(dataInputStream) - - val raw = outputSerde.deserialize(writable) - val dataList = outputSoi.getStructFieldsDataAsList(raw) - val fieldList = outputSoi.getAllStructFieldRefs() - - var i = 0 - dataList.foreach( element => { - if (element == null) { - mutableRow.setNullAt(i) - } else { - mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector) - } - i += 1 - }) - return mutableRow - } catch { - case e: EOFException => - eof = true - return null + true } } - override def next(): Row = { + override def next(): InternalRow = { if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericRow( + new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .asInstanceOf[Array[Any]]) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericRow( + new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .asInstanceOf[Array[Any]]) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { - val ret = deserialize() - if (!eof) { - cacheRow = null - cacheRow = deserialize() + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + val fieldList = outputSoi.getAllStructFieldRefs() + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector) + } + i += 1 } - ret + mutableRow } } } - val (inputSerde, inputSoi) = ioschema.initInputSerDe(input) - val dataOutputStream = new DataOutputStream(outputStream) - val outputProjection = new InterpretedProjection(input, child.output) + writerThread.start() + + outputIterator + } + + child.execute().mapPartitions { iter => + if (iter.hasNext) { + processIterator(iter) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } +} + +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + outputProjection: Projection, + @Nullable inputSerde: AbstractSerDe, + @Nullable inputSoi: ObjectInspector, + ioschema: HiveScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext + ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + + setDaemon(true) + + @volatile private var _exception: Throwable = null + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) + override def run(): Unit = Utils.logUncaughtExceptions { + TaskContext.setTaskContext(taskContext) + + val dataOutputStream = new DataOutputStream(outputStream) + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + val len = inputSchema.length + try { + iter.map(outputProjection).foreach { row => + if (inputSerde == null) { + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() } + outputStream.write(data.getBytes("utf-8")) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) } + } outputStream.close() - iterator + threwException = false + } catch { + case NonFatal(e) => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = e + proc.destroy() + throw e + } finally { + try { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } } } } @@ -175,96 +282,69 @@ case class ScriptTransformation( /** * The wrapper class of Hive input and output schema properties */ +private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], - inputSerdeClass: String, - outputSerdeClass: String, + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { - val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n")) + private val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - - def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) - (serde, initInputSoi(serde, columns, columnTypes)) - } - def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = { - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps) - (serde, initOutputputSoi(serde)) + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) + .asInstanceOf[ObjectInspector] + (serde, objectInspector) + } } - def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - - val columns = attrs.map { - case aref: AttributeReference => aref.name - case e: NamedExpression => e.name - case _ => null - } - - val columnTypes = attrs.map { - case aref: AttributeReference => aref.dataType - case e: NamedExpression => e.dataType - case _ => null + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) } + } + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) (columns, columnTypes) } - - def initSerDe(serdeClassName: String, columns: Seq[String], - columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { - - val serde: AbstractSerDe = if (serdeClassName != "") { - val trimed_class = serdeClassName.split("'")(1) - Utils.classForName(trimed_class) - .newInstance.asInstanceOf[AbstractSerDe] - } else { - null - } - if (serde != null) { - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - - val properties = new Properties() - properties.putAll(propsMap) - serde.initialize(null, properties) - } + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { - serde - } + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] - def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType]) - : ObjectInspector = { + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - if (inputSerde != null) { - val fieldObjectInspectors = columnTypes.map(toInspector(_)) - ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) - .asInstanceOf[ObjectInspector] - } else { - null - } - } - - def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { - if (outputSerde != null) { - outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] - } else { - null - } + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + + val properties = new Properties() + properties.putAll(propsMap.asJava) + serde.initialize(null, properties) + + serde } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 95dcaccefdc5..d1699dd53681 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,53 +17,57 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.sources.ResolvedDataSource -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser} +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. * * Right now, it only supports Hive tables and it only updates the size of a Hive table * in the Hive metastore. */ -@DeveloperApi +private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) Seq.empty[Row] } } /** - * :: DeveloperApi :: * Drops a table from the metastore and removes it if it is cached. */ -@DeveloperApi +private[hive] case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - // Other exceptions can be caused by users providing wrong parameters in OPTIONS + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => + // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an exception. - case e: Exception => log.warn(s"${e.getMessage}") + // Users should be able to drop such kinds of tables regardless if there is an error. + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") @@ -72,27 +76,45 @@ case class DropTable( } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class AddJar(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] + val currentClassLoader = Utils.getContextOrSparkClassLoader + + // Add jar to current context + val jarURL = new java.io.File(path).toURI.toURL + val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader) + Thread.currentThread.setContextClassLoader(newClassLoader) + // We need to explicitly set the class loader associated with the conf in executionHive's + // state because this class loader will be used as the context class loader of the current + // thread to execute any Hive command. + // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get() + // returns the value of a thread local variable and its HiveConf may not be the HiveConf + // associated with `executionHive.state` (for example, HiveContext is created in one thread + // and then add jar is called from another thread). + hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader) + // Add jar to isolated hive (metadataHive) class loader. hiveContext.runSqlHive(s"ADD JAR $path") + + // Add jar to executors hiveContext.sparkContext.addJar(path) - Seq.empty[Row] + + Seq(Row(0)) } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) @@ -100,36 +122,55 @@ case class AddFile(path: String) extends RunnableCommand { } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). +private[hive] case class CreateMetastoreDataSource( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], - allowExisting: Boolean) extends RunnableCommand { + allowExisting: Boolean, + managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (hiveContext.catalog.tableExists(tableIdent.toSeq)) { if (allowExisting) { return Seq.empty[Row] } else { - sys.error(s"Table $tableName already exists.") + throw new AnalysisException(s"Table $tableName already exists.") } } var isExternal = true val optionsWithPath = - if (!options.contains("path")) { + if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, userSpecifiedSchema, + Array.empty[String], provider, optionsWithPath, isExternal) @@ -138,44 +179,120 @@ case class CreateMetastoreDataSource( } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). +private[hive] case class CreateMetastoreDataSourceAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, + partitionColumns: Array[String], + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - - if (hiveContext.catalog.tableExists(tableName :: Nil)) { - if (allowExisting) { - return Seq.empty[Row] - } else { - sys.error(s"Table $tableName already exists.") - } + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") } - val df = DataFrame(hiveContext, query) + val tableName = tableIdent.unquotedString + val hiveContext = sqlContext.asInstanceOf[HiveContext] + var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } + var existingSchema = None: Option[StructType] + if (sqlContext.catalog.tableExists(tableIdent.toSeq)) { + // Check if we need to throw an exception or just return. + mode match { + case SaveMode.ErrorIfExists => + throw new AnalysisException(s"Table $tableName already exists. " + + s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " + + s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" + + s"the existing data. " + + s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.") + case SaveMode.Ignore => + // Since the table already exists and the save mode is Ignore, we will just return. + return Seq.empty[Row] + case SaveMode.Append => + // Check if the specified data source match the data source of the existing table. + val resolved = ResolvedDataSource( + sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) + val createdRelation = LogicalRelation(resolved.relation) + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match { + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => + if (l.relation != createdRelation.relation) { + val errorDescription = + s"Cannot append to table $tableName because the resolved relation does not " + + s"match the existing relation of $tableName. " + + s"You can use insertInto($tableName, false) to append this DataFrame to the " + + s"table $tableName and using its data source and options." + val errorMessage = + s""" + |$errorDescription + |== Relations == + |${sideBySide( + s"== Expected Relation ==" :: l.toString :: Nil, + s"== Actual Relation ==" :: createdRelation.toString :: Nil + ).mkString("\n")} + """.stripMargin + throw new AnalysisException(errorMessage) + } + existingSchema = Some(l.schema) + case o => + throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") + } + case SaveMode.Overwrite => + hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + // Need to create the table again. + createMetastoreTable = true + } + } else { + // The table does not exist. We need to create it in metastore. + createMetastoreTable = true + } + + val data = DataFrame(hiveContext, query) + val df = existingSchema match { + // If we are inserting into an existing table, just use the existing schema. + case Some(schema) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema) + case None => data + } + // Create the relation based on the data of df. - ResolvedDataSource(sqlContext, provider, optionsWithPath, df) + val resolved = + ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) - hiveContext.catalog.createDataSourceTable( - tableName, - None, - provider, - optionsWithPath, - isExternal) + if (createMetastoreTable) { + // We will use the schema of resolved.relation as the schema of the table (instead of + // the schema of df). It is important since the nullability may be changed by the relation + // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). + hiveContext.catalog.createDataSourceTable( + tableIdent, + Some(resolved.relation.schema), + partitionColumns, + provider, + optionsWithPath, + isExternal) + } + // Refresh the cache of the table in the catalog. + hiveContext.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala new file mode 100644 index 000000000000..cad02373e5ba --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -0,0 +1,632 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory +import org.apache.hadoop.hive.ql.exec._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper + +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.types._ + + +private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) + extends analysis.FunctionRegistry with HiveInspectors { + + def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + Try(underlying.lookupFunction(name, children)).getOrElse { + // We only look it up to see if it exists, but do not include it in the HiveUDF since it is + // not always serializable. + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + throw new AnalysisException(s"undefined function $name")) + + val functionClassName = functionInfo.getFunctionClass.getName + + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + } else { + sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + } + } + } + + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = underlying.registerFunction(name, info, builder) + + /* List all of the registered function names. */ + override def listFunction(): Seq[String] = { + (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted + } + + /* Get the class of the registered function by specified name. */ + override def lookupFunction(name: String): Option[ExpressionInfo] = { + underlying.lookupFunction(name).orElse( + Try { + val info = FunctionRegistry.getFunctionInfo(name) + val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) + if (annotation != null) { + Some(new ExpressionInfo( + info.getFunctionClass.getCanonicalName, + annotation.name(), + annotation.value(), + annotation.extended())) + } else { + None + } + }.getOrElse(None)) + } +} + +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with CodegenFallback with Logging { + + override def deterministic: Boolean = isUDFDeterministic + + override def nullable: Boolean = true + + @transient + lazy val function = funcWrapper.createFunction[UDF]() + + @transient + private lazy val method = + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) + + @transient + private lazy val arguments = children.map(toInspector).toArray + + @transient + private lazy val isUDFDeterministic = { + val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() + } + + override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) + + // Create parameter converters + @transient + private lazy val conversionHelper = new ConversionHelper(method, arguments) + + val dataType = javaClassToDataType(method.getReturnType) + + @transient + lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( + method.getGenericReturnType(), ObjectInspectorOptions.JAVA) + + @transient + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + // TODO: Finish input output types. + override def eval(input: InternalRow): Any = { + val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes) + val ret = FunctionRegistry.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs : _*): _*) + unwrap(ret, returnInspector) + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +// Adapter from Catalyst ExpressionResult to Hive DeferredObject +private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) + extends DeferredObject with HiveInspectors { + + private var func: () => Any = _ + def set(func: () => Any): Unit = { + this.func = func + } + override def prepare(i: Int): Unit = {} + override def get(): AnyRef = wrap(func(), oi, dataType) +} + +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with CodegenFallback with Logging { + + override def nullable: Boolean = true + + override def deterministic: Boolean = isUDFDeterministic + + override def foldable: Boolean = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] + + @transient + lazy val function = funcWrapper.createFunction[GenericUDF]() + + @transient + private lazy val argumentInspectors = children.map(toInspector) + + @transient + private lazy val returnInspector = { + function.initializeAndFoldConstants(argumentInspectors.toArray) + } + + @transient + private lazy val isUDFDeterministic = { + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() + } + + @transient + private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] + + lazy val dataType: DataType = inspectorToDataType(returnInspector) + + override def eval(input: InternalRow): Any = { + returnInspector // Make sure initialized. + + var i = 0 + while (i < children.length) { + val idx = i + deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + () => { + children(idx).eval(input) + }) + i += 1 + } + unwrap(function.evaluate(deferedObjects), returnInspector) + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +/** + * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. + */ +private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p: LogicalPlan if !p.childrenResolved => p + + // We are resolving WindowExpressions at here. When we get here, we have already + // replaced those WindowSpecReferences. + case p: LogicalPlan => + p transformExpressions { + case WindowExpression( + UnresolvedWindowFunction(name, children), + windowSpec: WindowSpecDefinition) => + // First, let's find the window function info. + val windowFunctionInfo: WindowFunctionInfo = + Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( + throw new AnalysisException(s"Couldn't find window function $name")) + + // Get the class of this function. + // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use + // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. + val functionClass = windowFunctionInfo.getFunctionClass() + val newChildren = + // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit + // input parameters and requires implicit parameters, which + // are expressions in Order By clause. + if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { + if (children.nonEmpty) { + throw new AnalysisException(s"$name does not take input parameters.") + } + windowSpec.orderSpec.map(_.child) + } else { + children + } + + // If the class is UDAF, we need to use UDAFBridge. + val isUDAFBridgeRequired = + if (classOf[UDAF].isAssignableFrom(functionClass)) { + true + } else { + false + } + + // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of + // HiveWindowFunction. + val windowFunction = + HiveWindowFunction( + new HiveFunctionWrapper(functionClass.getName), + windowFunctionInfo.isPivotResult, + isUDAFBridgeRequired, + newChildren) + + // Second, check if the specified window function can accept window definition. + windowSpec.frameSpecification match { + case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => + // This Hive window function does not support user-speficied window frame. + throw new AnalysisException( + s"Window function $name does not take a frame specification.") + case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && + windowFunctionInfo.isPivotResult => + // These two should not be true at the same time when a window frame is defined. + // If so, throw an exception. + throw new AnalysisException(s"Could not handle Hive window function $name because " + + s"it supports both a user specified window frame and pivot result.") + case _ => // OK + } + // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs + // a window frame specification to work. + val newWindowSpec = windowSpec.frameSpecification match { + case UnspecifiedFrame => + val newWindowFrame = + SpecifiedWindowFrame.defaultWindowFrame( + windowSpec.orderSpec.nonEmpty, + windowFunctionInfo.isSupportsWindow) + WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) + case _ => windowSpec + } + + // Finally, we create a WindowExpression with the resolved window function and + // specified window spec. + WindowExpression(windowFunction, newWindowSpec) + } + } +} + +/** + * A [[WindowFunction]] implementation wrapping Hive's window function. + * @param funcWrapper The wrapper for the Hive Window Function. + * @param pivotResult If it is true, the Hive function will return a list of values representing + * the values of the added columns. Otherwise, a single value is returned for + * current row. + * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's + * createFunction is UDAF, we need to use GenericUDAFBridge to wrap + * it as a GenericUDAFResolver2. + * @param children Input parameters. + */ +private[hive] case class HiveWindowFunction( + funcWrapper: HiveFunctionWrapper, + pivotResult: Boolean, + isUDAFBridgeRequired: Boolean, + children: Seq[Expression]) extends WindowFunction + with HiveInspectors with Unevaluable { + + // Hive window functions are based on GenericUDAFResolver2. + type UDFType = GenericUDAFResolver2 + + @transient + protected lazy val resolver: GenericUDAFResolver2 = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + } else { + funcWrapper.createFunction[GenericUDAFResolver2]() + } + + @transient + protected lazy val inputInspectors = children.map(toInspector).toArray + + // The GenericUDAFEvaluator used to evaluate the window function. + @transient + protected lazy val evaluator: GenericUDAFEvaluator = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) + } + + // The object inspector of values returned from the Hive window function. + @transient + protected lazy val returnInspector = { + evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) + } + + override def dataType: DataType = + if (!pivotResult) { + inspectorToDataType(returnInspector) + } else { + // If pivotResult is true, we should take the element type out as the data type of this + // function. + inspectorToDataType(returnInspector) match { + case ArrayType(dt, _) => dt + case _ => + sys.error( + s"error resolve the data type of window function ${funcWrapper.functionClassName}") + } + } + + override def nullable: Boolean = true + + @transient + lazy val inputProjection = new InterpretedProjection(children) + + @transient + private var hiveEvaluatorBuffer: AggregationBuffer = _ + // Output buffer. + private var outputBuffer: Any = _ + + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + override def init(): Unit = { + evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) + } + + // Reset the hiveEvaluatorBuffer and outputPosition + override def reset(): Unit = { + // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. + // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. + // However, RowNumberBuffer.init does not really reset this buffer. + hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer + evaluator.reset(hiveEvaluatorBuffer) + } + + override def prepareInputParameters(input: InternalRow): AnyRef = { + wrap( + inputProjection(input), + inputInspectors, + new Array[AnyRef](children.length), + inputDataTypes) + } + + // Add input parameters for a single row. + override def update(input: AnyRef): Unit = { + evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) + } + + override def batchUpdate(inputs: Array[AnyRef]): Unit = { + var i = 0 + while (i < inputs.length) { + evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) + i += 1 + } + } + + override def evaluate(): Unit = { + outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) + } + + override def get(index: Int): Any = { + if (!pivotResult) { + // if pivotResult is false, we will get a single value for all rows in the frame. + outputBuffer + } else { + // if pivotResult is true, we will get a ArrayData having the same size with the size + // of the window frame. At here, we will return the result at the position of + // index in the output buffer. + outputBuffer.asInstanceOf[ArrayData].get(index, dataType) + } + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } + + override def newInstance(): WindowFunction = + new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) +} + +private[hive] case class HiveGenericUDAF( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression]) extends AggregateExpression1 + with HiveInspectors { + + type UDFType = AbstractGenericUDAFResolver + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() + + @transient + protected lazy val objectInspector = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) + resolver.getEvaluator(parameterInfo) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } + + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) +} + +/** It is used as a wrapper for the hive functions which uses UDAF interface */ +private[hive] case class HiveUDAF( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression]) extends AggregateExpression1 + with HiveInspectors { + + type UDFType = UDAF + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = + new GenericUDAFBridge(funcWrapper.createFunction()) + + @transient + protected lazy val objectInspector = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) + resolver.getEvaluator(parameterInfo) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } + + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) +} + +/** + * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a + * [[Generator]]. Note that the semantics of Generators do not allow + * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning + * dependent operations like calls to `close()` before producing output will not operate the same as + * in Hive. However, in practice this should not affect compatibility for most sane UDTFs + * (e.g. explode or GenericUDTFParseUrlTuple). + * + * Operators that require maintaining state in between input rows should instead be implemented as + * user defined aggregations, which have clean semantics even in a partitioned execution. + */ +private[hive] case class HiveGenericUDTF( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression]) + extends Generator with HiveInspectors with CodegenFallback { + + @transient + protected lazy val function: GenericUDTF = { + val fun: GenericUDTF = funcWrapper.createFunction() + fun.setCollector(collector) + fun + } + + @transient + protected lazy val inputInspectors = children.map(toInspector) + + @transient + protected lazy val outputInspector = function.initialize(inputInspectors.toArray) + + @transient + protected lazy val udtInput = new Array[AnyRef](children.length) + + @transient + protected lazy val collector = new UDTFCollector + + lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true) + } + + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + outputInspector // Make sure initialized. + + val inputProjection = new InterpretedProjection(children) + + function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) + collector.collectRows() + } + + protected class UDTFCollector extends Collector { + var collected = new ArrayBuffer[InternalRow] + + override def collect(input: java.lang.Object) { + // We need to clone the input here because implementations of + // GenericUDTF reuse the same object. Luckily they are always an array, so + // it is easy to clone. + collected += unwrap(input, outputInspector).asInstanceOf[InternalRow] + } + + def collectRows(): Seq[InternalRow] = { + val toCollect = collected + collected = new ArrayBuffer[InternalRow] + toCollect + } + } + + override def terminate(): TraversableOnce[InternalRow] = { + outputInspector // Make sure initialized. + function.close() + collector.collectRows() + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +private[hive] case class HiveUDAFFunction( + funcWrapper: HiveFunctionWrapper, + exprs: Seq[Expression], + base: AggregateExpression1, + isUDAFBridgeRequired: Boolean = false) + extends AggregateFunction1 + with HiveInspectors { + + def this() = this(null, null, null) + + private val resolver = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + } else { + funcWrapper.createFunction[AbstractGenericUDAFResolver]() + } + + private val inspectors = exprs.map(toInspector).toArray + + private val function = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) + resolver.getEvaluator(parameterInfo) + } + + private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + + private val buffer = + function.getNewAggregationBuffer + + override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) + + @transient + val inputProjection = new InterpretedProjection(exprs) + + @transient + protected lazy val cached = new Array[AnyRef](exprs.length) + + @transient + private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + + def update(input: InternalRow): Unit = { + val inputs = inputProjection(input) + function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) + } +} + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala deleted file mode 100644 index 76d214037219..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper - -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory -import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils.getContextOrSparkClassLoader - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - -private[hive] abstract class HiveFunctionRegistry - extends analysis.FunctionRegistry with HiveInspectors { - - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) - - def lookupFunction(name: String, children: Seq[Expression]): Expression = { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) - } else { - sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") - } - } -} - -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { - type EvaluatedType = Any - type UDFType = UDF - - def nullable = true - - @transient - lazy val function = funcWrapper.createFunction[UDFType]() - - @transient - protected lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) - - @transient - protected lazy val arguments = children.map(toInspector).toArray - - @transient - protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() - } - - override def foldable = isUDFDeterministic && children.forall(_.foldable) - - // Create parameter converters - @transient - protected lazy val conversionHelper = new ConversionHelper(method, arguments) - - @transient - lazy val dataType = javaClassToDataType(method.getReturnType) - - @transient - lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType(), ObjectInspectorOptions.JAVA) - - @transient - protected lazy val cached = new Array[AnyRef](children.length) - - // TODO: Finish input output types. - override def eval(input: Row): Any = { - unwrap( - FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), - returnInspector) - } - - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" -} - -// Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector) - extends DeferredObject with HiveInspectors { - private var func: () => Any = _ - def set(func: () => Any) { - this.func = func - } - override def prepare(i: Int) = {} - override def get(): AnyRef = wrap(func(), oi) -} - -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { - type UDFType = GenericUDF - type EvaluatedType = Any - - def nullable = true - - @transient - lazy val function = funcWrapper.createFunction[UDFType]() - - @transient - protected lazy val argumentInspectors = children.map(toInspector) - - @transient - protected lazy val returnInspector = { - function.initializeAndFoldConstants(argumentInspectors.toArray) - } - - @transient - protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) - } - - override def foldable = - isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - - @transient - protected lazy val deferedObjects = - argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] - - lazy val dataType: DataType = inspectorToDataType(returnInspector) - - override def eval(input: Row): Any = { - returnInspector // Make sure initialized. - - var i = 0 - while (i < children.length) { - val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( - () => { - children(idx).eval(input) - }) - i += 1 - } - unwrap(function.evaluate(deferedObjects), returnInspector) - } - - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" -} - -private[hive] case class HiveGenericUdaf( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression - with HiveInspectors { - - type UDFType = AbstractGenericUDAFResolver - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - - def newInstance() = new HiveUdafFunction(funcWrapper, children, this) -} - -/** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression - with HiveInspectors { - - type UDFType = UDAF - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = - new GenericUDAFBridge(funcWrapper.createFunction()) - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - - def newInstance() = - new HiveUdafFunction(funcWrapper, children, this, true) -} - -/** - * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow - * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning - * dependent operations like calls to `close()` before producing output will not operate the same as - * in Hive. However, in practice this should not affect compatibility for most sane UDTFs - * (e.g. explode or GenericUDTFParseUrlTuple). - * - * Operators that require maintaining state in between input rows should instead be implemented as - * user defined aggregations, which have clean semantics even in a partitioned execution. - */ -private[hive] case class HiveGenericUdtf( - funcWrapper: HiveFunctionWrapper, - aliasNames: Seq[String], - children: Seq[Expression]) - extends Generator with HiveInspectors { - - @transient - protected lazy val function: GenericUDTF = funcWrapper.createFunction() - - @transient - protected lazy val inputInspectors = children.map(toInspector) - - @transient - protected lazy val outputInspector = function.initialize(inputInspectors.toArray) - - @transient - protected lazy val udtInput = new Array[AnyRef](children.length) - - protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map { - field => inspectorToDataType(field.getFieldObjectInspector) - } - - override protected def makeOutput() = { - // Use column names when given, otherwise _c1, _c2, ... _cn. - if (aliasNames.size == outputDataTypes.size) { - aliasNames.zip(outputDataTypes).map { - case (attrName, attrDataType) => - AttributeReference(attrName, attrDataType, nullable = true)() - } - } else { - outputDataTypes.zipWithIndex.map { - case (attrDataType, i) => - AttributeReference(s"_c$i", attrDataType, nullable = true)() - } - } - } - - override def eval(input: Row): TraversableOnce[Row] = { - outputInspector // Make sure initialized. - - val inputProjection = new InterpretedProjection(children) - val collector = new UDTFCollector - function.setCollector(collector) - function.process(wrap(inputProjection(input), inputInspectors, udtInput)) - collector.collectRows() - } - - protected class UDTFCollector extends Collector { - var collected = new ArrayBuffer[Row] - - override def collect(input: java.lang.Object) { - // We need to clone the input here because implementations of - // GenericUDTF reuse the same object. Luckily they are always an array, so - // it is easy to clone. - collected += unwrap(input, outputInspector).asInstanceOf[Row] - } - - def collectRows() = { - val toCollect = collected - collected = new ArrayBuffer[Row] - toCollect - } - } - - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" -} - -private[hive] case class HiveUdafFunction( - funcWrapper: HiveFunctionWrapper, - exprs: Seq[Expression], - base: AggregateExpression, - isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction - with HiveInspectors { - - def this() = this(null, null, null) - - private val resolver = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[AbstractGenericUDAFResolver]() - } - - private val inspectors = exprs.map(toInspector).toArray - - private val function = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - - // Cast required to avoid type inference selecting a deprecated Hive API. - private val buffer = - function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] - - override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) - - @transient - val inputProjection = new InterpretedProjection(exprs) - - @transient - protected lazy val cached = new Array[AnyRef](exprs.length) - - def update(input: Row): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) - } -} - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index aae175e426ad..8dc796b056a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import java.io.IOException import java.text.NumberFormat import java.util.Date @@ -30,12 +29,16 @@ import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} -import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -56,7 +59,7 @@ private[hive] class SparkHiveWriterContainer( PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } - protected val conf = new SerializableWritable(jobConf) + protected val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 @@ -69,7 +72,7 @@ private[hive] class SparkHiveWriterContainer( @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = - conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] def driverSideSetup() { setIDs(0, 0, 0) @@ -92,7 +95,9 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { + writer + } def close() { // Seems the boolean value passed into close does not matter. @@ -117,19 +122,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -178,7 +171,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( - ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ @@ -207,16 +200,31 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else String.valueOf(rawVal) - s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" + override def getLocalFileWriter(row: InternalRow, schema: StructType) + : FileSinkOperator.RecordWriter = { + def convertToHiveRawString(col: String, value: Any): String = { + val raw = String.valueOf(value) + schema(col).dataType match { + case DateType => DateTimeUtils.dateToString(raw.toInt) + case _: DecimalType => BigDecimal(raw).toString() + case _ => raw } - .mkString + } - def newWriter = { + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString + + def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( fileSinkConf.getDirName + dynamicPartPath, fileSinkConf.getTableInfo, @@ -224,12 +232,11 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) - val path = { - val outputPath = FileOutputFormat.getOutputPath(conf.value) - assert(outputPath != null, "Undefined job output-path") - val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) - new Path(workPath, getOutputName) - } + // use the path like ${hive_tmp}/_temporary/${attemptId}/ + // to avoid write to the same file when `spark.speculation=true` + val path = FileOutputFormat.getTaskOutputPath( + conf.value, + dynamicPartPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, @@ -240,6 +247,6 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( Reporter.NULL) } - writers.getOrElseUpdate(dynamicPartPath, newWriter) + writers.getOrElseUpdate(dynamicPartPath, newWriter()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala new file mode 100644 index 000000000000..0f9a1a6ef3b2 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType + +private[orc] object OrcFileOperator extends Logging { + /** + * Retrieves a ORC file reader from a given path. The path can point to either a directory or a + * single ORC file. If it points to an directory, it picks any non-empty ORC file within that + * directory. + * + * The reader returned by this method is mainly used for two purposes: + * + * 1. Retrieving file metadata (schema and compression codecs, etc.) + * 2. Read the actual file content (in this case, the given path should point to the target file) + * + * @note As recorded by SPARK-8501, ORC writes an empty schema (struct<> + logInfo( + s"ORC file $path has empty schema, it probably contains no rows. " + + "Trying to read another ORC file to figure out the schema.") + false + case _ => true + } + } + + val conf = config.getOrElse(new Configuration) + val fs = { + val hdfsPath = new Path(basePath) + hdfsPath.getFileSystem(conf) + } + + listOrcFiles(basePath, conf).iterator.map { path => + path -> OrcFile.createReader(fs, path) + }.collectFirst { + case (path, reader) if isWithNonEmptySchema(path, reader) => reader + } + } + + def readSchema(path: String, conf: Option[Configuration]): StructType = { + val reader = getFileReader(path, conf).getOrElse { + throw new AnalysisException( + s"Failed to discover schema from ORC files stored in $path. " + + "Probably there are either no ORC files or only empty ORC files.") + } + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $path, got Hive schema string: $schema") + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] + } + + def getObjectInspector( + path: String, conf: Option[Configuration]): Option[StructObjectInspector] = { + getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) + } + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDir) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + + if (paths == null || paths.isEmpty) { + throw new IllegalArgumentException( + s"orcFileOperator: path $path does not have valid orc files matching the pattern") + } + + paths + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala new file mode 100644 index 000000000000..b3d9f7f71a27 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.serde2.io.DateWritable + +import org.apache.spark.Logging +import org.apache.spark.sql.sources._ + +/** + * It may be optimized by push down partial filters. But we are conservative here. + * Because if some filters fail to be parsed, the tree may be corrupted, + * and cannot be used anymore. + */ +private[orc] object OrcFilters extends Logging { + def createFilter(expr: Array[Filter]): Option[SearchArgument] = { + expr.reduceOption(And).flatMap { conjunction => + val builder = SearchArgumentFactory.newBuilder() + buildSearchArgument(conjunction, builder).map(_.build()) + } + } + + private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { + def newBuilder = SearchArgumentFactory.newBuilder() + + def isSearchableLiteral(value: Any): Boolean = value match { + // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. + case _: String | _: Long | _: Double | _: Byte | _: Short | _: Integer | _: Float => true + case _: DateWritable | _: HiveDecimal | _: HiveChar | _: HiveVarchar => true + case _ => false + } + + // lian: I probably missed something here, and had to end up with a pretty weird double-checking + // pattern when converting `And`/`Or`/`Not` filters. + // + // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, + // and `startNot()` mutate internal state of the builder instance. This forces us to translate + // all convertible filters with a single builder instance. However, before actually converting a + // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible + // filter is found, we may already end up with a builder whose internal state is inconsistent. + // + // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and + // then try to convert its children. Say we convert `left` child successfully, but find that + // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is + // inconsistent now. + // + // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + // children with brand new builders, and only do the actual conversion with the right builder + // instance when the children are proven to be convertible. + // + // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. + // Usage of builder methods mentioned above can only be found in test code, where all tested + // filters are known to be convertible. + + expression match { + case And(left, right) => + val tryLeft = buildSearchArgument(left, newBuilder) + val tryRight = buildSearchArgument(right, newBuilder) + + val conjunction = for { + _ <- tryLeft + _ <- tryRight + lhs <- buildSearchArgument(left, builder.startAnd()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + // For filter `left AND right`, we can still push down `left` even if `right` is not + // convertible, and vice versa. + conjunction + .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) + .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) + + case Or(left, right) => + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) + lhs <- buildSearchArgument(left, builder.startOr()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(child, newBuilder) + negate <- buildSearchArgument(child, builder.startNot()) + } yield negate.end() + + case EqualTo(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.equals(attribute, _)) + + case EqualNullSafe(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.nullSafeEquals(attribute, _)) + + case LessThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThan(attribute, _)) + + case LessThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThanEquals(attribute, _)) + + case GreaterThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThanEquals(attribute, _).end()) + + case GreaterThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThan(attribute, _).end()) + + case IsNull(attribute) => + Some(builder.isNull(attribute)) + + case IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + + case In(attribute, values) => + Option(values) + .filter(_.forall(isSearchableLiteral)) + .map(builder.in(attribute, _)) + + case _ => None + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala new file mode 100644 index 000000000000..1cff5cf9c354 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util.Properties + +import scala.collection.JavaConverters._ + +import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.SerializableConfiguration + +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "orc" + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + assert( + sqlContext.isInstanceOf[HiveContext], + "The ORC data source can only be used with HiveContext.") + + new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + } +} + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + + private val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map { f => + HiveMetastoreTypes.toMetastoreType(f.dataType) + }.mkString(":")) + + val serde = new OrcSerde + serde.initialize(context.getConfiguration, table) + serde + } + + // Object inspector converted from the schema of the relation to be written. + private val structOI = { + val typeInfo = + TypeInfoUtils.getTypeInfoFromTypeString( + HiveMetastoreTypes.toMetastoreType(dataSchema)) + + TypeInfoUtils + .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) + .asInstanceOf[StructObjectInspector] + } + + // Used to hold temporary `Writable` fields of the next row to be written. + private val reusableOutputBuffer = new Array[Any](dataSchema.length) + + // Used to convert Catalyst values into Hadoop `Writable`s. + private val wrappers = structOI.getAllStructFieldRefs.asScala + .zip(dataSchema.fields.map(_.dataType)) + .map { case (ref, dt) => + wrapperFor(ref.getFieldObjectInspector, dt) + }.toArray + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + + val conf = context.getConfiguration + val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") + val partition = context.getTaskAttemptID.getTaskID.getId + val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + + new OrcOutputFormat().getRecordWriter( + new Path(path, filename).getFileSystem(conf), + conf.asInstanceOf[JobConf], + new Path(path, filename).toString, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + var i = 0 + while (i < row.numFields) { + reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) + i += 1 + } + + recordWriter.write( + NullWritable.get(), + serializer.serialize(reusableOutputBuffer, structOI)) + } + + override def close(): Unit = { + if (recordWriterInstantiated) { + recordWriter.close(Reporter.NULL) + } + } +} + +private[sql] class OrcRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) + with Logging { + + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + + override val dataSchema: StructType = maybeDataSchema.getOrElse { + OrcFileOperator.readSchema( + paths.head, Some(sqlContext.sparkContext.hadoopConfiguration)) + } + + override def needConversion: Boolean = false + + override def equals(other: Any): Boolean = other match { + case that: OrcRelation => + paths.toSet == that.paths.toSet && + dataSchema == that.dataSchema && + schema == that.schema && + partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode( + paths.toSet, + dataSchema, + schema, + partitionColumns) + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[FileStatus]): RDD[Row] = { + val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes + OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + job.getConfiguration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + } + } +} + +private[orc] case class OrcTableScan( + attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + @transient inputPaths: Array[FileStatus]) + extends Logging + with HiveInspectors { + + @transient private val sqlContext = relation.sqlContext + + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) + } + + // Transform all given raw `Writable`s into `InternalRow`s. + private def fillObject( + path: String, + conf: Configuration, + iterator: Iterator[Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[InternalRow] = { + val deserializer = new OrcSerde + val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) + + // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero + // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty + // partition since we know that this file is empty. + maybeStructOI.map { soi => + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = fieldRefs.map(unwrapperFor) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: InternalRow + } + }.getOrElse { + Iterator.empty + } + } + + def execute(): RDD[InternalRow] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration + + // Tries to push down filters if ORC filter push-down is enabled + if (sqlContext.conf.orcFilterPushDown) { + OrcFilters.createFilter(filters).foreach { f => + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + // Sets requested columns + addColumnIds(attributes, relation, conf) + + if (inputPaths.isEmpty) { + // the input path probably be pruned, return an empty RDD. + return sqlContext.sparkContext.emptyRDD[InternalRow] + } + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + + val inputFormatClass = + classOf[OrcInputFormat] + .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] + + val rdd = sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + inputFormatClass, + classOf[NullWritable], + classOf[Writable] + ).asInstanceOf[HadoopRDD[NullWritable, Writable]] + + val wrappedConf = new SerializableConfiguration(conf) + + rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject( + split.getPath.toString, + wrappedConf.value, + iterator.map(_._2), + attributes.zipWithIndex, + mutableRow) + } + } +} + +private[orc] object OrcTableScan { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala index a6c8ed4f7e86..db074361ef03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala @@ -17,4 +17,14 @@ package org.apache.spark.sql +/** + * Support for running Spark SQL queries using functionality from Apache Hive (does not require an + * existing Hive installation). Supported Hive features include: + * - Using HiveQL to express queries. + * - Reading metadata from the Hive Metastore using HiveSerDes. + * - Hive UDFs, UDAs, UDTs + * + * Users that would like access to this functionality should create a + * [[hive.HiveContext HiveContext]] instead of a [[SQLContext]]. + */ package object hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala deleted file mode 100644 index 2a16c9d1a27c..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.parquet - -import java.util.Properties - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category -import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector -import org.apache.hadoop.io.Writable - -/** - * A placeholder that allows Spark SQL users to create metastore tables that are stored as - * parquet files. It is only intended to pass the checks that the serde is valid and exists - * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan - * when "spark.sql.hive.convertMetastoreParquet" is set to true. - */ -@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + - "placeholder in the Hive MetaStore", "1.2.0") -class FakeParquetSerDe extends SerDe { - override def getObjectInspector: ObjectInspector = new ObjectInspector { - override def getCategory: Category = Category.PRIMITIVE - - override def getTypeName: String = "string" - } - - override def deserialize(p1: Writable): AnyRef = throwError - - override def initialize(p1: Configuration, p2: Properties): Unit = {} - - override def getSerializedClass: Class[_ <: Writable] = throwError - - override def getSerDeStats: SerDeStats = throwError - - override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError - - private def throwError = - sys.error( - "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe") -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala new file mode 100644 index 000000000000..572eaebe81ac --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import java.io.File +import java.util.{Set => JavaSet} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.FunctionRegistry +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.CacheTableCommand +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.{SparkConf, SparkContext} + +// SPARK-3729: Test key required to check for initialization errors with config. +object TestHive + extends TestHiveContext( + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) + +/** + * A locally running test instance of Spark's Hive execution engine. + * + * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. + * Calling [[reset]] will delete all tables and other state in the database, leaving the database + * in a "clean" state. + * + * TestHive is singleton object version of this class because instantiating multiple copies of the + * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of + * test cases that rely on TestHive must be serialized. + */ +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { + self => + + import HiveContext._ + + // By clearing the port we force Spark to pick a new one. This allows us to rerun tests + // without restarting the JVM. + System.clearProperty("spark.hostPort") + CommandProcessorFactory.clean(hiveconf) + + hiveconf.set("hive.plan.serialization.format", "javaXML") + + lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") + + lazy val scratchDirPath = { + val dir = Utils.createTempDir(namePrefix = "scratch-") + dir.delete() + dir + } + + private lazy val temporaryConfig = newTemporaryConfiguration() + + /** Sets up the system initially or after a RESET command */ + protected override def configure(): Map[String, String] = { + super.configure() ++ temporaryConfig ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" + ) + } + + val testTempDir = Utils.createTempDir() + + // For some hive test case which contain ${system:test.tmp.dir} + System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + + /** The location of the compiled hive distribution */ + lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ + lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") + + // Override so we can intercept relative paths and rewrite them to point at hive. + override def runSqlHive(sql: String): Seq[String] = + super.runSqlHive(rewritePaths(substitutor.substitute(this.hiveconf, sql))) + + override def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution(plan) + + override protected[sql] def createSession(): SQLSession = { + new this.SQLSession() + } + + protected[hive] class SQLSession extends super.SQLSession { + /** Fewer partitions to speed up testing. */ + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. + // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" + override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + } + } + + /** + * Returns the value of specified environmental variable as a [[java.io.File]] after checking + * to ensure it exists + */ + private def envVarToFile(envVar: String): Option[File] = { + Option(System.getenv(envVar)).map(new File(_)) + } + + /** + * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the + * hive test cases assume the system is set up. + */ + private def rewritePaths(cmd: String): String = + if (cmd.toUpperCase contains "LOAD DATA") { + val testDataLocation = + hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) + cmd.replaceAll("\\.\\./\\.\\./", testDataLocation + "/") + } else { + cmd + } + + val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") + hiveFilesTemp.delete() + hiveFilesTemp.mkdir() + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) + + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { + new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) + } else { + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + File.separator + "resources") + } + + def getHiveFile(path: String): File = { + val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) + hiveDevHome + .map(new File(_, stripped)) + .filter(_.exists) + .getOrElse(new File(inRepoTests, stripped)) + } + + val describedTable = "DESCRIBE (\\w+)".r + + /** + * Override QueryExecution with special debug workflow. + */ + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { + def this(sql: String) = this(parseSql(sql)) + override lazy val analyzed = { + val describedTables = logical match { + case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil + case CacheTableCommand(tbl, _, _) => tbl :: Nil + case _ => Nil + } + + // Make sure any test tables referenced are loaded. + val referencedTables = + describedTables ++ + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } + val referencedTestTables = referencedTables.filter(testTables.contains) + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + referencedTestTables.foreach(loadTestTable) + // Proceed with analysis. + analyzer.execute(logical) + } + } + + case class TestTable(name: String, commands: (() => Unit)*) + + protected[hive] implicit class SqlCmd(sql: String) { + def cmd: () => Unit = { + () => new QueryExecution(sql).stringResult(): Unit + } + } + + /** + * A list of test tables and the DDL required to initialize them. A test table is loaded on + * demand when a query are run against it. + */ + @transient + lazy val testTables = new mutable.HashMap[String, TestTable]() + + def registerTestTable(testTable: TestTable): Unit = { + testTables += (testTable.name -> testTable) + } + + // The test tables that are defined in the Hive QTestUtil. + // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java + // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql + @transient + val hiveQTestUtilTables = Seq( + TestTable("src", + "CREATE TABLE src (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + TestTable("src1", + "CREATE TABLE src1 (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + TestTable("srcpart", () => { + runSqlHive( + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("srcpart1", () => { + runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("src_thrift", () => { + import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer + import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} + import org.apache.thrift.protocol.TBinaryProtocol + + runSqlHive( + s""" + |CREATE TABLE src_thrift(fake INT) + |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' + |WITH SERDEPROPERTIES( + | 'serialization.class'='org.apache.spark.sql.hive.test.Complex', + | 'serialization.format'='${classOf[TBinaryProtocol].getName}' + |) + |STORED AS + |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' + |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' + """.stripMargin) + + runSqlHive( + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") + }), + TestTable("serdeins", + s"""CREATE TABLE serdeins (key INT, value STRING) + |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ('field.delim'='\\t') + """.stripMargin.cmd, + "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), + TestTable("episodes", + s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) + |STORED AS avro + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd + ), + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING + // IS NOT YET SUPPORTED + TestTable("episodes_part", + s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) + |PARTITIONED BY (doctor_pt INT) + |STORED AS avro + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + // WORKAROUND: Required to pass schema to SerDe for partitioned tables. + // TODO: Pass this automatically from the table to partitions. + s""" + |ALTER TABLE episodes_part SET SERDEPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s""" + INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) + SELECT title, air_date, doctor FROM episodes + """.cmd + ), + TestTable("src_json", + s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + ) + + hiveQTestUtilTables.foreach(registerTestTable) + + private val loadedTables = new collection.mutable.HashSet[String] + + var cacheTables: Boolean = false + def loadTestTable(name: String) { + if (!(loadedTables contains name)) { + // Marks the table as loaded first to prevent infinite mutually recursive table loading. + loadedTables += name + logDebug(s"Loading test table $name") + val createCmds = + testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) + createCmds.foreach(_()) + + if (cacheTables) { + cacheTable(name) + } + } + } + + /** + * Records the UDFs present when the server starts, so we can delete ones that are created by + * tests. + */ + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames + + /** + * Resets the test instance by deleting any tables that have been created. + * TODO: also clear out UDFs, views, etc. + */ + def reset() { + try { + // HACK: Hive is too noisy by default. + org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + } + + cacheManager.clearCache() + loadedTables.clear() + catalog.cachedDataSourceTables.invalidateAll() + catalog.client.reset() + catalog.unregisterAllTables() + + FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). + foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } + + // Some tests corrupt this value on purpose, which breaks the RESET call below. + hiveconf.set("fs.default.name", new File(".").toURI.toString) + // It is important that we RESET first as broken hooks that might have been set could break + // other sql exec here. + executionHive.runSqlHive("RESET") + metadataHive.runSqlHive("RESET") + // For some reason, RESET does not reset the following variables... + // https://issues.apache.org/jira/browse/HIVE-9004 + runSqlHive("set hive.table.parameters.default=") + runSqlHive("set datanucleus.cache.collections=true") + runSqlHive("set datanucleus.cache.collections.lazy=true") + // Lots of tests fail if we do not change the partition whitelist from the default. + runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + + configure().foreach { + case (k, v) => + metadataHive.runSqlHive(s"SET $k=$v") + } + defaultOverides() + + runSqlHive("USE default") + + // Just loading src makes a lot of tests pass. This is because some tests do something like + // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our + // Analyzer and thus the test table auto-loading mechanism. + // Remove after we handle more DDL operations natively. + loadTestTable("src") + loadTestTable("srcpart") + } catch { + case e: Exception => + logError("FATAL ERROR: Failed to reset TestDB state.", e) + } + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java new file mode 100644 index 000000000000..019d8a30266e --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.hive.aggregate.MyDoubleSum; + +public class JavaDataFrameSuite { + private transient JavaSparkContext sc; + private transient HiveContext hc; + + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); + } + df = hc.read().json(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + if (hc != null) { + hc.sql("DROP TABLE IF EXISTS window_table"); + } + } + + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select(functions.avg("key").over( + Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } + + @Test + public void testUDAF() { + DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + UserDefinedAggregateFunction udaf = new MyDoubleSum(); + UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); + // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if + // we want to use distinct aggregation. + DataFrame aggregatedDF = + df.groupBy() + .agg( + udaf.distinct(col("value")), + udaf.apply(col("value")), + registeredUDAF.apply(col("value")), + callUDF("mydoublesum", col("value"))); + + List expectedResult = new ArrayList(); + expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); + checkAnswer( + aggregatedDF, + expectedResult); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java new file mode 100644 index 000000000000..4192155975c4 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.SaveMode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.QueryTest$; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaMetastoreDataSourcesSuite { + private transient JavaSparkContext sc; + private transient HiveContext sqlContext; + + String originalDefaultSource; + File path; + Path hiveManagedPath; + FileSystem fs; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestHive$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); + if (fs.exists(hiveManagedPath)){ + fs.delete(hiveManagedPath, true); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.read().json(rdd); + df.registerTempTable("jsonTable"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + if (sqlContext != null) { + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } + } + + @Test + public void saveExternalTableAndQueryIt() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + df.collectAsList()); + } + + @Test + public void saveExternalTableWithSchemaAndQueryIt() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); + + checkAnswer( + loadedDF, + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + } + + @Test + public void saveTableAndQueryIt() { + Map options = new HashMap(); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java new file mode 100644 index 000000000000..5a167edd8959 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType dataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. + buffer.update(0, null); + // The initial value of the count is 0. + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. + if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + // Otherwise, update the bufferSum and increment bufferCount. + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + // Otherwise, we update the bufferSum and bufferCount. + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. + return null; + } else { + // Otherwise, we calculate the special average value. + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java new file mode 100644 index 000000000000..c3b7768e71bf --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.Row; + +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType dataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. + buffer.update(0, input.getDouble(0)); + } else { + // Otherwise, we add the input value to the buffer value. + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. + buffer1.update(0, buffer2.getDouble(0)); + } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. + return null; + } else { + // Otherwise, the intermediate sum is the final result. + return buffer.getDouble(0); + } + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java index efd34df293c8..f33210ebdae1 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.execution; -import org.apache.hadoop.hive.ql.exec.UDF; - import java.util.List; -import org.apache.commons.lang.StringUtils; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.hive.ql.exec.UDF; public class UDFListString extends UDF { diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToIntIntMap.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToIntIntMap.java new file mode 100644 index 000000000000..b3e8bcbbd822 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToIntIntMap.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.HashMap; +import java.util.Map; + +public class UDFToIntIntMap extends UDF { + public Map evaluate(Object o) { + return new HashMap() { + { + put(1, 1); + put(2, 1); + put(3, 1); + } + }; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java new file mode 100644 index 000000000000..67576a72f198 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Arrays; +import java.util.List; + +public class UDFToListInt extends UDF { + public List evaluate(Object o) { + return Arrays.asList(1, 2, 3); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java new file mode 100644 index 000000000000..f02395cbba88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Arrays; +import java.util.List; + +public class UDFToListString extends UDF { + public List evaluate(Object o) { + return Arrays.asList("data1", "data2", "data3"); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToStringIntMap.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToStringIntMap.java new file mode 100644 index 000000000000..9eea5c9a881f --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToStringIntMap.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.HashMap; +import java.util.Map; + +public class UDFToStringIntMap extends UDF { + public Map evaluate(Object o) { + return new HashMap() { + { + put("key1", 1); + put("key2", 2); + put("key3", 3); + } + }; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java new file mode 100644 index 000000000000..e010112bb932 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -0,0 +1,1139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.test; + +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.hadoop.hive.serde2.thrift.test.IntString; +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.EncodingUtils; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; + +/** + * This is a fork of Hive 0.13's org/apache/hadoop/hive/serde2/thrift/test/Complex.java, which + * does not contain union fields that are not supported by Spark SQL. + */ + +@SuppressWarnings({"ALL", "unchecked"}) +public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); + + private static final org.apache.thrift.protocol.TField AINT_FIELD_DESC = new org.apache.thrift.protocol.TField("aint", org.apache.thrift.protocol.TType.I32, (short)1); + private static final org.apache.thrift.protocol.TField A_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("aString", org.apache.thrift.protocol.TType.STRING, (short)2); + private static final org.apache.thrift.protocol.TField LINT_FIELD_DESC = new org.apache.thrift.protocol.TField("lint", org.apache.thrift.protocol.TType.LIST, (short)3); + private static final org.apache.thrift.protocol.TField L_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lString", org.apache.thrift.protocol.TType.LIST, (short)4); + private static final org.apache.thrift.protocol.TField LINT_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lintString", org.apache.thrift.protocol.TType.LIST, (short)5); + private static final org.apache.thrift.protocol.TField M_STRING_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("mStringString", org.apache.thrift.protocol.TType.MAP, (short)6); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ComplexStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ComplexTupleSchemeFactory()); + } + + private int aint; // required + private String aString; // required + private List lint; // required + private List lString; // required + private List lintString; // required + private Map mStringString; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + AINT((short)1, "aint"), + A_STRING((short)2, "aString"), + LINT((short)3, "lint"), + L_STRING((short)4, "lString"), + LINT_STRING((short)5, "lintString"), + M_STRING_STRING((short)6, "mStringString"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // AINT + return AINT; + case 2: // A_STRING + return A_STRING; + case 3: // LINT + return LINT; + case 4: // L_STRING + return L_STRING; + case 5: // LINT_STRING + return LINT_STRING; + case 6: // M_STRING_STRING + return M_STRING_STRING; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __AINT_ISSET_ID = 0; + private byte __isset_bitfield = 0; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.AINT, new org.apache.thrift.meta_data.FieldMetaData("aint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.A_STRING, new org.apache.thrift.meta_data.FieldMetaData("aString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.LINT, new org.apache.thrift.meta_data.FieldMetaData("lint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.L_STRING, new org.apache.thrift.meta_data.FieldMetaData("lString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.LINT_STRING, new org.apache.thrift.meta_data.FieldMetaData("lintString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, IntString.class)))); + tmpMap.put(_Fields.M_STRING_STRING, new org.apache.thrift.meta_data.FieldMetaData("mStringString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Complex.class, metaDataMap); + } + + public Complex() { + } + + public Complex( + int aint, + String aString, + List lint, + List lString, + List lintString, + Map mStringString) + { + this(); + this.aint = aint; + setAintIsSet(true); + this.aString = aString; + this.lint = lint; + this.lString = lString; + this.lintString = lintString; + this.mStringString = mStringString; + } + + /** + * Performs a deep copy on other. + */ + public Complex(Complex other) { + __isset_bitfield = other.__isset_bitfield; + this.aint = other.aint; + if (other.isSetAString()) { + this.aString = other.aString; + } + if (other.isSetLint()) { + List __this__lint = new ArrayList(); + for (Integer other_element : other.lint) { + __this__lint.add(other_element); + } + this.lint = __this__lint; + } + if (other.isSetLString()) { + List __this__lString = new ArrayList(); + for (String other_element : other.lString) { + __this__lString.add(other_element); + } + this.lString = __this__lString; + } + if (other.isSetLintString()) { + List __this__lintString = new ArrayList(); + for (IntString other_element : other.lintString) { + __this__lintString.add(new IntString(other_element)); + } + this.lintString = __this__lintString; + } + if (other.isSetMStringString()) { + Map __this__mStringString = new HashMap(); + for (Map.Entry other_element : other.mStringString.entrySet()) { + + String other_element_key = other_element.getKey(); + String other_element_value = other_element.getValue(); + + String __this__mStringString_copy_key = other_element_key; + + String __this__mStringString_copy_value = other_element_value; + + __this__mStringString.put(__this__mStringString_copy_key, __this__mStringString_copy_value); + } + this.mStringString = __this__mStringString; + } + } + + public Complex deepCopy() { + return new Complex(this); + } + + @Override + public void clear() { + setAintIsSet(false); + this.aint = 0; + this.aString = null; + this.lint = null; + this.lString = null; + this.lintString = null; + this.mStringString = null; + } + + public int getAint() { + return this.aint; + } + + public void setAint(int aint) { + this.aint = aint; + setAintIsSet(true); + } + + public void unsetAint() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __AINT_ISSET_ID); + } + + /** Returns true if field aint is set (has been assigned a value) and false otherwise */ + public boolean isSetAint() { + return EncodingUtils.testBit(__isset_bitfield, __AINT_ISSET_ID); + } + + public void setAintIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __AINT_ISSET_ID, value); + } + + public String getAString() { + return this.aString; + } + + public void setAString(String aString) { + this.aString = aString; + } + + public void unsetAString() { + this.aString = null; + } + + /** Returns true if field aString is set (has been assigned a value) and false otherwise */ + public boolean isSetAString() { + return this.aString != null; + } + + public void setAStringIsSet(boolean value) { + if (!value) { + this.aString = null; + } + } + + public int getLintSize() { + return (this.lint == null) ? 0 : this.lint.size(); + } + + public java.util.Iterator getLintIterator() { + return (this.lint == null) ? null : this.lint.iterator(); + } + + public void addToLint(int elem) { + if (this.lint == null) { + this.lint = new ArrayList<>(); + } + this.lint.add(elem); + } + + public List getLint() { + return this.lint; + } + + public void setLint(List lint) { + this.lint = lint; + } + + public void unsetLint() { + this.lint = null; + } + + /** Returns true if field lint is set (has been assigned a value) and false otherwise */ + public boolean isSetLint() { + return this.lint != null; + } + + public void setLintIsSet(boolean value) { + if (!value) { + this.lint = null; + } + } + + public int getLStringSize() { + return (this.lString == null) ? 0 : this.lString.size(); + } + + public java.util.Iterator getLStringIterator() { + return (this.lString == null) ? null : this.lString.iterator(); + } + + public void addToLString(String elem) { + if (this.lString == null) { + this.lString = new ArrayList(); + } + this.lString.add(elem); + } + + public List getLString() { + return this.lString; + } + + public void setLString(List lString) { + this.lString = lString; + } + + public void unsetLString() { + this.lString = null; + } + + /** Returns true if field lString is set (has been assigned a value) and false otherwise */ + public boolean isSetLString() { + return this.lString != null; + } + + public void setLStringIsSet(boolean value) { + if (!value) { + this.lString = null; + } + } + + public int getLintStringSize() { + return (this.lintString == null) ? 0 : this.lintString.size(); + } + + public java.util.Iterator getLintStringIterator() { + return (this.lintString == null) ? null : this.lintString.iterator(); + } + + public void addToLintString(IntString elem) { + if (this.lintString == null) { + this.lintString = new ArrayList<>(); + } + this.lintString.add(elem); + } + + public List getLintString() { + return this.lintString; + } + + public void setLintString(List lintString) { + this.lintString = lintString; + } + + public void unsetLintString() { + this.lintString = null; + } + + /** Returns true if field lintString is set (has been assigned a value) and false otherwise */ + public boolean isSetLintString() { + return this.lintString != null; + } + + public void setLintStringIsSet(boolean value) { + if (!value) { + this.lintString = null; + } + } + + public int getMStringStringSize() { + return (this.mStringString == null) ? 0 : this.mStringString.size(); + } + + public void putToMStringString(String key, String val) { + if (this.mStringString == null) { + this.mStringString = new HashMap(); + } + this.mStringString.put(key, val); + } + + public Map getMStringString() { + return this.mStringString; + } + + public void setMStringString(Map mStringString) { + this.mStringString = mStringString; + } + + public void unsetMStringString() { + this.mStringString = null; + } + + /** Returns true if field mStringString is set (has been assigned a value) and false otherwise */ + public boolean isSetMStringString() { + return this.mStringString != null; + } + + public void setMStringStringIsSet(boolean value) { + if (!value) { + this.mStringString = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case AINT: + if (value == null) { + unsetAint(); + } else { + setAint((Integer)value); + } + break; + + case A_STRING: + if (value == null) { + unsetAString(); + } else { + setAString((String)value); + } + break; + + case LINT: + if (value == null) { + unsetLint(); + } else { + setLint((List)value); + } + break; + + case L_STRING: + if (value == null) { + unsetLString(); + } else { + setLString((List)value); + } + break; + + case LINT_STRING: + if (value == null) { + unsetLintString(); + } else { + setLintString((List)value); + } + break; + + case M_STRING_STRING: + if (value == null) { + unsetMStringString(); + } else { + setMStringString((Map)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case AINT: + return Integer.valueOf(getAint()); + + case A_STRING: + return getAString(); + + case LINT: + return getLint(); + + case L_STRING: + return getLString(); + + case LINT_STRING: + return getLintString(); + + case M_STRING_STRING: + return getMStringString(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case AINT: + return isSetAint(); + case A_STRING: + return isSetAString(); + case LINT: + return isSetLint(); + case L_STRING: + return isSetLString(); + case LINT_STRING: + return isSetLintString(); + case M_STRING_STRING: + return isSetMStringString(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof Complex) + return this.equals((Complex)that); + return false; + } + + public boolean equals(Complex that) { + if (that == null) + return false; + + boolean this_present_aint = true; + boolean that_present_aint = true; + if (this_present_aint || that_present_aint) { + if (!(this_present_aint && that_present_aint)) + return false; + if (this.aint != that.aint) + return false; + } + + boolean this_present_aString = true && this.isSetAString(); + boolean that_present_aString = true && that.isSetAString(); + if (this_present_aString || that_present_aString) { + if (!(this_present_aString && that_present_aString)) + return false; + if (!this.aString.equals(that.aString)) + return false; + } + + boolean this_present_lint = true && this.isSetLint(); + boolean that_present_lint = true && that.isSetLint(); + if (this_present_lint || that_present_lint) { + if (!(this_present_lint && that_present_lint)) + return false; + if (!this.lint.equals(that.lint)) + return false; + } + + boolean this_present_lString = true && this.isSetLString(); + boolean that_present_lString = true && that.isSetLString(); + if (this_present_lString || that_present_lString) { + if (!(this_present_lString && that_present_lString)) + return false; + if (!this.lString.equals(that.lString)) + return false; + } + + boolean this_present_lintString = true && this.isSetLintString(); + boolean that_present_lintString = true && that.isSetLintString(); + if (this_present_lintString || that_present_lintString) { + if (!(this_present_lintString && that_present_lintString)) + return false; + if (!this.lintString.equals(that.lintString)) + return false; + } + + boolean this_present_mStringString = true && this.isSetMStringString(); + boolean that_present_mStringString = true && that.isSetMStringString(); + if (this_present_mStringString || that_present_mStringString) { + if (!(this_present_mStringString && that_present_mStringString)) + return false; + if (!this.mStringString.equals(that.mStringString)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + HashCodeBuilder builder = new HashCodeBuilder(); + + boolean present_aint = true; + builder.append(present_aint); + if (present_aint) + builder.append(aint); + + boolean present_aString = true && (isSetAString()); + builder.append(present_aString); + if (present_aString) + builder.append(aString); + + boolean present_lint = true && (isSetLint()); + builder.append(present_lint); + if (present_lint) + builder.append(lint); + + boolean present_lString = true && (isSetLString()); + builder.append(present_lString); + if (present_lString) + builder.append(lString); + + boolean present_lintString = true && (isSetLintString()); + builder.append(present_lintString); + if (present_lintString) + builder.append(lintString); + + boolean present_mStringString = true && (isSetMStringString()); + builder.append(present_mStringString); + if (present_mStringString) + builder.append(mStringString); + + return builder.toHashCode(); + } + + public int compareTo(Complex other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + Complex typedOther = (Complex)other; + + lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aint, typedOther.aint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetAString()).compareTo(typedOther.isSetAString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aString, typedOther.aString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLint()).compareTo(typedOther.isSetLint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lint, typedOther.lint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLString()).compareTo(typedOther.isSetLString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lString, typedOther.lString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLintString()).compareTo(typedOther.isSetLintString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLintString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lintString, typedOther.lintString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMStringString()).compareTo(typedOther.isSetMStringString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMStringString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.mStringString, typedOther.mStringString); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Complex("); + boolean first = true; + + sb.append("aint:"); + sb.append(this.aint); + first = false; + if (!first) sb.append(", "); + sb.append("aString:"); + if (this.aString == null) { + sb.append("null"); + } else { + sb.append(this.aString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lint:"); + if (this.lint == null) { + sb.append("null"); + } else { + sb.append(this.lint); + } + first = false; + if (!first) sb.append(", "); + sb.append("lString:"); + if (this.lString == null) { + sb.append("null"); + } else { + sb.append(this.lString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lintString:"); + if (this.lintString == null) { + sb.append("null"); + } else { + sb.append(this.lintString); + } + first = false; + if (!first) sb.append(", "); + sb.append("mStringString:"); + if (this.mStringString == null) { + sb.append("null"); + } else { + sb.append(this.mStringString); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ComplexStandardSchemeFactory implements SchemeFactory { + public ComplexStandardScheme getScheme() { + return new ComplexStandardScheme(); + } + } + + private static class ComplexStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // AINT + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // A_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // LINT + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.lint = new ArrayList(_list0.size); + for (int _i1 = 0; _i1 < _list0.size; ++_i1) + { + int _elem2; // required + _elem2 = iprot.readI32(); + struct.lint.add(_elem2); + } + iprot.readListEnd(); + } + struct.setLintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // L_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); + struct.lString = new ArrayList(_list3.size); + for (int _i4 = 0; _i4 < _list3.size; ++_i4) + { + String _elem5; // required + _elem5 = iprot.readString(); + struct.lString.add(_elem5); + } + iprot.readListEnd(); + } + struct.setLStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LINT_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list6 = iprot.readListBegin(); + struct.lintString = new ArrayList(_list6.size); + for (int _i7 = 0; _i7 < _list6.size; ++_i7) + { + IntString _elem8; // required + _elem8 = new IntString(); + _elem8.read(iprot); + struct.lintString.add(_elem8); + } + iprot.readListEnd(); + } + struct.setLintStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // M_STRING_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map9 = iprot.readMapBegin(); + struct.mStringString = new HashMap(2*_map9.size); + for (int _i10 = 0; _i10 < _map9.size; ++_i10) + { + String _key11; // required + String _val12; // required + _key11 = iprot.readString(); + _val12 = iprot.readString(); + struct.mStringString.put(_key11, _val12); + } + iprot.readMapEnd(); + } + struct.setMStringStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Complex struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(AINT_FIELD_DESC); + oprot.writeI32(struct.aint); + oprot.writeFieldEnd(); + if (struct.aString != null) { + oprot.writeFieldBegin(A_STRING_FIELD_DESC); + oprot.writeString(struct.aString); + oprot.writeFieldEnd(); + } + if (struct.lint != null) { + oprot.writeFieldBegin(LINT_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.lint.size())); + for (int _iter13 : struct.lint) + { + oprot.writeI32(_iter13); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lString != null) { + oprot.writeFieldBegin(L_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.lString.size())); + for (String _iter14 : struct.lString) + { + oprot.writeString(_iter14); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lintString != null) { + oprot.writeFieldBegin(LINT_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, struct.lintString.size())); + for (IntString _iter15 : struct.lintString) + { + _iter15.write(oprot); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.mStringString != null) { + oprot.writeFieldBegin(M_STRING_STRING_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, struct.mStringString.size())); + for (Map.Entry _iter16 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter16.getKey()); + oprot.writeString(_iter16.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ComplexTupleSchemeFactory implements SchemeFactory { + public ComplexTupleScheme getScheme() { + return new ComplexTupleScheme(); + } + } + + private static class ComplexTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + BitSet optionals = new BitSet(); + if (struct.isSetAint()) { + optionals.set(0); + } + if (struct.isSetAString()) { + optionals.set(1); + } + if (struct.isSetLint()) { + optionals.set(2); + } + if (struct.isSetLString()) { + optionals.set(3); + } + if (struct.isSetLintString()) { + optionals.set(4); + } + if (struct.isSetMStringString()) { + optionals.set(5); + } + oprot.writeBitSet(optionals, 6); + if (struct.isSetAint()) { + oprot.writeI32(struct.aint); + } + if (struct.isSetAString()) { + oprot.writeString(struct.aString); + } + if (struct.isSetLint()) { + { + oprot.writeI32(struct.lint.size()); + for (int _iter17 : struct.lint) + { + oprot.writeI32(_iter17); + } + } + } + if (struct.isSetLString()) { + { + oprot.writeI32(struct.lString.size()); + for (String _iter18 : struct.lString) + { + oprot.writeString(_iter18); + } + } + } + if (struct.isSetLintString()) { + { + oprot.writeI32(struct.lintString.size()); + for (IntString _iter19 : struct.lintString) + { + _iter19.write(oprot); + } + } + } + if (struct.isSetMStringString()) { + { + oprot.writeI32(struct.mStringString.size()); + for (Map.Entry _iter20 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter20.getKey()); + oprot.writeString(_iter20.getValue()); + } + } + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + BitSet incoming = iprot.readBitSet(6); + if (incoming.get(0)) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } + if (incoming.get(1)) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } + if (incoming.get(2)) { + { + org.apache.thrift.protocol.TList _list21 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.lint = new ArrayList(_list21.size); + for (int _i22 = 0; _i22 < _list21.size; ++_i22) + { + int _elem23; // required + _elem23 = iprot.readI32(); + struct.lint.add(_elem23); + } + } + struct.setLintIsSet(true); + } + if (incoming.get(3)) { + { + org.apache.thrift.protocol.TList _list24 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.lString = new ArrayList(_list24.size); + for (int _i25 = 0; _i25 < _list24.size; ++_i25) + { + String _elem26; // required + _elem26 = iprot.readString(); + struct.lString.add(_elem26); + } + } + struct.setLStringIsSet(true); + } + if (incoming.get(4)) { + { + org.apache.thrift.protocol.TList _list27 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + struct.lintString = new ArrayList(_list27.size); + for (int _i28 = 0; _i28 < _list27.size; ++_i28) + { + IntString _elem29; // required + _elem29 = new IntString(); + _elem29.read(iprot); + struct.lintString.add(_elem29); + } + } + struct.setLintStringIsSet(true); + } + if (incoming.get(5)) { + { + org.apache.thrift.protocol.TMap _map30 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.mStringString = new HashMap(2*_map30.size); + for (int _i31 = 0; _i31 < _map30.size; ++_i31) + { + String _key32; // required + String _val33; // required + _key32 = iprot.readString(); + _val33 = iprot.readString(); + struct.mStringString.put(_key32, _val33); + } + } + struct.setMStringStringIsSet(true); + } + } + } + +} + diff --git a/sql/hive/src/test/resources/TestUDTF.jar b/sql/hive/src/test/resources/TestUDTF.jar new file mode 100644 index 000000000000..514f2d5d26fd Binary files /dev/null and b/sql/hive/src/test/resources/TestUDTF.jar differ diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 new file mode 100644 index 000000000000..f6ba75da254c --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 @@ -0,0 +1,3 @@ +5 +5 +5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 new file mode 100644 index 000000000000..ca7b591095e2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 @@ -0,0 +1,4 @@ +val_4 +val_5 +val_5 +val_5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 new file mode 100644 index 000000000000..b8626c4cff28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 @@ -0,0 +1 @@ +4 diff --git a/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 new file mode 100644 index 000000000000..9a276bc794c0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 @@ -0,0 +1,3 @@ +476 +172 +622 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e new file mode 100644 index 000000000000..27ba77ddaf61 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f new file mode 100644 index 000000000000..444039e75fba --- /dev/null +++ b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f @@ -0,0 +1,500 @@ +476 +172 +622 +54 +330 +818 +510 +556 +196 +968 +530 +386 +802 +300 +546 +448 +738 +132 +256 +426 +292 +812 +858 +748 +304 +938 +290 +990 +74 +654 +562 +554 +418 +30 +164 +806 +332 +834 +860 +504 +584 +438 +574 +306 +386 +676 +892 +918 +788 +474 +964 +348 +826 +988 +414 +398 +932 +416 +348 +798 +792 +494 +834 +978 +324 +754 +794 +618 +730 +532 +878 +684 +734 +650 +334 +390 +950 +34 +226 +310 +406 +678 +0 +910 +256 +622 +632 +114 +604 +410 +298 +876 +690 +258 +340 +40 +978 +314 +756 +442 +184 +222 +94 +144 +8 +560 +70 +854 +554 +416 +712 +798 +338 +764 +996 +250 +772 +874 +938 +384 +572 +374 +352 +108 +918 +102 +276 +206 +478 +426 +432 +860 +556 +352 +578 +442 +130 +636 +664 +622 +550 +274 +482 +166 +666 +360 +568 +24 +460 +362 +134 +520 +808 +768 +978 +706 +746 +544 +276 +434 +168 +696 +932 +116 +16 +822 +460 +416 +696 +48 +926 +862 +358 +344 +84 +258 +316 +238 +992 +0 +644 +394 +936 +786 +908 +200 +596 +398 +382 +836 +192 +52 +330 +654 +460 +410 +240 +262 +102 +808 +86 +872 +312 +938 +936 +616 +190 +392 +576 +962 +914 +196 +564 +394 +374 +636 +636 +818 +940 +274 +738 +632 +338 +826 +170 +154 +0 +980 +174 +728 +358 +236 +268 +790 +564 +276 +476 +838 +30 +236 +144 +180 +614 +38 +870 +20 +554 +546 +612 +448 +618 +778 +654 +484 +738 +784 +544 +662 +802 +484 +904 +354 +452 +10 +994 +804 +792 +634 +790 +116 +70 +672 +190 +22 +336 +68 +458 +466 +286 +944 +644 +996 +320 +390 +84 +642 +860 +238 +978 +916 +156 +152 +82 +446 +984 +298 +898 +436 +456 +276 +906 +60 +418 +128 +936 +152 +148 +684 +138 +460 +66 +736 +206 +592 +226 +432 +734 +688 +334 +548 +438 +478 +970 +232 +446 +512 +526 +140 +974 +960 +802 +576 +382 +10 +488 +876 +256 +934 +864 +404 +632 +458 +938 +926 +560 +4 +70 +566 +662 +470 +160 +88 +386 +642 +670 +208 +932 +732 +350 +806 +966 +106 +210 +514 +812 +818 +380 +812 +802 +228 +516 +180 +406 +524 +696 +848 +24 +792 +402 +434 +328 +862 +908 +956 +596 +250 +862 +328 +848 +374 +764 +10 +140 +794 +960 +582 +48 +702 +510 +208 +140 +326 +876 +238 +828 +400 +982 +474 +878 +720 +496 +958 +610 +834 +398 +888 +240 +858 +338 +886 +646 +650 +554 +460 +956 +356 +936 +620 +634 +666 +986 +920 +414 +498 +530 +960 +166 +272 +706 +344 +428 +924 +466 +812 +266 +350 +378 +908 +750 +802 +842 +814 +768 +512 +52 +268 +134 +768 +758 +36 +924 +984 +200 +596 +18 +682 +996 +292 +916 +724 +372 +570 +696 +334 +36 +546 +366 +562 +688 +194 +938 +630 +168 +56 +74 +896 +304 +696 +614 +388 +828 +954 +444 +252 +180 +338 +806 +800 +400 +194 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 new file mode 100644 index 000000000000..dac1b84b916d --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 new file mode 100644 index 000000000000..c7cb747c0a65 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c new file mode 100644 index 000000000000..c7cb747c0a65 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a new file mode 100644 index 000000000000..dac1b84b916d --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 new file mode 100644 index 000000000000..1eea4a9b2368 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce new file mode 100644 index 000000000000..1eea4a9b2368 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 b/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 new file mode 100644 index 000000000000..946e72fc87c2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 @@ -0,0 +1,2 @@ +97 500 +97 500 diff --git a/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 b/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 new file mode 100644 index 000000000000..a5c8806279fa --- /dev/null +++ b/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 @@ -0,0 +1,2 @@ +3 +3 diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad similarity index 100% rename from sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 rename to sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad diff --git a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 index d35bf9093ca9..2383bef94097 100644 --- a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 +++ b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 @@ -15,9 +15,9 @@ my_enum_structlist_map map from deserializer my_structlist array>> from deserializer my_enumlist array from deserializer -my_stringset struct<> from deserializer -my_enumset struct<> from deserializer -my_structset struct<> from deserializer +my_stringset array from deserializer +my_enumset array from deserializer +my_structset array>> from deserializer optionals struct<> from deserializer b string diff --git a/sql/hive/src/test/resources/golden/create table as with db name within backticks-0-a253b1ed35dbf503d1b8902dacbe23ac b/sql/hive/src/test/resources/golden/create table as with db name within backticks-0-a253b1ed35dbf503d1b8902dacbe23ac new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create table as with db name within backticks-1-61dc640dfeaff471f3d2b730f9cbf959 b/sql/hive/src/test/resources/golden/create table as with db name within backticks-1-61dc640dfeaff471f3d2b730f9cbf959 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create table as with db name within backticks-2-ce780d068b8d24786e639e361101a0c7 b/sql/hive/src/test/resources/golden/create table as with db name within backticks-2-ce780d068b8d24786e639e361101a0c7 new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/create table as with db name within backticks-2-ce780d068b8d24786e639e361101a0c7 @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/create table as with db name within backticks-3-afd6e46b6a289c3c24a8eec75a94043c b/sql/hive/src/test/resources/golden/create table as with db name within backticks-3-afd6e46b6a289c3c24a8eec75a94043c new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 rename to sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 rename to sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea rename to sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b rename to sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b rename to sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b diff --git a/sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 b/sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 b/sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 b/sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f b/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f new file mode 100644 index 000000000000..1dcda4315a14 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f @@ -0,0 +1 @@ +{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}},"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025","fb:testid":"1234"} diff --git a/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc b/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc new file mode 100644 index 000000000000..81c545efebe5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc @@ -0,0 +1 @@ +1234 diff --git a/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a b/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a new file mode 100644 index 000000000000..99127db9e311 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a @@ -0,0 +1 @@ +amy {"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}} diff --git a/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb b/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb new file mode 100644 index 000000000000..0bc03998296a --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb @@ -0,0 +1 @@ +{"price":19.95,"color":"red"} [{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 b/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 new file mode 100644 index 000000000000..4f7e09bd3fa7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 @@ -0,0 +1 @@ +{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95} [{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab b/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab new file mode 100644 index 000000000000..b2d212a597d9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab @@ -0,0 +1 @@ +reference ["reference","fiction","fiction"] ["0-553-21311-3","0-395-19395-8"] [{"age":25,"name":"bob"},{"age":26,"name":"jack"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da b/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da new file mode 100644 index 000000000000..21d88629fcdb --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da @@ -0,0 +1 @@ +25 [25,26] diff --git a/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 b/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 new file mode 100644 index 000000000000..e60721e1dd24 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 @@ -0,0 +1 @@ +2 [[1,2,{"b":"y","a":"x"}],[3,4],[5,6]] 1 [1,2,{"b":"y","a":"x"}] [1,2,{"b":"y","a":"x"},3,4,5,6] y ["y"] diff --git a/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 b/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 new file mode 100644 index 000000000000..356fcdf7139b --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 @@ -0,0 +1 @@ +NULL NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b b/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b new file mode 100644 index 000000000000..ef4a39675ed6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b @@ -0,0 +1 @@ +94025 diff --git a/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 b/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 b/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b new file mode 100644 index 000000000000..518a70918b2c --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b @@ -0,0 +1 @@ +name string diff --git a/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce new file mode 100644 index 000000000000..33398360345d --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce @@ -0,0 +1 @@ +邵铮 diff --git a/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 000000000000..01e79c32a8c9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 new file mode 100644 index 000000000000..0c7520f2090d --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 @@ -0,0 +1,3 @@ +86 val_86 +238 val_238 +311 val_311 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 000000000000..01e79c32a8c9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 000000000000..185a91c110d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 000000000000..185a91c110d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea index 25ce912507d5..a1963ba81e0d 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea +++ b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 index 25ce912507d5..a1963ba81e0d 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 +++ b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a b/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 b/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-10-692a197bd688b48f762e72978f54aa32 b/sql/hive/src/test/resources/golden/merge4-10-692a197bd688b48f762e72978f54aa32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b b/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b new file mode 100644 index 000000000000..5d2cddc42f27 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b @@ -0,0 +1,1500 @@ +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 12 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 12 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 12 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 12 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 12 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 12 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 12 +12 val_12 2010-08-15 12 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 12 +15 val_15 2010-08-15 12 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 12 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 12 +18 val_18 2010-08-15 12 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 12 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 12 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 12 +24 val_24 2010-08-15 12 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 12 +26 val_26 2010-08-15 12 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 12 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 12 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 12 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 12 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 12 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 12 +37 val_37 2010-08-15 12 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 12 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 12 +42 val_42 2010-08-15 12 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 12 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 12 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 12 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 12 +51 val_51 2010-08-15 12 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 12 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 12 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 12 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 12 +58 val_58 2010-08-15 12 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 12 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 12 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 12 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 12 +67 val_67 2010-08-15 12 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 12 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 12 +72 val_72 2010-08-15 12 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 12 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 12 +76 val_76 2010-08-15 12 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 12 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 12 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 12 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 12 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 12 +83 val_83 2010-08-15 12 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 12 +84 val_84 2010-08-15 12 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 12 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 12 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 12 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 12 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 12 +95 val_95 2010-08-15 12 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 12 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 12 +97 val_97 2010-08-15 12 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 12 +98 val_98 2010-08-15 12 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 12 +100 val_100 2010-08-15 12 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 12 +103 val_103 2010-08-15 12 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 12 +104 val_104 2010-08-15 12 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 12 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 12 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 12 +113 val_113 2010-08-15 12 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 12 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 12 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 12 +118 val_118 2010-08-15 12 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 12 +120 val_120 2010-08-15 12 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 12 +125 val_125 2010-08-15 12 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 12 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 12 +129 val_129 2010-08-15 12 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 12 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 12 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 12 +134 val_134 2010-08-15 12 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 12 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 12 +137 val_137 2010-08-15 12 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 12 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 12 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 12 +146 val_146 2010-08-15 12 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 12 +149 val_149 2010-08-15 12 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 12 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 12 +152 val_152 2010-08-15 12 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 12 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 12 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 12 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 12 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 12 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 12 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 12 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 12 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 12 +164 val_164 2010-08-15 12 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 12 +165 val_165 2010-08-15 12 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 12 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 12 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 12 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 12 +172 val_172 2010-08-15 12 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 12 +174 val_174 2010-08-15 12 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 12 +175 val_175 2010-08-15 12 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 12 +176 val_176 2010-08-15 12 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 12 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 12 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 12 +179 val_179 2010-08-15 12 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 12 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 12 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 12 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 12 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 12 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 12 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 12 +191 val_191 2010-08-15 12 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 12 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 12 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 12 +195 val_195 2010-08-15 12 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 12 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 12 +197 val_197 2010-08-15 12 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 12 +200 val_200 2010-08-15 12 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 12 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 12 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 12 +203 val_203 2010-08-15 12 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 12 +205 val_205 2010-08-15 12 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 12 +207 val_207 2010-08-15 12 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 12 +209 val_209 2010-08-15 12 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 12 +213 val_213 2010-08-15 12 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 12 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 12 +216 val_216 2010-08-15 12 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 12 +217 val_217 2010-08-15 12 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 12 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 12 +219 val_219 2010-08-15 12 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 12 +221 val_221 2010-08-15 12 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 12 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 12 +223 val_223 2010-08-15 12 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 12 +224 val_224 2010-08-15 12 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 12 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 12 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 12 +229 val_229 2010-08-15 12 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 12 +233 val_233 2010-08-15 12 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 12 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 12 +237 val_237 2010-08-15 12 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 12 +238 val_238 2010-08-15 12 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 12 +239 val_239 2010-08-15 12 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 12 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 12 +242 val_242 2010-08-15 12 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 12 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 12 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 12 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 12 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 12 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 12 +255 val_255 2010-08-15 12 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 12 +256 val_256 2010-08-15 12 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 12 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 12 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 12 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 12 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 12 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 12 +265 val_265 2010-08-15 12 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 12 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 12 +272 val_272 2010-08-15 12 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 12 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 12 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 12 +278 val_278 2010-08-15 12 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 12 +280 val_280 2010-08-15 12 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 12 +281 val_281 2010-08-15 12 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 12 +282 val_282 2010-08-15 12 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 12 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 12 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 12 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 12 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 12 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 12 +288 val_288 2010-08-15 12 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 12 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 12 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 12 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 12 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 12 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 12 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 12 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 12 +307 val_307 2010-08-15 12 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 12 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 12 +309 val_309 2010-08-15 12 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 12 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 12 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 12 +317 val_317 2010-08-15 12 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 12 +321 val_321 2010-08-15 12 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 12 +322 val_322 2010-08-15 12 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 12 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 12 +325 val_325 2010-08-15 12 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 12 +331 val_331 2010-08-15 12 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 12 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 12 +333 val_333 2010-08-15 12 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 12 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 12 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 12 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 12 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 12 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 12 +342 val_342 2010-08-15 12 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 12 +344 val_344 2010-08-15 12 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 12 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 12 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 12 +353 val_353 2010-08-15 12 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 12 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 12 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 12 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 12 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 12 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 12 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 12 +367 val_367 2010-08-15 12 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 12 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 12 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 12 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 12 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 12 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 12 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 12 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 12 +382 val_382 2010-08-15 12 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 12 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 12 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 12 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 12 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 12 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 12 +395 val_395 2010-08-15 12 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 12 +397 val_397 2010-08-15 12 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 12 +399 val_399 2010-08-15 12 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 12 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 12 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 12 +404 val_404 2010-08-15 12 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 12 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 12 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 12 +413 val_413 2010-08-15 12 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 12 +414 val_414 2010-08-15 12 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 12 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 12 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 12 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 12 +424 val_424 2010-08-15 12 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 12 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 12 +429 val_429 2010-08-15 12 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 12 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 12 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 12 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 12 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 12 +439 val_439 2010-08-15 12 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 12 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 12 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 12 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 12 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 12 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 12 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 12 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 12 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 12 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 12 +458 val_458 2010-08-15 12 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 12 +459 val_459 2010-08-15 12 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 12 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 12 +462 val_462 2010-08-15 12 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 12 +463 val_463 2010-08-15 12 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 12 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 12 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 12 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 12 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 12 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 12 +478 val_478 2010-08-15 12 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 12 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 12 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 12 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 12 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 12 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 12 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 12 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 12 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 12 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 12 +492 val_492 2010-08-15 12 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 12 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 12 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 12 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 12 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 12 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 diff --git a/sql/hive/src/test/resources/golden/merge4-12-62541540a18d68a3cb8497a741061d11 b/sql/hive/src/test/resources/golden/merge4-12-62541540a18d68a3cb8497a741061d11 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-13-ed1103f06609365b40e78d13c654cc71 b/sql/hive/src/test/resources/golden/merge4-13-ed1103f06609365b40e78d13c654cc71 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 b/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 new file mode 100644 index 000000000000..30becc42d7b5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 @@ -0,0 +1,3 @@ +ds=2010-08-15/hr=11 +ds=2010-08-15/hr=12 +ds=2010-08-15/hr=file, diff --git a/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a b/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a new file mode 100644 index 000000000000..4c867a5deff0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a @@ -0,0 +1 @@ +1 1 2010-08-15 file, diff --git a/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a b/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 b/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe b/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-5-3d24d877366c42030f6d9a596665720d b/sql/hive/src/test/resources/golden/merge4-5-3d24d877366c42030f6d9a596665720d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-6-b3a76420183795720ab3a384046e5af b/sql/hive/src/test/resources/golden/merge4-6-b3a76420183795720ab3a384046e5af new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-7-631a45828eae3f5f562d992efe4cd56d b/sql/hive/src/test/resources/golden/merge4-7-631a45828eae3f5f562d992efe4cd56d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b b/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b new file mode 100644 index 000000000000..aa972caa5665 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b @@ -0,0 +1,1000 @@ +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 12 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 12 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 12 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 12 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 12 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 12 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 12 +12 val_12 2010-08-15 12 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 12 +15 val_15 2010-08-15 12 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 12 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 12 +18 val_18 2010-08-15 12 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 12 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 12 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 12 +24 val_24 2010-08-15 12 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 12 +26 val_26 2010-08-15 12 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 12 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 12 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 12 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 12 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 12 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 12 +37 val_37 2010-08-15 12 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 12 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 12 +42 val_42 2010-08-15 12 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 12 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 12 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 12 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 12 +51 val_51 2010-08-15 12 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 12 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 12 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 12 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 12 +58 val_58 2010-08-15 12 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 12 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 12 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 12 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 12 +67 val_67 2010-08-15 12 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 12 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 12 +72 val_72 2010-08-15 12 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 12 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 12 +76 val_76 2010-08-15 12 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 12 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 12 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 12 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 12 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 12 +83 val_83 2010-08-15 12 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 12 +84 val_84 2010-08-15 12 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 12 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 12 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 12 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 12 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 12 +95 val_95 2010-08-15 12 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 12 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 12 +97 val_97 2010-08-15 12 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 12 +98 val_98 2010-08-15 12 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 12 +100 val_100 2010-08-15 12 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 12 +103 val_103 2010-08-15 12 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 12 +104 val_104 2010-08-15 12 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 12 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 12 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 12 +113 val_113 2010-08-15 12 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 12 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 12 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 12 +118 val_118 2010-08-15 12 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 12 +120 val_120 2010-08-15 12 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 12 +125 val_125 2010-08-15 12 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 12 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 12 +129 val_129 2010-08-15 12 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 12 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 12 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 12 +134 val_134 2010-08-15 12 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 12 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 12 +137 val_137 2010-08-15 12 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 12 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 12 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 12 +146 val_146 2010-08-15 12 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 12 +149 val_149 2010-08-15 12 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 12 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 12 +152 val_152 2010-08-15 12 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 12 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 12 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 12 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 12 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 12 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 12 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 12 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 12 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 12 +164 val_164 2010-08-15 12 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 12 +165 val_165 2010-08-15 12 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 12 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 12 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 12 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 12 +172 val_172 2010-08-15 12 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 12 +174 val_174 2010-08-15 12 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 12 +175 val_175 2010-08-15 12 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 12 +176 val_176 2010-08-15 12 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 12 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 12 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 12 +179 val_179 2010-08-15 12 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 12 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 12 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 12 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 12 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 12 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 12 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 12 +191 val_191 2010-08-15 12 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 12 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 12 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 12 +195 val_195 2010-08-15 12 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 12 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 12 +197 val_197 2010-08-15 12 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 12 +200 val_200 2010-08-15 12 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 12 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 12 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 12 +203 val_203 2010-08-15 12 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 12 +205 val_205 2010-08-15 12 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 12 +207 val_207 2010-08-15 12 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 12 +209 val_209 2010-08-15 12 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 12 +213 val_213 2010-08-15 12 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 12 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 12 +216 val_216 2010-08-15 12 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 12 +217 val_217 2010-08-15 12 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 12 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 12 +219 val_219 2010-08-15 12 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 12 +221 val_221 2010-08-15 12 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 12 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 12 +223 val_223 2010-08-15 12 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 12 +224 val_224 2010-08-15 12 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 12 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 12 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 12 +229 val_229 2010-08-15 12 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 12 +233 val_233 2010-08-15 12 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 12 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 12 +237 val_237 2010-08-15 12 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 12 +238 val_238 2010-08-15 12 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 12 +239 val_239 2010-08-15 12 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 12 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 12 +242 val_242 2010-08-15 12 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 12 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 12 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 12 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 12 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 12 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 12 +255 val_255 2010-08-15 12 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 12 +256 val_256 2010-08-15 12 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 12 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 12 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 12 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 12 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 12 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 12 +265 val_265 2010-08-15 12 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 12 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 12 +272 val_272 2010-08-15 12 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 12 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 12 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 12 +278 val_278 2010-08-15 12 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 12 +280 val_280 2010-08-15 12 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 12 +281 val_281 2010-08-15 12 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 12 +282 val_282 2010-08-15 12 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 12 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 12 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 12 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 12 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 12 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 12 +288 val_288 2010-08-15 12 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 12 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 12 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 12 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 12 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 12 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 12 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 12 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 12 +307 val_307 2010-08-15 12 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 12 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 12 +309 val_309 2010-08-15 12 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 12 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 12 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 12 +317 val_317 2010-08-15 12 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 12 +321 val_321 2010-08-15 12 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 12 +322 val_322 2010-08-15 12 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 12 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 12 +325 val_325 2010-08-15 12 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 12 +331 val_331 2010-08-15 12 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 12 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 12 +333 val_333 2010-08-15 12 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 12 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 12 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 12 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 12 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 12 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 12 +342 val_342 2010-08-15 12 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 12 +344 val_344 2010-08-15 12 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 12 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 12 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 12 +353 val_353 2010-08-15 12 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 12 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 12 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 12 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 12 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 12 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 12 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 12 +367 val_367 2010-08-15 12 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 12 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 12 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 12 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 12 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 12 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 12 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 12 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 12 +382 val_382 2010-08-15 12 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 12 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 12 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 12 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 12 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 12 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 12 +395 val_395 2010-08-15 12 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 12 +397 val_397 2010-08-15 12 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 12 +399 val_399 2010-08-15 12 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 12 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 12 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 12 +404 val_404 2010-08-15 12 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 12 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 12 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 12 +413 val_413 2010-08-15 12 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 12 +414 val_414 2010-08-15 12 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 12 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 12 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 12 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 12 +424 val_424 2010-08-15 12 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 12 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 12 +429 val_429 2010-08-15 12 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 12 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 12 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 12 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 12 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 12 +439 val_439 2010-08-15 12 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 12 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 12 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 12 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 12 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 12 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 12 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 12 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 12 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 12 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 12 +458 val_458 2010-08-15 12 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 12 +459 val_459 2010-08-15 12 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 12 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 12 +462 val_462 2010-08-15 12 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 12 +463 val_463 2010-08-15 12 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 12 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 12 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 12 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 12 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 12 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 12 +478 val_478 2010-08-15 12 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 12 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 12 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 12 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 12 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 12 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 12 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 12 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 12 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 12 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 12 +492 val_492 2010-08-15 12 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 12 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 12 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 12 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 12 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 12 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 diff --git a/sql/hive/src/test/resources/golden/merge4-9-ad3dc168c8b6f048717e39ab16b0a319 b/sql/hive/src/test/resources/golden/merge4-9-ad3dc168c8b6f048717e39ab16b0a319 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a b/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a new file mode 100644 index 000000000000..390d344ecb9d --- /dev/null +++ b/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a @@ -0,0 +1 @@ +1 1 -1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e b/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b b/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b new file mode 100644 index 000000000000..e74deff51c9b --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b @@ -0,0 +1,10 @@ +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +NULL 1 +NULL NULL +1.0 NULL +1.0 1 +1.0 1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 b/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b new file mode 100644 index 000000000000..00ebb521970d --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b @@ -0,0 +1,10 @@ +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +fooNull 1 +fooNull fooNull +1.0 fooNull +1.0 1 +1.0 1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 b/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 b/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d b/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 new file mode 100644 index 000000000000..b00bcb362453 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 @@ -0,0 +1,6 @@ +a string +b string +c string +d string + +Detailed Table Information Table(tableName:base_tab, dbName:default, owner:animal, createTime:1423973915, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null), FieldSchema(name:c, type:string, comment:null), FieldSchema(name:d, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/base_tab, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973915, COLUMN_STATS_ACCURATE=true, totalSize=130, numRows=0, rawDataSize=0}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE) diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd b/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 b/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 b/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 new file mode 100644 index 000000000000..264c973ff7af --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 @@ -0,0 +1,4 @@ +a string +b string + +Detailed Table Information Table(tableName:null_tab3, dbName:default, owner:animal, createTime:1423973928, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.null.format=fooNull, serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973928, COLUMN_STATS_ACCURATE=true, totalSize=80, numRows=10, rawDataSize=70}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE) diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 new file mode 100644 index 000000000000..881917bcf1c6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 @@ -0,0 +1,18 @@ +CREATE TABLE `null_tab3`( + `a` string, + `b` string) +ROW FORMAT DELIMITED + NULL DEFINED AS 'fooNull' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3' +TBLPROPERTIES ( + 'numFiles'='1', + 'transient_lastDdlTime'='1423973928', + 'COLUMN_STATS_ACCURATE'='true', + 'totalSize'='80', + 'numRows'='10', + 'rawDataSize'='70') diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 new file mode 100644 index 000000000000..3a2e3f4984a0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 @@ -0,0 +1 @@ +-1 diff --git a/sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 b/sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 similarity index 100% rename from sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 rename to sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 diff --git a/sql/hive/src/test/resources/golden/semicolon-0-f104632770dc96b81f00ccdac51fe5a8 b/sql/hive/src/test/resources/golden/semicolon-0-f104632770dc96b81f00ccdac51fe5a8 new file mode 100644 index 000000000000..1b79f38e25b2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semicolon-0-f104632770dc96b81f00ccdac51fe5a8 @@ -0,0 +1 @@ +500 diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 index 501bb6ab32f2..7bb2c0ab4398 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` smallint, `value` float) COMMENT 'temporary table' diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 index 90f8415a1c6b..3cc1a57ee3a4 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_feng.tmp_showcrt`( +CREATE TABLE `tmp_feng.tmp_showcrt`( `key` string, `value` int) ROW FORMAT SERDE diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 index 4ee22e523031..b51c71a71f91 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 index 6fda2570b53f..29189e1d860a 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 index 3049cd6243ad..1b283db3e774 100644 --- a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 +++ b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 @@ -17,6 +17,7 @@ ^ abs acos +add_months and array array_contains @@ -29,6 +30,7 @@ base64 between bin case +cbrt ceil ceiling coalesce @@ -47,7 +49,11 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user date_add +date_format date_sub datediff day @@ -65,6 +71,7 @@ ewah_bitmap_empty ewah_bitmap_or exp explode +factorial field find_in_set first_value @@ -73,6 +80,7 @@ format_number from_unixtime from_utc_timestamp get_json_object +greatest hash hex histogram_numeric @@ -81,6 +89,7 @@ if in in_file index +initcap inline instr isnotnull @@ -88,10 +97,13 @@ isnull java_method json_tuple lag +last_day last_value lcase lead +least length +levenshtein like ln locate @@ -109,11 +121,15 @@ max min minute month +months_between named_struct negative +next_day ngrams noop +noopstreaming noopwithmap +noopwithmapstreaming not ntile nvl @@ -147,10 +163,14 @@ rpad rtrim second sentences +shiftleft +shiftright +shiftrightunsigned sign sin size sort_array +soundex space split sqrt @@ -170,6 +190,7 @@ to_unix_timestamp to_utc_timestamp translate trim +trunc ucase unbase64 unhex diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae index 0f6cc6f44f1f..fdf701f96280 100644 --- a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae +++ b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae @@ -1 +1 @@ -Table tmpfoo does not have property: bar +Table default.tmpfoo does not have property: bar diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 index 27de46fdf22a..84a31a5a6970 100644 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -1 +1 @@ --0.0010000000000000009 +-0.001 diff --git a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 index 3c91e138d7bd..d8ec084f0b2b 100644 --- a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 +++ b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 @@ -1,5 +1,5 @@ date_add(start_date, num_days) - Returns the date that is num_days after start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_add('2009-30-07', 1) FROM src LIMIT 1; - '2009-31-07' + > SELECT date_add('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-31' diff --git a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 index 29d663f35c58..169c50003625 100644 --- a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 +++ b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 @@ -1,5 +1,5 @@ date_sub(start_date, num_days) - Returns the date that is num_days before start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_sub('2009-30-07', 1) FROM src LIMIT 1; - '2009-29-07' + > SELECT date_sub('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-29' diff --git a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 index 7ccaee7ad3bd..42197f7ad3e5 100644 --- a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 +++ b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 @@ -1,5 +1,5 @@ datediff(date1, date2) - Returns the number of days between date1 and date2 date1 and date2 are strings in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. The time parts are ignored.If date1 is earlier than date2, the result is negative. Example: - > SELECT datediff('2009-30-07', '2009-31-07') FROM src LIMIT 1; + > SELECT datediff('2009-07-30', '2009-07-31') FROM src LIMIT 1; 1 diff --git a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 index d4017178b4e6..09703d10eab7 100644 --- a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 +++ b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 @@ -1 +1 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 index 6135aafa5086..7c0ec1dc3be5 100644 --- a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 +++ b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 @@ -1,6 +1,9 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: dayofmonth -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT day('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT day('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 index 47a7018d9d5a..c37eb0ec2e96 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 @@ -1 +1 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 index d9490e20a3b6..9e931f649914 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 @@ -1,6 +1,9 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: day -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT dayofmonth('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT dayofmonth('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 new file mode 100644 index 000000000000..cd35e5b290db --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 @@ -0,0 +1 @@ +reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 new file mode 100644 index 000000000000..48ef97292ab6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 @@ -0,0 +1,3 @@ +reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection +Use this UDF to call Java methods by matching the argument signature + diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca b/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 new file mode 100644 index 000000000000..176ea0358d7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 @@ -0,0 +1,5 @@ +238 -18 238 238 238 238.0 238.0 238 val_238 val_238_concat false true false false false val_238 -1 -1 VALUE_238 al_238 al_2 VAL_238 val_238 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +86 86 86 86 86 86.0 86.0 86 val_86 val_86_concat true true true true true val_86 -1 -1 VALUE_86 al_86 al_8 VAL_86 val_86 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +311 55 311 311 311 311.0 311.0 311 val_311 val_311_concat false true false false false val_311 5 6 VALUE_311 al_311 al_3 VAL_311 val_311 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +27 27 27 27 27 27.0 27.0 27 val_27 val_27_concat false true false false false val_27 -1 -1 VALUE_27 al_27 al_2 VAL_27 val_27 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +165 -91 165 165 165 165.0 165.0 165 val_165 val_165_concat false true false false false val_165 4 4 VALUE_165 al_165 al_1 VAL_165 val_165 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 diff --git a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 index d54ebfbd6fb1..a529b107ff21 100644 --- a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 +++ b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 @@ -1,2 +1,2 @@ std(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, stddev +Synonyms: stddev, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d index 5f674788180e..ac3176a38254 100644 --- a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d +++ b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d @@ -1,2 +1,2 @@ stddev(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, std +Synonyms: std, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 new file mode 100644 index 000000000000..44b2a42cc26c --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 @@ -0,0 +1 @@ +unhex(str) - Converts hexadecimal argument to binary diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 new file mode 100644 index 000000000000..97af3b812a42 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 @@ -0,0 +1,14 @@ +unhex(str) - Converts hexadecimal argument to binary +Performs the inverse operation of HEX(str). That is, it interprets +each pair of hexadecimal digits in the argument as a number and +converts it to the byte representation of the number. The +resulting characters are returned as a binary string. + +Example: +> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1; +'MySQL' + +The characters in the argument string must be legal hexadecimal +digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters +any nonhexadecimal digits in the argument, it returns NULL. Also, +if there are an odd number of characters a leading 0 is appended. diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e new file mode 100644 index 000000000000..b4a6f2b69222 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e @@ -0,0 +1 @@ +MySQL 1267 a -4 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 new file mode 100644 index 000000000000..3a67adaf0a9a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 @@ -0,0 +1 @@ +NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c b/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 b/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 b/sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 rename to sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 1. testWindowing-0-327a8cd39fe30255ff492ee86f660522 b/sql/hive/src/test/resources/golden/windowing.q -- 1. testWindowing-0-327a8cd39fe30255ff492ee86f660522 new file mode 100644 index 000000000000..850c41c8115d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 1. testWindowing-0-327a8cd39fe30255ff492ee86f660522 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 2346.3 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 4100.06 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 5702.650000000001 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 8749.730000000001 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 3491.38 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 5523.360000000001 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 7222.02 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 8923.62 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 2861.95 +Manufacturer#3 almond antique metallic orange dim 19 3 3 4272.34 +Manufacturer#3 almond antique misty red olive 1 4 4 6195.32 +Manufacturer#3 almond antique olive coral navajo 45 5 5 7532.61 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 2 2 2996.09 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 4202.35 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 6047.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 7337.620000000001 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 2 2 3401.3500000000004 +Manufacturer#5 almond antique sky peru orange 2 3 3 5190.08 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 6208.18 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 10. testHavingWithWindowingCondRankNoGBY-0-fef4bf638d52a9a601845347010602fd b/sql/hive/src/test/resources/golden/windowing.q -- 10. testHavingWithWindowingCondRankNoGBY-0-fef4bf638d52a9a601845347010602fd new file mode 100644 index 000000000000..850c41c8115d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 10. testHavingWithWindowingCondRankNoGBY-0-fef4bf638d52a9a601845347010602fd @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 2346.3 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 4100.06 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 5702.650000000001 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 8749.730000000001 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 3491.38 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 5523.360000000001 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 7222.02 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 8923.62 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 2861.95 +Manufacturer#3 almond antique metallic orange dim 19 3 3 4272.34 +Manufacturer#3 almond antique misty red olive 1 4 4 6195.32 +Manufacturer#3 almond antique olive coral navajo 45 5 5 7532.61 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 2 2 2996.09 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 4202.35 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 6047.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 7337.620000000001 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 2 2 3401.3500000000004 +Manufacturer#5 almond antique sky peru orange 2 3 3 5190.08 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 6208.18 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 11. testFirstLast-0-86bb9c97d92fdcd941bcb5143513e2e6 b/sql/hive/src/test/resources/golden/windowing.q -- 11. testFirstLast-0-86bb9c97d92fdcd941bcb5143513e2e6 new file mode 100644 index 000000000000..921679cdcf56 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 11. testFirstLast-0-86bb9c97d92fdcd941bcb5143513e2e6 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 2 2 34 +Manufacturer#1 almond antique burnished rose metallic 2 2 2 6 +Manufacturer#1 almond antique chartreuse lavender yellow 34 34 2 28 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 6 2 42 +Manufacturer#1 almond aquamarine burnished black steel 28 28 34 42 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 42 6 42 +Manufacturer#2 almond antique violet chocolate turquoise 14 14 14 2 +Manufacturer#2 almond antique violet turquoise frosted 40 40 14 25 +Manufacturer#2 almond aquamarine midnight light salmon 2 2 14 18 +Manufacturer#2 almond aquamarine rose maroon antique 25 25 40 18 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 18 2 18 +Manufacturer#3 almond antique chartreuse khaki white 17 17 17 19 +Manufacturer#3 almond antique forest lavender goldenrod 14 14 17 1 +Manufacturer#3 almond antique metallic orange dim 19 19 17 45 +Manufacturer#3 almond antique misty red olive 1 1 14 45 +Manufacturer#3 almond antique olive coral navajo 45 45 19 45 +Manufacturer#4 almond antique gainsboro frosted violet 10 10 10 27 +Manufacturer#4 almond antique violet mint lemon 39 39 10 7 +Manufacturer#4 almond aquamarine floral ivory bisque 27 27 10 12 +Manufacturer#4 almond aquamarine yellow dodger mint 7 7 39 12 +Manufacturer#4 almond azure aquamarine papaya violet 12 12 27 12 +Manufacturer#5 almond antique blue firebrick mint 31 31 31 2 +Manufacturer#5 almond antique medium spring khaki 6 6 31 46 +Manufacturer#5 almond antique sky peru orange 2 2 31 23 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 46 6 23 +Manufacturer#5 almond azure blanched chiffon midnight 23 23 2 23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 12. testFirstLastWithWhere-0-84345a9f685ba63b87caa4bb16b122b5 b/sql/hive/src/test/resources/golden/windowing.q -- 12. testFirstLastWithWhere-0-84345a9f685ba63b87caa4bb16b122b5 new file mode 100644 index 000000000000..09e30c7c5734 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 12. testFirstLastWithWhere-0-84345a9f685ba63b87caa4bb16b122b5 @@ -0,0 +1,5 @@ +Manufacturer#3 almond antique chartreuse khaki white 17 1 17 17 19 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 14 17 1 +Manufacturer#3 almond antique metallic orange dim 19 3 19 17 45 +Manufacturer#3 almond antique misty red olive 1 4 1 14 45 +Manufacturer#3 almond antique olive coral navajo 45 5 45 19 45 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 13. testSumWindow-0-6cfc8840d3a4469b0fe11d63182cb59f b/sql/hive/src/test/resources/golden/windowing.q -- 13. testSumWindow-0-6cfc8840d3a4469b0fe11d63182cb59f new file mode 100644 index 000000000000..01ee88ff2330 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 13. testSumWindow-0-6cfc8840d3a4469b0fe11d63182cb59f @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 38 2 +Manufacturer#1 almond antique burnished rose metallic 2 44 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 72 34 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 112 6 +Manufacturer#1 almond aquamarine burnished black steel 28 110 28 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 76 42 +Manufacturer#2 almond antique violet chocolate turquoise 14 56 14 +Manufacturer#2 almond antique violet turquoise frosted 40 81 40 +Manufacturer#2 almond aquamarine midnight light salmon 2 99 2 +Manufacturer#2 almond aquamarine rose maroon antique 25 85 25 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 45 18 +Manufacturer#3 almond antique chartreuse khaki white 17 50 17 +Manufacturer#3 almond antique forest lavender goldenrod 14 51 14 +Manufacturer#3 almond antique metallic orange dim 19 96 19 +Manufacturer#3 almond antique misty red olive 1 79 1 +Manufacturer#3 almond antique olive coral navajo 45 65 45 +Manufacturer#4 almond antique gainsboro frosted violet 10 76 10 +Manufacturer#4 almond antique violet mint lemon 39 83 39 +Manufacturer#4 almond aquamarine floral ivory bisque 27 95 27 +Manufacturer#4 almond aquamarine yellow dodger mint 7 85 7 +Manufacturer#4 almond azure aquamarine papaya violet 12 46 12 +Manufacturer#5 almond antique blue firebrick mint 31 39 31 +Manufacturer#5 almond antique medium spring khaki 6 85 6 +Manufacturer#5 almond antique sky peru orange 2 108 2 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 77 46 +Manufacturer#5 almond azure blanched chiffon midnight 23 71 23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 14. testNoSortClause-0-88d96a526d3cae6ed8168c5b228974d1 b/sql/hive/src/test/resources/golden/windowing.q -- 14. testNoSortClause-0-88d96a526d3cae6ed8168c5b228974d1 new file mode 100644 index 000000000000..c78eb640c9c2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 14. testNoSortClause-0-88d96a526d3cae6ed8168c5b228974d1 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 +Manufacturer#3 almond antique metallic orange dim 19 3 3 +Manufacturer#3 almond antique misty red olive 1 4 4 +Manufacturer#3 almond antique olive coral navajo 45 5 5 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 +Manufacturer#4 almond antique violet mint lemon 39 2 2 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 +Manufacturer#5 almond antique medium spring khaki 6 2 2 +Manufacturer#5 almond antique sky peru orange 2 3 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 15. testExpressions-0-11f6c13cf2710ce7054654cca136e73e b/sql/hive/src/test/resources/golden/windowing.q -- 15. testExpressions-0-11f6c13cf2710ce7054654cca136e73e new file mode 100644 index 000000000000..050138ccf04c --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 15. testExpressions-0-11f6c13cf2710ce7054654cca136e73e @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 0.3333333333333333 0.0 1 2 2.0 0.0 2 2 2 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 0.3333333333333333 0.0 1 2 2.0 0.0 2 2 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 0.5 0.4 2 3 12.666666666666666 15.084944665313014 2 34 2 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 0.6666666666666666 0.6 2 4 11.0 13.379088160259652 2 6 2 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 0.8333333333333334 0.8 3 5 14.4 13.763720427268202 2 28 34 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 1.0 1.0 3 6 19.0 16.237815945091466 2 42 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 0.2 0.0 1 1 14.0 0.0 4 14 14 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 0.4 0.25 1 2 27.0 13.0 4 40 14 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 0.6 0.5 2 3 18.666666666666668 15.86050300449376 4 2 14 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 0.8 0.75 2 4 20.25 14.00669482783144 4 25 40 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 1.0 1.0 3 5 19.8 12.560254774486067 4 18 2 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 0.2 0.0 1 1 17.0 0.0 2 17 17 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 0.4 0.25 1 2 15.5 1.5 2 14 17 +Manufacturer#3 almond antique metallic orange dim 19 3 3 0.6 0.5 2 3 16.666666666666668 2.0548046676563256 2 19 17 +Manufacturer#3 almond antique misty red olive 1 4 4 0.8 0.75 2 4 12.75 7.013380069552769 2 1 14 +Manufacturer#3 almond antique olive coral navajo 45 5 5 1.0 1.0 3 5 19.2 14.344336861632886 2 45 19 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 0.2 0.0 1 1 10.0 0.0 0 10 10 +Manufacturer#4 almond antique violet mint lemon 39 2 2 0.4 0.25 1 2 24.5 14.5 0 39 10 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 0.6 0.5 2 3 25.333333333333332 11.897712198383164 0 27 10 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 0.8 0.75 2 4 20.75 13.007209539328564 0 7 39 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 1.0 1.0 3 5 19.0 12.149074038789951 0 12 27 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 0.2 0.0 1 1 31.0 0.0 1 31 31 +Manufacturer#5 almond antique medium spring khaki 6 2 2 0.4 0.25 1 2 18.5 12.5 1 6 31 +Manufacturer#5 almond antique sky peru orange 2 3 3 0.6 0.5 2 3 13.0 12.832251036613439 1 2 31 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 0.8 0.75 2 4 21.25 18.102140757380052 1 46 6 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 1.0 1.0 3 5 21.6 16.206171663906314 1 23 2 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 16. testMultipleWindows-0-efd1476255eeb1b1961149144f574b7a b/sql/hive/src/test/resources/golden/windowing.q -- 16. testMultipleWindows-0-efd1476255eeb1b1961149144f574b7a new file mode 100644 index 000000000000..c10888852b50 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 16. testMultipleWindows-0-efd1476255eeb1b1961149144f574b7a @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 0.3333333333333333 4 4 2 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 0.3333333333333333 4 4 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 0.5 38 34 2 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 0.6666666666666666 44 10 2 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 0.8333333333333334 72 28 34 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 1.0 114 42 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 0.2 14 14 14 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 0.4 54 40 14 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 0.6 56 2 14 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 0.8 81 25 40 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 1.0 99 32 2 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 0.2 17 31 17 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 0.4 31 14 17 +Manufacturer#3 almond antique metallic orange dim 19 3 3 0.6 50 50 17 +Manufacturer#3 almond antique misty red olive 1 4 4 0.8 51 1 14 +Manufacturer#3 almond antique olive coral navajo 45 5 5 1.0 96 45 19 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 0.2 10 17 10 +Manufacturer#4 almond antique violet mint lemon 39 2 2 0.4 49 39 10 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 0.6 76 27 10 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 0.8 83 7 39 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 1.0 95 29 27 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 0.2 31 31 31 +Manufacturer#5 almond antique medium spring khaki 6 2 2 0.4 37 8 31 +Manufacturer#5 almond antique sky peru orange 2 3 3 0.6 39 2 31 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 0.8 85 46 6 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 1.0 108 23 2 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 17. testCountStar-0-1b1fc185c8fddf68e58e92f29052ab2d b/sql/hive/src/test/resources/golden/windowing.q -- 17. testCountStar-0-1b1fc185c8fddf68e58e92f29052ab2d new file mode 100644 index 000000000000..b1309a497d68 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 17. testCountStar-0-1b1fc185c8fddf68e58e92f29052ab2d @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 2 2 2 +Manufacturer#1 almond antique burnished rose metallic 2 2 2 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 3 2 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 4 2 +Manufacturer#1 almond aquamarine burnished black steel 28 5 5 34 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 6 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 14 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 14 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 14 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 40 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 2 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 17 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 17 +Manufacturer#3 almond antique metallic orange dim 19 3 3 17 +Manufacturer#3 almond antique misty red olive 1 4 4 14 +Manufacturer#3 almond antique olive coral navajo 45 5 5 19 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 10 +Manufacturer#4 almond antique violet mint lemon 39 2 2 10 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 10 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 39 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 27 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 31 +Manufacturer#5 almond antique medium spring khaki 6 2 2 31 +Manufacturer#5 almond antique sky peru orange 2 3 3 31 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 6 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 2 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 18. testUDAFs-0-6974e5959e41a661e09db18547fef58a b/sql/hive/src/test/resources/golden/windowing.q -- 18. testUDAFs-0-6974e5959e41a661e09db18547fef58a new file mode 100644 index 000000000000..52d2ee8d0cd3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 18. testUDAFs-0-6974e5959e41a661e09db18547fef58a @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 4100.06 1173.15 1753.76 1366.6866666666667 +Manufacturer#1 almond antique burnished rose metallic 2 5702.650000000001 1173.15 1753.76 1425.6625000000001 +Manufacturer#1 almond antique chartreuse lavender yellow 34 7117.070000000001 1173.15 1753.76 1423.4140000000002 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 7576.58 1173.15 1753.76 1515.316 +Manufacturer#1 almond aquamarine burnished black steel 28 6403.43 1414.42 1753.76 1600.8575 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 4649.67 1414.42 1632.66 1549.89 +Manufacturer#2 almond antique violet chocolate turquoise 14 5523.360000000001 1690.68 2031.98 1841.1200000000001 +Manufacturer#2 almond antique violet turquoise frosted 40 7222.02 1690.68 2031.98 1805.505 +Manufacturer#2 almond aquamarine midnight light salmon 2 8923.62 1690.68 2031.98 1784.7240000000002 +Manufacturer#2 almond aquamarine rose maroon antique 25 7232.9400000000005 1698.66 2031.98 1808.2350000000001 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5432.24 1698.66 2031.98 1810.7466666666667 +Manufacturer#3 almond antique chartreuse khaki white 17 4272.34 1190.27 1671.68 1424.1133333333335 +Manufacturer#3 almond antique forest lavender goldenrod 14 6195.32 1190.27 1922.98 1548.83 +Manufacturer#3 almond antique metallic orange dim 19 7532.61 1190.27 1922.98 1506.522 +Manufacturer#3 almond antique misty red olive 1 5860.929999999999 1190.27 1922.98 1465.2324999999998 +Manufacturer#3 almond antique olive coral navajo 45 4670.66 1337.29 1922.98 1556.8866666666665 +Manufacturer#4 almond antique gainsboro frosted violet 10 4202.35 1206.26 1620.67 1400.7833333333335 +Manufacturer#4 almond antique violet mint lemon 39 6047.27 1206.26 1844.92 1511.8175 +Manufacturer#4 almond aquamarine floral ivory bisque 27 7337.620000000001 1206.26 1844.92 1467.5240000000001 +Manufacturer#4 almond aquamarine yellow dodger mint 7 5716.950000000001 1206.26 1844.92 1429.2375000000002 +Manufacturer#4 almond azure aquamarine papaya violet 12 4341.530000000001 1206.26 1844.92 1447.176666666667 +Manufacturer#5 almond antique blue firebrick mint 31 5190.08 1611.66 1789.69 1730.0266666666666 +Manufacturer#5 almond antique medium spring khaki 6 6208.18 1018.1 1789.69 1552.045 +Manufacturer#5 almond antique sky peru orange 2 7672.66 1018.1 1789.69 1534.532 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 5882.970000000001 1018.1 1788.73 1470.7425000000003 +Manufacturer#5 almond azure blanched chiffon midnight 23 4271.3099999999995 1018.1 1788.73 1423.7699999999998 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 19. testUDAFsWithGBY-0-67d15ee5915ac64a738fd4b60d75eb35 b/sql/hive/src/test/resources/golden/windowing.q -- 19. testUDAFsWithGBY-0-67d15ee5915ac64a738fd4b60d75eb35 new file mode 100644 index 000000000000..6461642d34a2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 19. testUDAFsWithGBY-0-67d15ee5915ac64a738fd4b60d75eb35 @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 4529.5 1173.15 1173.15 1509.8333333333333 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 5943.92 1753.76 1753.76 1485.98 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 7576.58 1602.59 1602.59 1515.316 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 6403.43 1414.42 1414.42 1600.8575 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 4649.67 1632.66 1632.66 1549.89 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 5523.360000000001 1690.68 1690.68 1841.1200000000001 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 7222.02 1800.7 1800.7 1805.505 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 8923.62 2031.98 2031.98 1784.7240000000002 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 7232.9400000000005 1698.66 1698.66 1808.2350000000001 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5432.24 1701.6 1701.6 1810.7466666666667 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 4272.34 1671.68 1671.68 1424.1133333333335 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 6195.32 1190.27 1190.27 1548.83 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 7532.61 1410.39 1410.39 1506.522 +Manufacturer#3 almond antique misty red olive 1 1922.98 5860.929999999999 1922.98 1922.98 1465.2324999999998 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 4670.66 1337.29 1337.29 1556.8866666666665 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 4202.35 1620.67 1620.67 1400.7833333333335 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 6047.27 1375.42 1375.42 1511.8175 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 7337.620000000001 1206.26 1206.26 1467.5240000000001 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 5716.950000000001 1844.92 1844.92 1429.2375000000002 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 4341.530000000001 1290.35 1290.35 1447.176666666667 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 5190.08 1789.69 1789.69 1730.0266666666666 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 6208.18 1611.66 1611.66 1552.045 +Manufacturer#5 almond antique sky peru orange 2 1788.73 7672.66 1788.73 1788.73 1534.532 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 5882.970000000001 1018.1 1018.1 1470.7425000000003 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 4271.3099999999995 1464.48 1464.48 1423.7699999999998 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-cb5618b1e626f3a9d4a030b508b5d251 b/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-cb5618b1e626f3a9d4a030b508b5d251 new file mode 100644 index 000000000000..2c30e652aa26 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-cb5618b1e626f3a9d4a030b508b5d251 @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 1 1 2 0 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 2 2 34 32 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 3 3 6 -28 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 4 4 28 22 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 5 5 42 14 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 1 1 14 0 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 2 2 40 26 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 3 3 2 -38 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 4 4 25 23 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5 5 18 -7 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 1 1 17 0 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 2 2 14 -3 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 3 3 19 5 +Manufacturer#3 almond antique misty red olive 1 1922.98 4 4 1 -18 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 5 5 45 44 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 1 1 10 0 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 2 2 39 29 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 3 3 27 -12 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 4 4 7 -20 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 5 5 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 1 1 31 0 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 2 2 6 -25 +Manufacturer#5 almond antique sky peru orange 2 1788.73 3 3 2 -4 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 4 4 46 44 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 5 5 23 -23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 new file mode 100644 index 000000000000..7e5fceeddeee --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 @@ -0,0 +1,97 @@ +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 2 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 6 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 34 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 2 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 34 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 2 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 6 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 28 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 34 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 2 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 6 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 28 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 34 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 42 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 6 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 28 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 34 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 42 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 6 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 28 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 42 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 2 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 14 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 40 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 2 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 14 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 25 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 40 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 2 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 14 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 18 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 25 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 40 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 2 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 18 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 25 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 40 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 2 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 18 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 25 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 14 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 17 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 19 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 1 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 14 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 17 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 19 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 1 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 14 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 17 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 19 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 45 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 1 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 14 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 19 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 45 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 1 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 19 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 45 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 10 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 27 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 39 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 7 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 10 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 27 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 39 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 7 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 10 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 12 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 27 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 39 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 7 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 12 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 27 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 39 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 7 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 12 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 27 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 2 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 6 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 31 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 2 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 6 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 31 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 46 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 2 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 6 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 23 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 31 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 46 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 2 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 6 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 23 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 46 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 2 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 23 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 46 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad new file mode 100644 index 000000000000..e7c39f454fb3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1753.76,"y":1.0}] 121152.0 1 +Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] 115872.0 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 [{"x":1173.15,"y":2.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] 110592.0 3 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 [{"x":1173.15,"y":1.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] 86428.0 4 +Manufacturer#1 almond aquamarine burnished black steel 28 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] 86098.0 5 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0}] 86428.0 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 [{"x":1690.68,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 1 +Manufacturer#2 almond antique violet turquoise frosted 40 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 139825.5 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 3 +Manufacturer#2 almond aquamarine rose maroon antique 25 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 169347.0 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 5 +Manufacturer#3 almond antique chartreuse khaki white 17 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0}] 90681.0 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] 65831.5 2 +Manufacturer#3 almond antique metallic orange dim 19 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] 90681.0 3 +Manufacturer#3 almond antique misty red olive 1 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] 76690.0 4 +Manufacturer#3 almond antique olive coral navajo 45 [{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] 112398.0 5 +Manufacturer#4 almond antique gainsboro frosted violet 10 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0}] 48427.0 1 +Manufacturer#4 almond antique violet mint lemon 39 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] 46844.0 2 +Manufacturer#4 almond aquamarine floral ivory bisque 27 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] 45261.0 3 +Manufacturer#4 almond aquamarine yellow dodger mint 7 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1844.92,"y":1.0}] 39309.0 4 +Manufacturer#4 almond azure aquamarine papaya violet 12 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1844.92,"y":1.0}] 33357.0 5 +Manufacturer#5 almond antique blue firebrick mint 31 [{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 155733.0 1 +Manufacturer#5 almond antique medium spring khaki 6 [{"x":1018.1,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 99201.0 2 +Manufacturer#5 almond antique sky peru orange 2 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 78486.0 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0}] 60577.5 4 +Manufacturer#5 almond azure blanched chiffon midnight 23 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1788.73,"y":1.0}] 78486.0 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 24. testLateralViews-0-dea06072f0a64fe4537fae854944ed5a b/sql/hive/src/test/resources/golden/windowing.q -- 24. testLateralViews-0-dea06072f0a64fe4537fae854944ed5a new file mode 100644 index 000000000000..dc83c9fffe93 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 24. testLateralViews-0-dea06072f0a64fe4537fae854944ed5a @@ -0,0 +1,78 @@ +Manufacturer#1 almond antique burnished rose metallic 1 2 2 +Manufacturer#1 almond antique burnished rose metallic 1 2 4 +Manufacturer#1 almond antique burnished rose metallic 2 2 6 +Manufacturer#1 almond antique burnished rose metallic 2 2 6 +Manufacturer#1 almond antique burnished rose metallic 3 2 6 +Manufacturer#1 almond antique burnished rose metallic 3 2 6 +Manufacturer#1 almond antique salmon chartreuse burlywood 1 6 10 +Manufacturer#1 almond antique salmon chartreuse burlywood 2 6 14 +Manufacturer#1 almond antique salmon chartreuse burlywood 3 6 18 +Manufacturer#1 almond aquamarine burnished black steel 1 28 40 +Manufacturer#1 almond aquamarine burnished black steel 2 28 62 +Manufacturer#1 almond aquamarine burnished black steel 3 28 84 +Manufacturer#1 almond antique chartreuse lavender yellow 1 34 90 +Manufacturer#1 almond antique chartreuse lavender yellow 2 34 96 +Manufacturer#1 almond antique chartreuse lavender yellow 3 34 102 +Manufacturer#1 almond aquamarine pink moccasin thistle 1 42 110 +Manufacturer#1 almond aquamarine pink moccasin thistle 2 42 118 +Manufacturer#1 almond aquamarine pink moccasin thistle 3 42 126 +Manufacturer#2 almond aquamarine midnight light salmon 1 2 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 2 4 +Manufacturer#2 almond aquamarine midnight light salmon 3 2 6 +Manufacturer#2 almond antique violet chocolate turquoise 1 14 18 +Manufacturer#2 almond antique violet chocolate turquoise 2 14 30 +Manufacturer#2 almond antique violet chocolate turquoise 3 14 42 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 1 18 46 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 2 18 50 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 3 18 54 +Manufacturer#2 almond aquamarine rose maroon antique 1 25 61 +Manufacturer#2 almond aquamarine rose maroon antique 2 25 68 +Manufacturer#2 almond aquamarine rose maroon antique 3 25 75 +Manufacturer#2 almond antique violet turquoise frosted 1 40 90 +Manufacturer#2 almond antique violet turquoise frosted 2 40 105 +Manufacturer#2 almond antique violet turquoise frosted 3 40 120 +Manufacturer#3 almond antique misty red olive 1 1 1 +Manufacturer#3 almond antique misty red olive 2 1 2 +Manufacturer#3 almond antique misty red olive 3 1 3 +Manufacturer#3 almond antique forest lavender goldenrod 1 14 16 +Manufacturer#3 almond antique forest lavender goldenrod 2 14 29 +Manufacturer#3 almond antique forest lavender goldenrod 3 14 42 +Manufacturer#3 almond antique chartreuse khaki white 1 17 45 +Manufacturer#3 almond antique chartreuse khaki white 2 17 48 +Manufacturer#3 almond antique chartreuse khaki white 3 17 51 +Manufacturer#3 almond antique metallic orange dim 1 19 53 +Manufacturer#3 almond antique metallic orange dim 2 19 55 +Manufacturer#3 almond antique metallic orange dim 3 19 57 +Manufacturer#3 almond antique olive coral navajo 1 45 83 +Manufacturer#3 almond antique olive coral navajo 2 45 109 +Manufacturer#3 almond antique olive coral navajo 3 45 135 +Manufacturer#4 almond aquamarine yellow dodger mint 1 7 7 +Manufacturer#4 almond aquamarine yellow dodger mint 2 7 14 +Manufacturer#4 almond aquamarine yellow dodger mint 3 7 21 +Manufacturer#4 almond antique gainsboro frosted violet 1 10 24 +Manufacturer#4 almond antique gainsboro frosted violet 2 10 27 +Manufacturer#4 almond antique gainsboro frosted violet 3 10 30 +Manufacturer#4 almond azure aquamarine papaya violet 1 12 32 +Manufacturer#4 almond azure aquamarine papaya violet 2 12 34 +Manufacturer#4 almond azure aquamarine papaya violet 3 12 36 +Manufacturer#4 almond aquamarine floral ivory bisque 1 27 51 +Manufacturer#4 almond aquamarine floral ivory bisque 2 27 66 +Manufacturer#4 almond aquamarine floral ivory bisque 3 27 81 +Manufacturer#4 almond antique violet mint lemon 1 39 93 +Manufacturer#4 almond antique violet mint lemon 2 39 105 +Manufacturer#4 almond antique violet mint lemon 3 39 117 +Manufacturer#5 almond antique sky peru orange 1 2 2 +Manufacturer#5 almond antique sky peru orange 2 2 4 +Manufacturer#5 almond antique sky peru orange 3 2 6 +Manufacturer#5 almond antique medium spring khaki 1 6 10 +Manufacturer#5 almond antique medium spring khaki 2 6 14 +Manufacturer#5 almond antique medium spring khaki 3 6 18 +Manufacturer#5 almond azure blanched chiffon midnight 1 23 35 +Manufacturer#5 almond azure blanched chiffon midnight 2 23 52 +Manufacturer#5 almond azure blanched chiffon midnight 3 23 69 +Manufacturer#5 almond antique blue firebrick mint 1 31 77 +Manufacturer#5 almond antique blue firebrick mint 2 31 85 +Manufacturer#5 almond antique blue firebrick mint 3 31 93 +Manufacturer#5 almond aquamarine dodger light gainsboro 1 46 108 +Manufacturer#5 almond aquamarine dodger light gainsboro 2 46 123 +Manufacturer#5 almond aquamarine dodger light gainsboro 3 46 138 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 26. testGroupByHavingWithSWQAndAlias-0-b996a664b06e5741c08079d5c38241bc b/sql/hive/src/test/resources/golden/windowing.q -- 26. testGroupByHavingWithSWQAndAlias-0-b996a664b06e5741c08079d5c38241bc new file mode 100644 index 000000000000..2c30e652aa26 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 26. testGroupByHavingWithSWQAndAlias-0-b996a664b06e5741c08079d5c38241bc @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 1 1 2 0 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 2 2 34 32 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 3 3 6 -28 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 4 4 28 22 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 5 5 42 14 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 1 1 14 0 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 2 2 40 26 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 3 3 2 -38 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 4 4 25 23 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5 5 18 -7 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 1 1 17 0 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 2 2 14 -3 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 3 3 19 5 +Manufacturer#3 almond antique misty red olive 1 1922.98 4 4 1 -18 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 5 5 45 44 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 1 1 10 0 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 2 2 39 29 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 3 3 27 -12 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 4 4 7 -20 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 5 5 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 1 1 31 0 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 2 2 6 -25 +Manufacturer#5 almond antique sky peru orange 2 1788.73 3 3 2 -4 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 4 4 46 44 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 5 5 23 -23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 27. testMultipleRangeWindows-0-227e080e337d734dd88ff814b3b412e4 b/sql/hive/src/test/resources/golden/windowing.q -- 27. testMultipleRangeWindows-0-227e080e337d734dd88ff814b3b412e4 new file mode 100644 index 000000000000..b2a91ba727a7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 27. testMultipleRangeWindows-0-227e080e337d734dd88ff814b3b412e4 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 4 10 +Manufacturer#1 almond antique burnished rose metallic 2 4 10 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 10 6 +Manufacturer#1 almond aquamarine burnished black steel 28 28 62 +Manufacturer#1 almond antique chartreuse lavender yellow 34 62 76 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 76 42 +Manufacturer#2 almond aquamarine midnight light salmon 2 2 2 +Manufacturer#2 almond antique violet chocolate turquoise 14 14 32 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 32 43 +Manufacturer#2 almond aquamarine rose maroon antique 25 43 25 +Manufacturer#2 almond antique violet turquoise frosted 40 40 40 +Manufacturer#3 almond antique misty red olive 1 1 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 14 50 +Manufacturer#3 almond antique chartreuse khaki white 17 31 36 +Manufacturer#3 almond antique metallic orange dim 19 50 19 +Manufacturer#3 almond antique olive coral navajo 45 45 45 +Manufacturer#4 almond aquamarine yellow dodger mint 7 7 29 +Manufacturer#4 almond antique gainsboro frosted violet 10 17 22 +Manufacturer#4 almond azure aquamarine papaya violet 12 29 12 +Manufacturer#4 almond aquamarine floral ivory bisque 27 27 27 +Manufacturer#4 almond antique violet mint lemon 39 39 39 +Manufacturer#5 almond antique sky peru orange 2 2 8 +Manufacturer#5 almond antique medium spring khaki 6 8 6 +Manufacturer#5 almond azure blanched chiffon midnight 23 23 54 +Manufacturer#5 almond antique blue firebrick mint 31 54 31 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 46 46 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 28. testPartOrderInUDAFInvoke-0-25912ae7d18c91cc09e17e57968fb5db b/sql/hive/src/test/resources/golden/windowing.q -- 28. testPartOrderInUDAFInvoke-0-25912ae7d18c91cc09e17e57968fb5db new file mode 100644 index 000000000000..5bcb0fa941d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 28. testPartOrderInUDAFInvoke-0-25912ae7d18c91cc09e17e57968fb5db @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 38 +Manufacturer#1 almond antique burnished rose metallic 2 44 +Manufacturer#1 almond antique chartreuse lavender yellow 34 72 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 112 +Manufacturer#1 almond aquamarine burnished black steel 28 110 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 76 +Manufacturer#2 almond antique violet chocolate turquoise 14 56 +Manufacturer#2 almond antique violet turquoise frosted 40 81 +Manufacturer#2 almond aquamarine midnight light salmon 2 99 +Manufacturer#2 almond aquamarine rose maroon antique 25 85 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 45 +Manufacturer#3 almond antique chartreuse khaki white 17 50 +Manufacturer#3 almond antique forest lavender goldenrod 14 51 +Manufacturer#3 almond antique metallic orange dim 19 96 +Manufacturer#3 almond antique misty red olive 1 79 +Manufacturer#3 almond antique olive coral navajo 45 65 +Manufacturer#4 almond antique gainsboro frosted violet 10 76 +Manufacturer#4 almond antique violet mint lemon 39 83 +Manufacturer#4 almond aquamarine floral ivory bisque 27 95 +Manufacturer#4 almond aquamarine yellow dodger mint 7 85 +Manufacturer#4 almond azure aquamarine papaya violet 12 46 +Manufacturer#5 almond antique blue firebrick mint 31 39 +Manufacturer#5 almond antique medium spring khaki 6 85 +Manufacturer#5 almond antique sky peru orange 2 108 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 77 +Manufacturer#5 almond azure blanched chiffon midnight 23 71 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 29. testPartOrderInWdwDef-0-88945892370ccbc1125a927a3d55342a b/sql/hive/src/test/resources/golden/windowing.q -- 29. testPartOrderInWdwDef-0-88945892370ccbc1125a927a3d55342a new file mode 100644 index 000000000000..5bcb0fa941d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 29. testPartOrderInWdwDef-0-88945892370ccbc1125a927a3d55342a @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 38 +Manufacturer#1 almond antique burnished rose metallic 2 44 +Manufacturer#1 almond antique chartreuse lavender yellow 34 72 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 112 +Manufacturer#1 almond aquamarine burnished black steel 28 110 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 76 +Manufacturer#2 almond antique violet chocolate turquoise 14 56 +Manufacturer#2 almond antique violet turquoise frosted 40 81 +Manufacturer#2 almond aquamarine midnight light salmon 2 99 +Manufacturer#2 almond aquamarine rose maroon antique 25 85 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 45 +Manufacturer#3 almond antique chartreuse khaki white 17 50 +Manufacturer#3 almond antique forest lavender goldenrod 14 51 +Manufacturer#3 almond antique metallic orange dim 19 96 +Manufacturer#3 almond antique misty red olive 1 79 +Manufacturer#3 almond antique olive coral navajo 45 65 +Manufacturer#4 almond antique gainsboro frosted violet 10 76 +Manufacturer#4 almond antique violet mint lemon 39 83 +Manufacturer#4 almond aquamarine floral ivory bisque 27 95 +Manufacturer#4 almond aquamarine yellow dodger mint 7 85 +Manufacturer#4 almond azure aquamarine papaya violet 12 46 +Manufacturer#5 almond antique blue firebrick mint 31 39 +Manufacturer#5 almond antique medium spring khaki 6 85 +Manufacturer#5 almond antique sky peru orange 2 108 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 77 +Manufacturer#5 almond azure blanched chiffon midnight 23 71 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 3. testGroupByHavingWithSWQ-0-a5a5339330a6a6660d32ccb0cc5d7100 b/sql/hive/src/test/resources/golden/windowing.q -- 3. testGroupByHavingWithSWQ-0-a5a5339330a6a6660d32ccb0cc5d7100 new file mode 100644 index 000000000000..2c30e652aa26 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 3. testGroupByHavingWithSWQ-0-a5a5339330a6a6660d32ccb0cc5d7100 @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 1 1 2 0 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 2 2 34 32 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 3 3 6 -28 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 4 4 28 22 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 5 5 42 14 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 1 1 14 0 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 2 2 40 26 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 3 3 2 -38 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 4 4 25 23 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5 5 18 -7 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 1 1 17 0 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 2 2 14 -3 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 3 3 19 5 +Manufacturer#3 almond antique misty red olive 1 1922.98 4 4 1 -18 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 5 5 45 44 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 1 1 10 0 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 2 2 39 29 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 3 3 27 -12 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 4 4 7 -20 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 5 5 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 1 1 31 0 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 2 2 6 -25 +Manufacturer#5 almond antique sky peru orange 2 1788.73 3 3 2 -4 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 4 4 46 44 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 5 5 23 -23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 30. testDefaultPartitioningSpecRules-0-fa80b09c99e3c1487de48ea71a88dada b/sql/hive/src/test/resources/golden/windowing.q -- 30. testDefaultPartitioningSpecRules-0-fa80b09c99e3c1487de48ea71a88dada new file mode 100644 index 000000000000..698a44349d2a --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 30. testDefaultPartitioningSpecRules-0-fa80b09c99e3c1487de48ea71a88dada @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 38 4 +Manufacturer#1 almond antique burnished rose metallic 2 44 4 +Manufacturer#1 almond antique chartreuse lavender yellow 34 72 38 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 112 44 +Manufacturer#1 almond aquamarine burnished black steel 28 110 72 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 76 114 +Manufacturer#2 almond antique violet chocolate turquoise 14 56 14 +Manufacturer#2 almond antique violet turquoise frosted 40 81 54 +Manufacturer#2 almond aquamarine midnight light salmon 2 99 56 +Manufacturer#2 almond aquamarine rose maroon antique 25 85 81 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 45 99 +Manufacturer#3 almond antique chartreuse khaki white 17 50 17 +Manufacturer#3 almond antique forest lavender goldenrod 14 51 31 +Manufacturer#3 almond antique metallic orange dim 19 96 50 +Manufacturer#3 almond antique misty red olive 1 79 51 +Manufacturer#3 almond antique olive coral navajo 45 65 96 +Manufacturer#4 almond antique gainsboro frosted violet 10 76 10 +Manufacturer#4 almond antique violet mint lemon 39 83 49 +Manufacturer#4 almond aquamarine floral ivory bisque 27 95 76 +Manufacturer#4 almond aquamarine yellow dodger mint 7 85 83 +Manufacturer#4 almond azure aquamarine papaya violet 12 46 95 +Manufacturer#5 almond antique blue firebrick mint 31 39 31 +Manufacturer#5 almond antique medium spring khaki 6 85 37 +Manufacturer#5 almond antique sky peru orange 2 108 39 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 77 85 +Manufacturer#5 almond azure blanched chiffon midnight 23 71 108 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 36. testRankWithPartitioning-0-45ccbaf0ee083858f7661c66b11d4768 b/sql/hive/src/test/resources/golden/windowing.q -- 36. testRankWithPartitioning-0-45ccbaf0ee083858f7661c66b11d4768 new file mode 100644 index 000000000000..e35257d98382 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 36. testRankWithPartitioning-0-45ccbaf0ee083858f7661c66b11d4768 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 +Manufacturer#1 almond antique burnished rose metallic 2 1 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 +Manufacturer#1 almond aquamarine burnished black steel 28 5 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 +Manufacturer#2 almond antique violet turquoise frosted 40 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 +Manufacturer#3 almond antique chartreuse khaki white 17 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 +Manufacturer#3 almond antique metallic orange dim 19 3 +Manufacturer#3 almond antique misty red olive 1 4 +Manufacturer#3 almond antique olive coral navajo 45 5 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 +Manufacturer#4 almond antique violet mint lemon 39 2 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1 +Manufacturer#5 almond antique medium spring khaki 6 2 +Manufacturer#5 almond antique sky peru orange 2 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 37. testPartitioningVariousForms-0-3436e50214f9afdec84334e10faa931a b/sql/hive/src/test/resources/golden/windowing.q -- 37. testPartitioningVariousForms-0-3436e50214f9afdec84334e10faa931a new file mode 100644 index 000000000000..9c0ca6c7a00b --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 37. testPartitioningVariousForms-0-3436e50214f9afdec84334e10faa931a @@ -0,0 +1,26 @@ +Manufacturer#1 8749.73 1173.15 1753.76 1458.29 6 +Manufacturer#1 8749.73 1173.15 1753.76 1458.29 6 +Manufacturer#1 8749.73 1173.15 1753.76 1458.29 6 +Manufacturer#1 8749.73 1173.15 1753.76 1458.29 6 +Manufacturer#1 8749.73 1173.15 1753.76 1458.29 6 +Manufacturer#1 8749.73 1173.15 1753.76 1458.29 6 +Manufacturer#2 8923.62 1690.68 2031.98 1784.72 5 +Manufacturer#2 8923.62 1690.68 2031.98 1784.72 5 +Manufacturer#2 8923.62 1690.68 2031.98 1784.72 5 +Manufacturer#2 8923.62 1690.68 2031.98 1784.72 5 +Manufacturer#2 8923.62 1690.68 2031.98 1784.72 5 +Manufacturer#3 7532.61 1190.27 1922.98 1506.52 5 +Manufacturer#3 7532.61 1190.27 1922.98 1506.52 5 +Manufacturer#3 7532.61 1190.27 1922.98 1506.52 5 +Manufacturer#3 7532.61 1190.27 1922.98 1506.52 5 +Manufacturer#3 7532.61 1190.27 1922.98 1506.52 5 +Manufacturer#4 7337.62 1206.26 1844.92 1467.52 5 +Manufacturer#4 7337.62 1206.26 1844.92 1467.52 5 +Manufacturer#4 7337.62 1206.26 1844.92 1467.52 5 +Manufacturer#4 7337.62 1206.26 1844.92 1467.52 5 +Manufacturer#4 7337.62 1206.26 1844.92 1467.52 5 +Manufacturer#5 7672.66 1018.1 1789.69 1534.53 5 +Manufacturer#5 7672.66 1018.1 1789.69 1534.53 5 +Manufacturer#5 7672.66 1018.1 1789.69 1534.53 5 +Manufacturer#5 7672.66 1018.1 1789.69 1534.53 5 +Manufacturer#5 7672.66 1018.1 1789.69 1534.53 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 38. testPartitioningVariousForms2-0-cba9d84a6b1bb5e36595338d4602377e b/sql/hive/src/test/resources/golden/windowing.q -- 38. testPartitioningVariousForms2-0-cba9d84a6b1bb5e36595338d4602377e new file mode 100644 index 000000000000..fc27df2f2b64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 38. testPartitioningVariousForms2-0-cba9d84a6b1bb5e36595338d4602377e @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 1173.15 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 2346.3 1173.15 1173.15 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 1753.76 1753.76 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 1602.59 1602.59 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 1414.42 1414.42 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 1632.66 1632.66 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 1690.68 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 1800.7 1800.7 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 2031.98 2031.98 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 1698.66 1698.66 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 1701.6 1701.6 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 1671.68 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 1190.27 1190.27 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 1410.39 1410.39 +Manufacturer#3 almond antique misty red olive 1 1922.98 1922.98 1922.98 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 1337.29 1337.29 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 1620.67 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 1375.42 1375.42 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 1206.26 1206.26 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 1844.92 1844.92 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 1290.35 1290.35 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 1789.69 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 1611.66 1611.66 +Manufacturer#5 almond antique sky peru orange 2 1788.73 1788.73 1788.73 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 1018.1 1018.1 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 1464.48 1464.48 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 39. testUDFOnOrderCols-0-7647562850dd367ef1e6c63117805423 b/sql/hive/src/test/resources/golden/windowing.q -- 39. testUDFOnOrderCols-0-7647562850dd367ef1e6c63117805423 new file mode 100644 index 000000000000..e5a541f56f6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 39. testUDFOnOrderCols-0-7647562850dd367ef1e6c63117805423 @@ -0,0 +1,26 @@ +Manufacturer#1 LARGE BRUSHED STEEL ARGE BRUSHED STEEL 1 +Manufacturer#1 LARGE BURNISHED STEEL ARGE BURNISHED STEEL 2 +Manufacturer#1 PROMO BURNISHED NICKEL ROMO BURNISHED NICKEL 3 +Manufacturer#1 PROMO PLATED TIN ROMO PLATED TIN 4 +Manufacturer#1 PROMO PLATED TIN ROMO PLATED TIN 4 +Manufacturer#1 STANDARD ANODIZED STEEL TANDARD ANODIZED STEEL 6 +Manufacturer#2 ECONOMY POLISHED STEEL CONOMY POLISHED STEEL 1 +Manufacturer#2 MEDIUM ANODIZED COPPER EDIUM ANODIZED COPPER 2 +Manufacturer#2 MEDIUM BURNISHED COPPER EDIUM BURNISHED COPPER 3 +Manufacturer#2 SMALL POLISHED NICKEL MALL POLISHED NICKEL 4 +Manufacturer#2 STANDARD PLATED TIN TANDARD PLATED TIN 5 +Manufacturer#3 ECONOMY PLATED COPPER CONOMY PLATED COPPER 1 +Manufacturer#3 MEDIUM BURNISHED BRASS EDIUM BURNISHED BRASS 2 +Manufacturer#3 MEDIUM BURNISHED TIN EDIUM BURNISHED TIN 3 +Manufacturer#3 PROMO ANODIZED TIN ROMO ANODIZED TIN 4 +Manufacturer#3 STANDARD POLISHED STEEL TANDARD POLISHED STEEL 5 +Manufacturer#4 ECONOMY BRUSHED COPPER CONOMY BRUSHED COPPER 1 +Manufacturer#4 SMALL BRUSHED BRASS MALL BRUSHED BRASS 2 +Manufacturer#4 SMALL PLATED STEEL MALL PLATED STEEL 3 +Manufacturer#4 PROMO POLISHED STEEL ROMO POLISHED STEEL 4 +Manufacturer#4 STANDARD ANODIZED TIN TANDARD ANODIZED TIN 5 +Manufacturer#5 LARGE BRUSHED BRASS ARGE BRUSHED BRASS 1 +Manufacturer#5 ECONOMY BURNISHED STEEL CONOMY BURNISHED STEEL 2 +Manufacturer#5 MEDIUM BURNISHED TIN EDIUM BURNISHED TIN 3 +Manufacturer#5 SMALL PLATED BRASS MALL PLATED BRASS 4 +Manufacturer#5 STANDARD BURNISHED TIN TANDARD BURNISHED TIN 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 4. testCount-0-e6e97e884327df86f16b870527ec026c b/sql/hive/src/test/resources/golden/windowing.q -- 4. testCount-0-e6e97e884327df86f16b870527ec026c new file mode 100644 index 000000000000..bf8e620a304a --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 4. testCount-0-e6e97e884327df86f16b870527ec026c @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 +Manufacturer#1 almond antique burnished rose metallic 2 +Manufacturer#1 almond antique chartreuse lavender yellow 3 +Manufacturer#1 almond antique salmon chartreuse burlywood 4 +Manufacturer#1 almond aquamarine burnished black steel 5 +Manufacturer#1 almond aquamarine pink moccasin thistle 6 +Manufacturer#2 almond antique violet chocolate turquoise 1 +Manufacturer#2 almond antique violet turquoise frosted 2 +Manufacturer#2 almond aquamarine midnight light salmon 3 +Manufacturer#2 almond aquamarine rose maroon antique 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 5 +Manufacturer#3 almond antique chartreuse khaki white 1 +Manufacturer#3 almond antique forest lavender goldenrod 2 +Manufacturer#3 almond antique metallic orange dim 3 +Manufacturer#3 almond antique misty red olive 4 +Manufacturer#3 almond antique olive coral navajo 5 +Manufacturer#4 almond antique gainsboro frosted violet 1 +Manufacturer#4 almond antique violet mint lemon 2 +Manufacturer#4 almond aquamarine floral ivory bisque 3 +Manufacturer#4 almond aquamarine yellow dodger mint 4 +Manufacturer#4 almond azure aquamarine papaya violet 5 +Manufacturer#5 almond antique blue firebrick mint 1 +Manufacturer#5 almond antique medium spring khaki 2 +Manufacturer#5 almond antique sky peru orange 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 4 +Manufacturer#5 almond azure blanched chiffon midnight 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 40. testNoBetweenForRows-0-99007f45b6406869e048b0e4eb9213f1 b/sql/hive/src/test/resources/golden/windowing.q -- 40. testNoBetweenForRows-0-99007f45b6406869e048b0e4eb9213f1 new file mode 100644 index 000000000000..1e29df62901d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 40. testNoBetweenForRows-0-99007f45b6406869e048b0e4eb9213f1 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 2346.3 +Manufacturer#1 almond antique chartreuse lavender yellow 34 4100.06 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 5702.650000000001 +Manufacturer#1 almond aquamarine burnished black steel 28 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 8749.730000000001 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 3491.38 +Manufacturer#2 almond aquamarine midnight light salmon 2 5523.360000000001 +Manufacturer#2 almond aquamarine rose maroon antique 25 7222.02 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 8923.62 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 2861.95 +Manufacturer#3 almond antique metallic orange dim 19 4272.34 +Manufacturer#3 almond antique misty red olive 1 6195.32 +Manufacturer#3 almond antique olive coral navajo 45 7532.61 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 2996.09 +Manufacturer#4 almond aquamarine floral ivory bisque 27 4202.35 +Manufacturer#4 almond aquamarine yellow dodger mint 7 6047.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 7337.620000000001 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 3401.3500000000004 +Manufacturer#5 almond antique sky peru orange 2 5190.08 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 6208.18 +Manufacturer#5 almond azure blanched chiffon midnight 23 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 41. testNoBetweenForRange-0-d81a591e90950de291d2f133793e9283 b/sql/hive/src/test/resources/golden/windowing.q -- 41. testNoBetweenForRange-0-d81a591e90950de291d2f133793e9283 new file mode 100644 index 000000000000..a620479fe406 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 41. testNoBetweenForRange-0-d81a591e90950de291d2f133793e9283 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 2346.3 +Manufacturer#1 almond antique burnished rose metallic 2 2346.3 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 3948.8900000000003 +Manufacturer#1 almond aquamarine burnished black steel 28 5363.31 +Manufacturer#1 almond antique chartreuse lavender yellow 34 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 8749.730000000001 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 +Manufacturer#2 almond antique violet chocolate turquoise 14 3722.66 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5424.26 +Manufacturer#2 almond aquamarine rose maroon antique 25 7122.92 +Manufacturer#2 almond antique violet turquoise frosted 40 8923.62 +Manufacturer#3 almond antique misty red olive 1 1922.98 +Manufacturer#3 almond antique forest lavender goldenrod 14 3113.25 +Manufacturer#3 almond antique chartreuse khaki white 17 4784.93 +Manufacturer#3 almond antique metallic orange dim 19 6195.320000000001 +Manufacturer#3 almond antique olive coral navajo 45 7532.610000000001 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 +Manufacturer#4 almond antique gainsboro frosted violet 10 3465.59 +Manufacturer#4 almond azure aquamarine papaya violet 12 4755.9400000000005 +Manufacturer#4 almond aquamarine floral ivory bisque 27 5962.200000000001 +Manufacturer#4 almond antique violet mint lemon 39 7337.620000000001 +Manufacturer#5 almond antique sky peru orange 2 1788.73 +Manufacturer#5 almond antique medium spring khaki 6 3400.3900000000003 +Manufacturer#5 almond azure blanched chiffon midnight 23 4864.870000000001 +Manufacturer#5 almond antique blue firebrick mint 31 6654.560000000001 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 7672.660000000002 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 42. testUnboundedFollowingForRows-0-fb8648e82e4dd56d6bdcfd739dd1edf0 b/sql/hive/src/test/resources/golden/windowing.q -- 42. testUnboundedFollowingForRows-0-fb8648e82e4dd56d6bdcfd739dd1edf0 new file mode 100644 index 000000000000..74147d2571a1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 42. testUnboundedFollowingForRows-0-fb8648e82e4dd56d6bdcfd739dd1edf0 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 8749.730000000001 +Manufacturer#1 almond antique burnished rose metallic 2 7576.58 +Manufacturer#1 almond antique chartreuse lavender yellow 34 6403.43 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4649.67 +Manufacturer#1 almond aquamarine burnished black steel 28 3047.08 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 +Manufacturer#2 almond antique violet chocolate turquoise 14 8923.62 +Manufacturer#2 almond antique violet turquoise frosted 40 7232.9400000000005 +Manufacturer#2 almond aquamarine midnight light salmon 2 5432.24 +Manufacturer#2 almond aquamarine rose maroon antique 25 3400.26 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 +Manufacturer#3 almond antique chartreuse khaki white 17 7532.61 +Manufacturer#3 almond antique forest lavender goldenrod 14 5860.929999999999 +Manufacturer#3 almond antique metallic orange dim 19 4670.66 +Manufacturer#3 almond antique misty red olive 1 3260.27 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 +Manufacturer#4 almond antique gainsboro frosted violet 10 7337.620000000001 +Manufacturer#4 almond antique violet mint lemon 39 5716.950000000001 +Manufacturer#4 almond aquamarine floral ivory bisque 27 4341.530000000001 +Manufacturer#4 almond aquamarine yellow dodger mint 7 3135.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 +Manufacturer#5 almond antique blue firebrick mint 31 7672.66 +Manufacturer#5 almond antique medium spring khaki 6 5882.970000000001 +Manufacturer#5 almond antique sky peru orange 2 4271.3099999999995 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 2482.58 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 43. testUnboundedFollowingForRange-0-3cd04e5f2398853c4850f4f86142bb39 b/sql/hive/src/test/resources/golden/windowing.q -- 43. testUnboundedFollowingForRange-0-3cd04e5f2398853c4850f4f86142bb39 new file mode 100644 index 000000000000..49d003b5de13 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 43. testUnboundedFollowingForRange-0-3cd04e5f2398853c4850f4f86142bb39 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 8749.730000000001 +Manufacturer#1 almond antique burnished rose metallic 2 8749.730000000001 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 6403.43 +Manufacturer#1 almond aquamarine burnished black steel 28 4800.84 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3386.42 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 +Manufacturer#2 almond aquamarine midnight light salmon 2 8923.62 +Manufacturer#2 almond antique violet chocolate turquoise 14 6891.639999999999 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5200.96 +Manufacturer#2 almond aquamarine rose maroon antique 25 3499.36 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 +Manufacturer#3 almond antique misty red olive 1 7532.610000000001 +Manufacturer#3 almond antique forest lavender goldenrod 14 5609.63 +Manufacturer#3 almond antique chartreuse khaki white 17 4419.360000000001 +Manufacturer#3 almond antique metallic orange dim 19 2747.6800000000003 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 +Manufacturer#4 almond aquamarine yellow dodger mint 7 7337.620000000001 +Manufacturer#4 almond antique gainsboro frosted violet 10 5492.7 +Manufacturer#4 almond azure aquamarine papaya violet 12 3872.0299999999997 +Manufacturer#4 almond aquamarine floral ivory bisque 27 2581.6800000000003 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 +Manufacturer#5 almond antique sky peru orange 2 7672.660000000002 +Manufacturer#5 almond antique medium spring khaki 6 5883.93 +Manufacturer#5 almond azure blanched chiffon midnight 23 4272.27 +Manufacturer#5 almond antique blue firebrick mint 31 2807.79 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 44. testOverNoPartitionSingleAggregate-0-cb3d2f8c1296044dc2658876bb6103ae b/sql/hive/src/test/resources/golden/windowing.q -- 44. testOverNoPartitionSingleAggregate-0-cb3d2f8c1296044dc2658876bb6103ae new file mode 100644 index 000000000000..5982c9ee2a4d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 44. testOverNoPartitionSingleAggregate-0-cb3d2f8c1296044dc2658876bb6103ae @@ -0,0 +1,26 @@ +almond antique blue firebrick mint 1789.69 1546.78 +almond antique burnished rose metallic 1173.15 1546.78 +almond antique burnished rose metallic 1173.15 1546.78 +almond antique chartreuse khaki white 1671.68 1546.78 +almond antique chartreuse lavender yellow 1753.76 1546.78 +almond antique forest lavender goldenrod 1190.27 1546.78 +almond antique gainsboro frosted violet 1620.67 1546.78 +almond antique medium spring khaki 1611.66 1546.78 +almond antique metallic orange dim 1410.39 1546.78 +almond antique misty red olive 1922.98 1546.78 +almond antique olive coral navajo 1337.29 1546.78 +almond antique salmon chartreuse burlywood 1602.59 1546.78 +almond antique sky peru orange 1788.73 1546.78 +almond antique violet chocolate turquoise 1690.68 1546.78 +almond antique violet mint lemon 1375.42 1546.78 +almond antique violet turquoise frosted 1800.7 1546.78 +almond aquamarine burnished black steel 1414.42 1546.78 +almond aquamarine dodger light gainsboro 1018.1 1546.78 +almond aquamarine floral ivory bisque 1206.26 1546.78 +almond aquamarine midnight light salmon 2031.98 1546.78 +almond aquamarine pink moccasin thistle 1632.66 1546.78 +almond aquamarine rose maroon antique 1698.66 1546.78 +almond aquamarine sandy cyan gainsboro 1701.6 1546.78 +almond aquamarine yellow dodger mint 1844.92 1546.78 +almond azure aquamarine papaya violet 1290.35 1546.78 +almond azure blanched chiffon midnight 1464.48 1546.78 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 5. testCountWithWindowingUDAF-0-3bde93728761b780a745c2ce0398aa0f b/sql/hive/src/test/resources/golden/windowing.q -- 5. testCountWithWindowingUDAF-0-3bde93728761b780a745c2ce0398aa0f new file mode 100644 index 000000000000..00d41fc0bcd9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 5. testCountWithWindowingUDAF-0-3bde93728761b780a745c2ce0398aa0f @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 1 1 2 1173.15 1173.15 2 0 +Manufacturer#1 almond antique burnished rose metallic 1 1 2 1173.15 2346.3 2 0 +Manufacturer#1 almond antique chartreuse lavender yellow 3 2 3 1753.76 4100.06 34 32 +Manufacturer#1 almond antique salmon chartreuse burlywood 4 3 4 1602.59 5702.650000000001 6 -28 +Manufacturer#1 almond aquamarine burnished black steel 5 4 5 1414.42 7117.070000000001 28 22 +Manufacturer#1 almond aquamarine pink moccasin thistle 6 5 6 1632.66 8749.730000000001 42 14 +Manufacturer#2 almond antique violet chocolate turquoise 1 1 1 1690.68 1690.68 14 0 +Manufacturer#2 almond antique violet turquoise frosted 2 2 2 1800.7 3491.38 40 26 +Manufacturer#2 almond aquamarine midnight light salmon 3 3 3 2031.98 5523.360000000001 2 -38 +Manufacturer#2 almond aquamarine rose maroon antique 4 4 4 1698.66 7222.02 25 23 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 5 5 5 1701.6 8923.62 18 -7 +Manufacturer#3 almond antique chartreuse khaki white 1 1 1 1671.68 1671.68 17 0 +Manufacturer#3 almond antique forest lavender goldenrod 2 2 2 1190.27 2861.95 14 -3 +Manufacturer#3 almond antique metallic orange dim 3 3 3 1410.39 4272.34 19 5 +Manufacturer#3 almond antique misty red olive 4 4 4 1922.98 6195.32 1 -18 +Manufacturer#3 almond antique olive coral navajo 5 5 5 1337.29 7532.61 45 44 +Manufacturer#4 almond antique gainsboro frosted violet 1 1 1 1620.67 1620.67 10 0 +Manufacturer#4 almond antique violet mint lemon 2 2 2 1375.42 2996.09 39 29 +Manufacturer#4 almond aquamarine floral ivory bisque 3 3 3 1206.26 4202.35 27 -12 +Manufacturer#4 almond aquamarine yellow dodger mint 4 4 4 1844.92 6047.27 7 -20 +Manufacturer#4 almond azure aquamarine papaya violet 5 5 5 1290.35 7337.620000000001 12 5 +Manufacturer#5 almond antique blue firebrick mint 1 1 1 1789.69 1789.69 31 0 +Manufacturer#5 almond antique medium spring khaki 2 2 2 1611.66 3401.3500000000004 6 -25 +Manufacturer#5 almond antique sky peru orange 3 3 3 1788.73 5190.08 2 -4 +Manufacturer#5 almond aquamarine dodger light gainsboro 4 4 4 1018.1 6208.18 46 44 +Manufacturer#5 almond azure blanched chiffon midnight 5 5 5 1464.48 7672.66 23 -23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 6. testCountInSubQ-0-73d5274a21d4f4fd51d2a0f1d98516ce b/sql/hive/src/test/resources/golden/windowing.q -- 6. testCountInSubQ-0-73d5274a21d4f4fd51d2a0f1d98516ce new file mode 100644 index 000000000000..98c09e4fe15c --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 6. testCountInSubQ-0-73d5274a21d4f4fd51d2a0f1d98516ce @@ -0,0 +1,26 @@ +1 1 2 1173.15 0 +1 1 2 2346.3 0 +3 2 3 4100.06 32 +4 3 4 5702.650000000001 -28 +5 4 5 7117.070000000001 22 +6 5 6 8749.730000000001 14 +1 1 1 1690.68 0 +2 2 2 3491.38 26 +3 3 3 5523.360000000001 -38 +4 4 4 7222.02 23 +5 5 5 8923.62 -7 +1 1 1 1671.68 0 +2 2 2 2861.95 -3 +3 3 3 4272.34 5 +4 4 4 6195.32 -18 +5 5 5 7532.61 44 +1 1 1 1620.67 0 +2 2 2 2996.09 29 +3 3 3 4202.35 -12 +4 4 4 6047.27 -20 +5 5 5 7337.620000000001 5 +1 1 1 1789.69 0 +2 2 2 3401.3500000000004 -25 +3 3 3 5190.08 -4 +4 4 4 6208.18 44 +5 5 5 7672.66 -23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 8. testMixedCaseAlias-0-4b1ad2515fb079012467e987f484a722 b/sql/hive/src/test/resources/golden/windowing.q -- 8. testMixedCaseAlias-0-4b1ad2515fb079012467e987f484a722 new file mode 100644 index 000000000000..e35257d98382 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 8. testMixedCaseAlias-0-4b1ad2515fb079012467e987f484a722 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 +Manufacturer#1 almond antique burnished rose metallic 2 1 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 +Manufacturer#1 almond aquamarine burnished black steel 28 5 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 +Manufacturer#2 almond antique violet turquoise frosted 40 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 +Manufacturer#3 almond antique chartreuse khaki white 17 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 +Manufacturer#3 almond antique metallic orange dim 19 3 +Manufacturer#3 almond antique misty red olive 1 4 +Manufacturer#3 almond antique olive coral navajo 45 5 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 +Manufacturer#4 almond antique violet mint lemon 39 2 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1 +Manufacturer#5 almond antique medium spring khaki 6 2 +Manufacturer#5 almond antique sky peru orange 2 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 9. testHavingWithWindowingNoGBY-0-70cdc0555a61ef08534a9ebebb95ebbf b/sql/hive/src/test/resources/golden/windowing.q -- 9. testHavingWithWindowingNoGBY-0-70cdc0555a61ef08534a9ebebb95ebbf new file mode 100644 index 000000000000..850c41c8115d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 9. testHavingWithWindowingNoGBY-0-70cdc0555a61ef08534a9ebebb95ebbf @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 2346.3 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 4100.06 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 5702.650000000001 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 8749.730000000001 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 3491.38 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 5523.360000000001 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 7222.02 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 8923.62 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 2861.95 +Manufacturer#3 almond antique metallic orange dim 19 3 3 4272.34 +Manufacturer#3 almond antique misty red olive 1 4 4 6195.32 +Manufacturer#3 almond antique olive coral navajo 45 5 5 7532.61 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 2 2 2996.09 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 4202.35 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 6047.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 7337.620000000001 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 2 2 3401.3500000000004 +Manufacturer#5 almond antique sky peru orange 2 3 3 5190.08 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 6208.18 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-0-d3f50875bd5dff172cf813fdb7d738eb b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-0-d3f50875bd5dff172cf813fdb7d738eb new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-1-dda16565b98926fc3587de937b9401c7 b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-1-dda16565b98926fc3587de937b9401c7 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-2-374e39786feb745cd70f25be58bfa24 b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-2-374e39786feb745cd70f25be58bfa24 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-3-d2b5e23edec42a62e61750b110ecbaac b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-3-d2b5e23edec42a62e61750b110ecbaac new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-3-d2b5e23edec42a62e61750b110ecbaac @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-4-50d0c630159068b5b8ccdeb76493f1f7 b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-4-50d0c630159068b5b8ccdeb76493f1f7 new file mode 100644 index 000000000000..850c41c8115d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-4-50d0c630159068b5b8ccdeb76493f1f7 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 2346.3 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 4100.06 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 5702.650000000001 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 8749.730000000001 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 3491.38 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 5523.360000000001 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 7222.02 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 8923.62 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 2861.95 +Manufacturer#3 almond antique metallic orange dim 19 3 3 4272.34 +Manufacturer#3 almond antique misty red olive 1 4 4 6195.32 +Manufacturer#3 almond antique olive coral navajo 45 5 5 7532.61 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 2 2 2996.09 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 4202.35 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 6047.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 7337.620000000001 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 2 2 3401.3500000000004 +Manufacturer#5 almond antique sky peru orange 2 3 3 5190.08 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 6208.18 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-5-3f95cd6f4add7a2d0101fe3dd97e5082 b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-5-3f95cd6f4add7a2d0101fe3dd97e5082 new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_adjust_rowcontainer_sz-5-3f95cd6f4add7a2d0101fe3dd97e5082 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/windowing_columnPruning-0-d3f50875bd5dff172cf813fdb7d738eb b/sql/hive/src/test/resources/golden/windowing_columnPruning-0-d3f50875bd5dff172cf813fdb7d738eb new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_columnPruning-1-dda16565b98926fc3587de937b9401c7 b/sql/hive/src/test/resources/golden/windowing_columnPruning-1-dda16565b98926fc3587de937b9401c7 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_columnPruning-2-374e39786feb745cd70f25be58bfa24 b/sql/hive/src/test/resources/golden/windowing_columnPruning-2-374e39786feb745cd70f25be58bfa24 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_columnPruning-3-9294b4a22bc396ff2accabd53c5da98b b/sql/hive/src/test/resources/golden/windowing_columnPruning-3-9294b4a22bc396ff2accabd53c5da98b new file mode 100644 index 000000000000..1b5ae55383a4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_columnPruning-3-9294b4a22bc396ff2accabd53c5da98b @@ -0,0 +1,26 @@ +2 1173.15 +2 2346.3 +34 4100.06 +6 5702.650000000001 +28 7117.070000000001 +42 8749.730000000001 +14 1690.68 +40 3491.38 +2 5523.360000000001 +25 7222.02 +18 8923.62 +17 1671.68 +14 2861.95 +19 4272.34 +1 6195.32 +45 7532.61 +10 1620.67 +39 2996.09 +27 4202.35 +7 6047.27 +12 7337.620000000001 +31 1789.69 +6 3401.3500000000004 +2 5190.08 +46 6208.18 +23 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing_columnPruning-4-445cab062581c449ceffcb368cdf133 b/sql/hive/src/test/resources/golden/windowing_columnPruning-4-445cab062581c449ceffcb368cdf133 new file mode 100644 index 000000000000..1b5ae55383a4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_columnPruning-4-445cab062581c449ceffcb368cdf133 @@ -0,0 +1,26 @@ +2 1173.15 +2 2346.3 +34 4100.06 +6 5702.650000000001 +28 7117.070000000001 +42 8749.730000000001 +14 1690.68 +40 3491.38 +2 5523.360000000001 +25 7222.02 +18 8923.62 +17 1671.68 +14 2861.95 +19 4272.34 +1 6195.32 +45 7532.61 +10 1620.67 +39 2996.09 +27 4202.35 +7 6047.27 +12 7337.620000000001 +31 1789.69 +6 3401.3500000000004 +2 5190.08 +46 6208.18 +23 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing_columnPruning-5-89110070c761eafb992eb9315128b53f b/sql/hive/src/test/resources/golden/windowing_columnPruning-5-89110070c761eafb992eb9315128b53f new file mode 100644 index 000000000000..e426c725b0e3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_columnPruning-5-89110070c761eafb992eb9315128b53f @@ -0,0 +1,26 @@ +Manufacturer#1 1173.15 +Manufacturer#1 2346.3 +Manufacturer#1 4100.06 +Manufacturer#1 5702.650000000001 +Manufacturer#1 7117.070000000001 +Manufacturer#1 8749.730000000001 +Manufacturer#2 1690.68 +Manufacturer#2 3491.38 +Manufacturer#2 5523.360000000001 +Manufacturer#2 7222.02 +Manufacturer#2 8923.62 +Manufacturer#3 1671.68 +Manufacturer#3 2861.95 +Manufacturer#3 4272.34 +Manufacturer#3 6195.32 +Manufacturer#3 7532.61 +Manufacturer#4 1620.67 +Manufacturer#4 2996.09 +Manufacturer#4 4202.35 +Manufacturer#4 6047.27 +Manufacturer#4 7337.620000000001 +Manufacturer#5 1789.69 +Manufacturer#5 3401.3500000000004 +Manufacturer#5 5190.08 +Manufacturer#5 6208.18 +Manufacturer#5 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 1-0-12a92d8800e0da8b515ba3eaf6a7fd0f b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 1-0-12a92d8800e0da8b515ba3eaf6a7fd0f new file mode 100644 index 000000000000..acc4f3bc2a2d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 1-0-12a92d8800e0da8b515ba3eaf6a7fd0f @@ -0,0 +1,1049 @@ + 1 4294967354 + 2 8589934811 + 3 12884902227 +alice allen 1 4294967487 +alice allen 2 8589935012 +alice allen 3 12884902543 +alice brown 1 4294967355 +alice carson 1 4294967370 +alice davidson 1 4294967517 +alice falkner 1 4294967316 +alice garcia 1 4294967369 +alice hernandez 1 4294967314 +alice hernandez 2 8589934613 +alice johnson 1 4294967424 +alice king 1 4294967387 +alice king 2 8589934903 +alice king 3 12884902449 +alice laertes 1 4294967519 +alice laertes 2 8589935038 +alice miller 1 4294967324 +alice nixon 1 4294967484 +alice nixon 2 8589934894 +alice nixon 3 12884902307 +alice ovid 1 4294967412 +alice polk 1 4294967366 +alice quirinius 1 4294967505 +alice quirinius 2 8589935054 +alice robinson 1 4294967502 +alice robinson 2 8589934947 +alice steinbeck 1 4294967549 +alice steinbeck 2 8589934913 +alice steinbeck 3 12884902387 +alice underhill 1 4294967441 +alice van buren 1 4294967428 +alice xylophone 1 4294967519 +alice xylophone 2 8589934877 +alice xylophone 3 12884902240 +alice zipper 1 4294967380 +alice zipper 2 8589934919 +alice zipper 3 12884902439 +bob brown 1 4294967431 +bob brown 2 8589934853 +bob brown 3 12884902280 +bob carson 1 4294967408 +bob davidson 1 4294967435 +bob davidson 2 8589934939 +bob davidson 3 12884902293 +bob ellison 1 4294967530 +bob ellison 2 8589934966 +bob ellison 3 12884902328 +bob ellison 4 17179869672 +bob falkner 1 4294967464 +bob garcia 1 4294967435 +bob garcia 2 8589934804 +bob garcia 3 12884902148 +bob garcia 4 17179869587 +bob garcia 5 21474836905 +bob hernandez 1 4294967500 +bob ichabod 1 4294967424 +bob king 1 4294967443 +bob king 2 8589934740 +bob king 3 12884902279 +bob laertes 1 4294967472 +bob laertes 2 8589934852 +bob miller 1 4294967349 +bob ovid 1 4294967401 +bob ovid 2 8589934801 +bob ovid 3 12884902313 +bob ovid 4 17179869708 +bob polk 1 4294967337 +bob quirinius 1 4294967346 +bob steinbeck 1 4294967342 +bob van buren 1 4294967422 +bob white 1 4294967493 +bob white 2 8589934855 +bob xylophone 1 4294967407 +bob xylophone 2 8589934872 +bob young 1 4294967413 +bob zipper 1 4294967416 +bob zipper 2 8589934769 +bob zipper 3 12884902068 +calvin allen 1 4294967373 +calvin brown 1 4294967530 +calvin brown 2 8589934967 +calvin brown 3 12884902378 +calvin carson 1 4294967534 +calvin davidson 1 4294967437 +calvin davidson 2 8589934905 +calvin ellison 1 4294967480 +calvin falkner 1 4294967305 +calvin falkner 2 8589934723 +calvin falkner 3 12884902155 +calvin falkner 4 17179869455 +calvin falkner 5 21474836800 +calvin falkner 6 25769804250 +calvin garcia 1 4294967492 +calvin hernandez 1 4294967341 +calvin johnson 1 4294967546 +calvin laertes 1 4294967499 +calvin laertes 2 8589934930 +calvin nixon 1 4294967488 +calvin nixon 2 8589934788 +calvin nixon 3 12884902200 +calvin ovid 1 4294967343 +calvin ovid 2 8589934881 +calvin ovid 3 12884902210 +calvin ovid 4 17179869559 +calvin polk 1 4294967416 +calvin quirinius 1 4294967532 +calvin quirinius 2 8589935053 +calvin robinson 1 4294967326 +calvin steinbeck 1 4294967417 +calvin steinbeck 2 8589934891 +calvin steinbeck 3 12884902396 +calvin thompson 1 4294967346 +calvin thompson 2 8589934859 +calvin underhill 1 4294967478 +calvin van buren 1 4294967300 +calvin van buren 2 8589934808 +calvin white 1 4294967304 +calvin white 2 8589934848 +calvin xylophone 1 4294967299 +calvin xylophone 2 8589934675 +calvin xylophone 3 12884902133 +calvin young 1 4294967410 +calvin young 2 8589934752 +calvin zipper 1 4294967441 +calvin zipper 2 8589934960 +david allen 1 4294967381 +david allen 2 8589934752 +david brown 1 4294967544 +david brown 2 8589934870 +david davidson 1 4294967487 +david davidson 2 8589934952 +david davidson 3 12884902474 +david davidson 4 17179869819 +david ellison 1 4294967477 +david ellison 2 8589934963 +david ellison 3 12884902426 +david hernandez 1 4294967324 +david ichabod 1 4294967385 +david ichabod 2 8589934872 +david laertes 1 4294967385 +david nixon 1 4294967381 +david ovid 1 4294967396 +david ovid 2 8589934839 +david quirinius 1 4294967375 +david quirinius 2 8589934905 +david quirinius 3 12884902362 +david robinson 1 4294967465 +david robinson 2 8589934933 +david thompson 1 4294967361 +david underhill 1 4294967331 +david underhill 2 8589934715 +david underhill 3 12884902185 +david van buren 1 4294967438 +david van buren 2 8589934747 +david white 1 4294967428 +david xylophone 1 8589934898 +david xylophone 1 8589934898 +david xylophone 3 12884902378 +david young 1 4294967296 +david young 2 8589934601 +ethan allen 1 4294967351 +ethan brown 1 4294967477 +ethan brown 2 8589934897 +ethan brown 3 12884902217 +ethan brown 4 17179869548 +ethan brown 5 21474836951 +ethan brown 6 25769804375 +ethan carson 1 4294967352 +ethan ellison 1 4294967514 +ethan ellison 2 8589934923 +ethan falkner 1 4294967318 +ethan falkner 2 8589934779 +ethan garcia 1 4294967310 +ethan hernandez 1 4294967349 +ethan johnson 1 4294967357 +ethan king 1 4294967413 +ethan laertes 1 4294967402 +ethan laertes 2 8589934859 +ethan laertes 3 12884902390 +ethan laertes 4 17179869880 +ethan laertes 5 21474837302 +ethan laertes 6 25769804603 +ethan laertes 7 30064771974 +ethan miller 1 4294967352 +ethan nixon 1 4294967499 +ethan ovid 1 4294967452 +ethan polk 1 4294967329 +ethan polk 2 8589934711 +ethan polk 3 12884902253 +ethan polk 4 17179869732 +ethan quirinius 1 4294967501 +ethan quirinius 2 8589934852 +ethan quirinius 3 12884902200 +ethan robinson 1 4294967353 +ethan robinson 2 8589934855 +ethan underhill 1 4294967466 +ethan van buren 1 4294967511 +ethan white 1 4294967445 +ethan white 2 8589934872 +ethan xylophone 1 4294967543 +ethan zipper 1 4294967462 +ethan zipper 2 8589934815 +fred davidson 1 4294967512 +fred davidson 2 8589934936 +fred davidson 3 12884902424 +fred ellison 1 4294967470 +fred ellison 2 8589934901 +fred ellison 3 12884902294 +fred falkner 1 4294967340 +fred falkner 2 8589934887 +fred falkner 3 12884902187 +fred hernandez 1 4294967365 +fred ichabod 1 4294967342 +fred ichabod 2 8589934831 +fred johnson 1 4294967373 +fred king 1 4294967346 +fred king 2 8589934766 +fred laertes 1 4294967351 +fred miller 1 4294967490 +fred nixon 1 4294967514 +fred nixon 2 8589934811 +fred nixon 3 12884902293 +fred nixon 4 17179869668 +fred polk 1 4294967332 +fred polk 2 8589934775 +fred polk 3 12884902233 +fred polk 4 17179869740 +fred quirinius 1 4294967426 +fred quirinius 2 8589934951 +fred robinson 1 4294967461 +fred steinbeck 1 4294967411 +fred steinbeck 2 8589934740 +fred steinbeck 3 12884902212 +fred underhill 1 4294967387 +fred van buren 1 4294967431 +fred van buren 2 8589934812 +fred van buren 3 12884902338 +fred van buren 4 17179869801 +fred white 1 4294967434 +fred young 1 4294967495 +fred young 2 8589934980 +fred zipper 1 4294967447 +gabriella allen 1 4294967405 +gabriella allen 2 8589934939 +gabriella brown 1 4294967543 +gabriella brown 2 8589934946 +gabriella carson 1 4294967540 +gabriella davidson 1 4294967507 +gabriella ellison 1 4294967393 +gabriella ellison 2 8589934733 +gabriella falkner 1 4294967378 +gabriella falkner 2 8589934901 +gabriella falkner 3 12884902335 +gabriella garcia 1 4294967419 +gabriella hernandez 1 4294967481 +gabriella hernandez 2 8589934943 +gabriella ichabod 1 4294967337 +gabriella ichabod 2 8589934725 +gabriella ichabod 3 12884902062 +gabriella ichabod 4 17179869382 +gabriella ichabod 5 21474836880 +gabriella king 1 4294967434 +gabriella king 2 8589934827 +gabriella laertes 1 4294967410 +gabriella miller 1 4294967363 +gabriella ovid 1 4294967482 +gabriella ovid 2 8589935004 +gabriella polk 1 4294967410 +gabriella polk 2 8589934712 +gabriella steinbeck 1 4294967500 +gabriella steinbeck 2 8589934935 +gabriella thompson 1 4294967299 +gabriella thompson 2 8589934711 +gabriella thompson 3 12884902196 +gabriella van buren 1 4294967457 +gabriella van buren 2 8589934927 +gabriella white 1 4294967335 +gabriella young 1 4294967493 +gabriella young 2 8589934924 +gabriella zipper 1 4294967357 +gabriella zipper 2 8589934867 +holly allen 1 4294967327 +holly brown 1 4294967321 +holly brown 2 8589934659 +holly falkner 1 4294967324 +holly hernandez 1 4294967378 +holly hernandez 2 8589934921 +holly hernandez 3 12884902465 +holly hernandez 4 17179869773 +holly ichabod 1 4294967342 +holly ichabod 2 8589934800 +holly ichabod 3 12884902129 +holly johnson 1 4294967517 +holly johnson 2 8589934897 +holly johnson 3 12884902432 +holly king 1 4294967392 +holly king 2 8589934753 +holly laertes 1 4294967406 +holly miller 1 4294967388 +holly nixon 1 4294967383 +holly nixon 2 8589934707 +holly polk 1 4294967398 +holly polk 2 8589934832 +holly robinson 1 4294967532 +holly thompson 1 4294967529 +holly thompson 2 8589934868 +holly thompson 3 12884902242 +holly underhill 1 4294967383 +holly underhill 2 8589934894 +holly underhill 3 12884902330 +holly underhill 4 17179869856 +holly van buren 1 4294967539 +holly white 1 4294967320 +holly white 2 8589934735 +holly xylophone 1 4294967435 +holly young 1 4294967487 +holly young 2 8589934987 +holly zipper 1 4294967337 +holly zipper 2 8589934846 +irene allen 1 4294967518 +irene brown 1 4294967434 +irene brown 2 8589934862 +irene brown 3 12884902272 +irene carson 1 4294967473 +irene ellison 1 4294967379 +irene ellison 2 8589934797 +irene falkner 1 4294967404 +irene falkner 2 8589934952 +irene garcia 1 4294967498 +irene garcia 2 8589934869 +irene garcia 3 12884902192 +irene ichabod 1 4294967529 +irene ichabod 2 8589935038 +irene johnson 1 4294967468 +irene laertes 1 4294967481 +irene laertes 2 8589934780 +irene laertes 3 12884902116 +irene miller 1 4294967387 +irene nixon 1 4294967323 +irene nixon 2 8589934824 +irene nixon 3 12884902362 +irene ovid 1 4294967499 +irene ovid 2 8589934870 +irene ovid 3 12884902230 +irene polk 1 4294967521 +irene polk 2 8589934930 +irene polk 3 12884902395 +irene polk 4 17179869941 +irene polk 5 21474837237 +irene quirinius 1 4294967365 +irene quirinius 2 8589934751 +irene quirinius 3 12884902141 +irene robinson 1 4294967347 +irene steinbeck 1 4294967549 +irene thompson 1 4294967479 +irene underhill 1 4294967371 +irene underhill 2 8589934753 +irene van buren 1 4294967439 +irene van buren 2 8589934906 +irene xylophone 1 4294967436 +jessica brown 1 4294967496 +jessica carson 1 4294967389 +jessica carson 2 8589934897 +jessica carson 3 12884902345 +jessica davidson 1 4294967325 +jessica davidson 2 8589934709 +jessica davidson 3 12884902098 +jessica davidson 4 17179869569 +jessica ellison 1 4294967316 +jessica ellison 2 8589934721 +jessica falkner 1 4294967549 +jessica garcia 1 4294967540 +jessica garcia 2 8589935041 +jessica ichabod 1 4294967413 +jessica johnson 1 4294967497 +jessica johnson 2 8589934870 +jessica miller 1 4294967495 +jessica nixon 1 4294967311 +jessica nixon 2 8589934754 +jessica ovid 1 4294967330 +jessica ovid 2 8589934855 +jessica polk 1 4294967403 +jessica quirinius 1 4294967523 +jessica quirinius 2 8589934942 +jessica quirinius 3 12884902388 +jessica quirinius 4 17179869696 +jessica robinson 1 4294967542 +jessica thompson 1 4294967449 +jessica thompson 2 8589934763 +jessica underhill 1 4294967541 +jessica underhill 2 8589934844 +jessica underhill 3 12884902153 +jessica van buren 1 4294967344 +jessica white 1 4294967482 +jessica white 2 8589934929 +jessica white 3 12884902378 +jessica white 4 17179869687 +jessica white 5 21474837086 +jessica xylophone 1 4294967421 +jessica young 1 4294967382 +jessica young 2 8589934903 +jessica zipper 1 4294967334 +jessica zipper 2 8589934785 +jessica zipper 3 12884902157 +katie allen 1 4294967443 +katie brown 1 4294967420 +katie davidson 1 4294967459 +katie ellison 1 4294967486 +katie ellison 2 8589934861 +katie falkner 1 4294967362 +katie garcia 1 4294967306 +katie garcia 2 8589934680 +katie hernandez 1 4294967451 +katie ichabod 1 4294967330 +katie ichabod 2 8589934742 +katie ichabod 3 12884902209 +katie king 1 4294967339 +katie king 2 8589934760 +katie king 3 12884902199 +katie miller 1 4294967425 +katie miller 2 8589934859 +katie nixon 1 4294967500 +katie ovid 1 4294967519 +katie polk 1 4294967384 +katie polk 2 8589934926 +katie robinson 1 4294967310 +katie van buren 1 4294967335 +katie van buren 2 8589934647 +katie white 1 4294967337 +katie white 2 8589934643 +katie xylophone 1 4294967486 +katie young 1 4294967349 +katie young 2 8589934681 +katie young 3 12884902225 +katie zipper 1 4294967354 +katie zipper 2 8589934766 +luke allen 1 4294967533 +luke allen 2 8589934836 +luke allen 3 12884902346 +luke allen 4 17179869863 +luke allen 5 21474837208 +luke brown 1 4294967473 +luke davidson 1 4294967550 +luke davidson 2 8589934904 +luke ellison 1 4294967322 +luke ellison 2 8589934675 +luke ellison 3 12884902103 +luke falkner 1 4294967359 +luke falkner 2 8589934782 +luke garcia 1 4294967304 +luke garcia 2 8589934683 +luke ichabod 1 4294967324 +luke ichabod 2 8589934766 +luke johnson 1 4294967527 +luke johnson 2 8589934987 +luke johnson 3 12884902342 +luke laertes 1 4294967505 +luke laertes 2 8589935011 +luke laertes 3 12884902497 +luke laertes 4 17179869806 +luke laertes 5 21474837193 +luke miller 1 4294967497 +luke ovid 1 4294967492 +luke ovid 2 8589934901 +luke polk 1 4294967545 +luke polk 2 8589934873 +luke quirinius 1 4294967320 +luke robinson 1 4294967299 +luke robinson 2 8589934606 +luke thompson 1 4294967521 +luke underhill 1 4294967393 +luke underhill 2 8589934803 +luke underhill 3 12884902122 +luke van buren 1 4294967424 +luke white 1 4294967505 +luke xylophone 1 4294967382 +luke zipper 1 4294967353 +mike allen 1 4294967466 +mike brown 1 4294967369 +mike carson 1 4294967477 +mike carson 2 8589934934 +mike carson 3 12884902482 +mike davidson 1 4294967501 +mike davidson 2 8589934965 +mike ellison 1 4294967353 +mike ellison 2 8589934747 +mike ellison 3 12884902282 +mike ellison 4 17179869806 +mike ellison 5 21474837309 +mike falkner 1 4294967301 +mike garcia 1 4294967428 +mike garcia 2 8589934826 +mike garcia 3 12884902289 +mike hernandez 1 4294967316 +mike hernandez 2 8589934800 +mike ichabod 1 4294967494 +mike king 1 4294967323 +mike king 2 8589934848 +mike king 3 12884902248 +mike king 4 17179869595 +mike king 5 21474837046 +mike king 6 25769804478 +mike miller 1 4294967449 +mike nixon 1 4294967527 +mike nixon 2 8589935004 +mike polk 1 4294967389 +mike polk 2 8589934848 +mike polk 3 12884902351 +mike quirinius 1 4294967422 +mike steinbeck 1 4294967519 +mike steinbeck 2 8589934827 +mike steinbeck 3 12884902316 +mike steinbeck 4 17179869850 +mike van buren 1 4294967544 +mike van buren 2 8589935061 +mike white 1 4294967336 +mike white 2 8589934882 +mike white 3 12884902374 +mike white 4 17179869843 +mike young 1 4294967453 +mike young 2 8589934804 +mike young 3 12884902198 +mike zipper 1 4294967402 +mike zipper 2 8589934727 +mike zipper 3 12884902228 +nick allen 1 4294967507 +nick allen 2 8589934807 +nick brown 1 4294967334 +nick davidson 1 4294967357 +nick ellison 1 4294967397 +nick ellison 2 8589934699 +nick falkner 1 4294967480 +nick falkner 2 8589934923 +nick garcia 1 4294967384 +nick garcia 2 8589934797 +nick garcia 3 12884902319 +nick ichabod 1 4294967388 +nick ichabod 2 8589934758 +nick ichabod 3 12884902225 +nick johnson 1 4294967398 +nick johnson 2 8589934809 +nick laertes 1 4294967389 +nick miller 1 4294967550 +nick nixon 1 4294967482 +nick ovid 1 4294967488 +nick polk 1 4294967551 +nick quirinius 1 4294967316 +nick quirinius 2 8589934612 +nick robinson 1 4294967409 +nick robinson 2 8589934731 +nick steinbeck 1 4294967355 +nick thompson 1 4294967401 +nick underhill 1 4294967527 +nick van buren 1 4294967303 +nick xylophone 1 4294967460 +nick young 1 4294967405 +nick young 2 8589934917 +nick zipper 1 4294967430 +nick zipper 2 8589934796 +oscar allen 1 4294967500 +oscar brown 1 4294967331 +oscar carson 1 4294967460 +oscar carson 2 8589934904 +oscar carson 3 12884902286 +oscar carson 4 17179869599 +oscar carson 5 21474836960 +oscar davidson 1 4294967482 +oscar ellison 1 8589934740 +oscar ellison 1 8589934740 +oscar falkner 1 4294967526 +oscar garcia 1 4294967301 +oscar hernandez 1 4294967343 +oscar hernandez 2 8589934843 +oscar ichabod 1 4294967513 +oscar ichabod 2 8589934837 +oscar ichabod 3 12884902165 +oscar ichabod 4 17179869569 +oscar johnson 1 4294967418 +oscar johnson 2 8589934763 +oscar king 1 4294967465 +oscar king 2 8589934936 +oscar king 3 12884902469 +oscar laertes 1 4294967425 +oscar laertes 2 8589934876 +oscar laertes 3 12884902426 +oscar laertes 4 17179869786 +oscar nixon 1 4294967532 +oscar ovid 1 4294967508 +oscar ovid 2 8589934910 +oscar ovid 3 12884902418 +oscar polk 1 4294967325 +oscar polk 2 8589934713 +oscar quirinius 1 4294967416 +oscar quirinius 2 8589934932 +oscar quirinius 3 12884902390 +oscar quirinius 4 17179869763 +oscar robinson 1 4294967355 +oscar robinson 2 8589934681 +oscar robinson 3 12884902031 +oscar robinson 4 17179869383 +oscar steinbeck 1 4294967548 +oscar thompson 1 4294967453 +oscar thompson 2 8589934824 +oscar thompson 3 12884902197 +oscar thompson 4 17179869496 +oscar underhill 1 4294967374 +oscar van buren 1 4294967520 +oscar van buren 2 8589934990 +oscar van buren 3 12884902490 +oscar white 1 4294967454 +oscar white 2 8589934761 +oscar white 3 12884902163 +oscar white 4 17179869512 +oscar xylophone 1 4294967400 +oscar xylophone 2 8589934806 +oscar xylophone 3 12884902124 +oscar zipper 1 4294967449 +oscar zipper 2 8589934969 +oscar zipper 3 12884902458 +priscilla brown 1 4294967369 +priscilla brown 2 8589934897 +priscilla brown 3 12884902360 +priscilla carson 1 4294967489 +priscilla carson 2 8589934838 +priscilla carson 3 12884902270 +priscilla ichabod 1 4294967379 +priscilla ichabod 2 8589934926 +priscilla johnson 1 4294967535 +priscilla johnson 2 8589935003 +priscilla johnson 3 12884902308 +priscilla johnson 4 17179869707 +priscilla johnson 5 21474837167 +priscilla king 1 4294967385 +priscilla nixon 1 4294967388 +priscilla nixon 2 8589934849 +priscilla ovid 1 4294967528 +priscilla ovid 2 8589935035 +priscilla polk 1 4294967434 +priscilla quirinius 1 4294967347 +priscilla thompson 1 4294967497 +priscilla underhill 1 4294967520 +priscilla underhill 2 8589934853 +priscilla van buren 1 4294967318 +priscilla van buren 2 8589934809 +priscilla van buren 3 12884902351 +priscilla white 1 4294967419 +priscilla xylophone 1 4294967503 +priscilla xylophone 2 8589934956 +priscilla xylophone 3 12884902406 +priscilla young 1 4294967401 +priscilla young 2 8589934931 +priscilla zipper 1 4294967516 +priscilla zipper 2 8589934950 +quinn allen 1 4294967339 +quinn allen 2 8589934881 +quinn brown 1 4294967335 +quinn brown 2 8589934651 +quinn brown 3 12884902099 +quinn davidson 1 4294967478 +quinn davidson 2 8589934849 +quinn davidson 3 12884902238 +quinn davidson 4 17179869565 +quinn ellison 1 4294967392 +quinn ellison 2 8589934907 +quinn garcia 1 4294967344 +quinn garcia 2 8589934882 +quinn garcia 3 12884902395 +quinn garcia 4 17179869868 +quinn ichabod 1 4294967405 +quinn king 1 4294967538 +quinn king 2 8589934996 +quinn laertes 1 4294967533 +quinn laertes 2 8589934977 +quinn laertes 3 12884902524 +quinn nixon 1 4294967432 +quinn ovid 1 4294967340 +quinn quirinius 1 4294967347 +quinn robinson 1 4294967365 +quinn steinbeck 1 4294967358 +quinn steinbeck 2 8589934810 +quinn thompson 1 4294967488 +quinn thompson 2 8589934888 +quinn underhill 1 4294967307 +quinn underhill 2 8589934744 +quinn underhill 3 12884902278 +quinn van buren 1 4294967362 +quinn young 1 4294967392 +quinn zipper 1 4294967521 +quinn zipper 2 8589934944 +rachel allen 1 4294967334 +rachel allen 2 8589934713 +rachel brown 1 4294967451 +rachel brown 2 8589934886 +rachel brown 3 12884902325 +rachel brown 4 17179869632 +rachel brown 5 21474836938 +rachel carson 1 4294967461 +rachel carson 2 8589934777 +rachel davidson 1 4294967387 +rachel ellison 1 4294967423 +rachel falkner 1 4294967544 +rachel falkner 2 8589934892 +rachel falkner 3 12884902350 +rachel falkner 4 17179869809 +rachel johnson 1 4294967541 +rachel king 1 4294967442 +rachel king 2 8589934771 +rachel laertes 1 4294967446 +rachel laertes 2 8589934804 +rachel ovid 1 4294967481 +rachel ovid 2 8589934832 +rachel polk 1 4294967335 +rachel quirinius 1 4294967297 +rachel robinson 1 4294967344 +rachel robinson 2 8589934807 +rachel robinson 3 12884902135 +rachel thompson 1 4294967518 +rachel thompson 2 8589934881 +rachel thompson 3 12884902306 +rachel underhill 1 4294967382 +rachel white 1 4294967457 +rachel white 2 8589934793 +rachel young 1 4294967391 +rachel zipper 1 4294967434 +rachel zipper 2 8589934813 +sarah carson 1 4294967503 +sarah carson 2 8589934822 +sarah carson 3 12884902167 +sarah ellison 1 4294967542 +sarah falkner 1 4294967525 +sarah falkner 2 8589934974 +sarah garcia 1 4294967391 +sarah garcia 2 8589934849 +sarah garcia 3 12884902247 +sarah ichabod 1 4294967370 +sarah ichabod 2 8589934909 +sarah johnson 1 4294967433 +sarah johnson 2 8589934926 +sarah johnson 3 12884902235 +sarah johnson 4 17179869559 +sarah king 1 4294967496 +sarah king 2 8589935039 +sarah miller 1 4294967458 +sarah ovid 1 4294967350 +sarah robinson 1 4294967419 +sarah robinson 2 8589934917 +sarah steinbeck 1 4294967456 +sarah white 1 4294967514 +sarah white 2 8589934882 +sarah xylophone 1 4294967355 +sarah young 1 4294967442 +sarah zipper 1 4294967432 +tom brown 1 4294967432 +tom brown 2 8589934950 +tom carson 1 4294967388 +tom carson 2 8589934693 +tom carson 3 12884902227 +tom davidson 1 4294967507 +tom ellison 1 4294967487 +tom ellison 2 8589934844 +tom ellison 3 12884902188 +tom falkner 1 4294967382 +tom falkner 2 8589934837 +tom hernandez 1 8589934733 +tom hernandez 1 8589934733 +tom ichabod 1 4294967445 +tom johnson 1 4294967492 +tom johnson 2 8589934923 +tom king 1 4294967331 +tom laertes 1 4294967431 +tom laertes 2 8589934744 +tom miller 1 4294967366 +tom miller 2 8589934723 +tom miller 3 12884902078 +tom nixon 1 4294967506 +tom ovid 1 4294967512 +tom polk 1 4294967329 +tom polk 2 8589934869 +tom quirinius 1 4294967507 +tom quirinius 2 8589934823 +tom robinson 1 4294967457 +tom robinson 2 8589935008 +tom robinson 3 12884902462 +tom robinson 4 17179869770 +tom steinbeck 1 4294967447 +tom van buren 1 4294967374 +tom van buren 2 8589934703 +tom van buren 3 12884902195 +tom white 1 4294967413 +tom young 1 4294967539 +tom young 2 8589935074 +tom zipper 1 4294967526 +ulysses brown 1 4294967537 +ulysses carson 1 4294967323 +ulysses carson 2 8589934815 +ulysses carson 3 12884902127 +ulysses carson 4 17179869485 +ulysses davidson 1 4294967467 +ulysses ellison 1 4294967442 +ulysses garcia 1 4294967470 +ulysses hernandez 1 4294967449 +ulysses hernandez 2 8589934995 +ulysses hernandez 3 12884902393 +ulysses ichabod 1 4294967353 +ulysses ichabod 2 8589934728 +ulysses johnson 1 4294967432 +ulysses king 1 4294967537 +ulysses laertes 1 4294967391 +ulysses laertes 2 8589934938 +ulysses laertes 3 12884902431 +ulysses miller 1 4294967373 +ulysses miller 2 8589934808 +ulysses nixon 1 4294967296 +ulysses ovid 1 4294967394 +ulysses polk 1 4294967509 +ulysses polk 2 8589934960 +ulysses polk 3 12884902440 +ulysses polk 4 17179869745 +ulysses quirinius 1 4294967449 +ulysses robinson 1 4294967531 +ulysses steinbeck 1 4294967303 +ulysses steinbeck 2 8589934788 +ulysses thompson 1 4294967389 +ulysses underhill 1 4294967544 +ulysses underhill 2 8589934949 +ulysses underhill 3 12884902275 +ulysses underhill 4 17179869726 +ulysses underhill 5 21474837190 +ulysses underhill 6 25769804570 +ulysses underhill 7 30064771927 +ulysses van buren 1 4294967439 +ulysses white 1 4294967429 +ulysses white 2 8589934878 +ulysses xylophone 1 4294967524 +ulysses xylophone 2 8589935025 +ulysses xylophone 3 12884902473 +ulysses young 1 4294967427 +ulysses young 2 8589934763 +ulysses young 3 12884902154 +victor allen 1 4294967450 +victor allen 2 8589934776 +victor brown 1 4294967521 +victor brown 2 8589934864 +victor brown 3 12884902170 +victor brown 4 17179869625 +victor davidson 1 4294967419 +victor davidson 2 8589934720 +victor davidson 3 12884902156 +victor ellison 1 4294967362 +victor ellison 2 8589934831 +victor hernandez 1 4294967428 +victor hernandez 2 8589934733 +victor hernandez 3 12884902062 +victor hernandez 4 17179869402 +victor hernandez 5 21474836874 +victor johnson 1 4294967496 +victor johnson 2 8589934824 +victor johnson 3 12884902246 +victor king 1 4294967401 +victor king 2 8589934884 +victor laertes 1 4294967407 +victor laertes 2 8589934862 +victor miller 1 4294967410 +victor nixon 1 4294967424 +victor nixon 2 8589934803 +victor ovid 1 4294967355 +victor polk 1 4294967333 +victor quirinius 1 4294967520 +victor quirinius 2 8589934846 +victor robinson 1 4294967440 +victor robinson 2 8589934930 +victor steinbeck 1 4294967390 +victor steinbeck 2 8589934707 +victor steinbeck 3 12884902037 +victor thompson 1 4294967319 +victor van buren 1 4294967365 +victor van buren 2 8589934906 +victor white 1 4294967403 +victor white 2 8589934862 +victor xylophone 1 4294967331 +victor xylophone 2 8589934864 +victor xylophone 3 12884902262 +victor xylophone 4 17179869633 +victor xylophone 5 21474837062 +victor young 1 4294967337 +victor zipper 1 4294967428 +wendy allen 1 4294967473 +wendy allen 2 8589934989 +wendy allen 3 12884902367 +wendy brown 1 4294967337 +wendy brown 2 8589934817 +wendy ellison 1 4294967475 +wendy ellison 2 8589934989 +wendy falkner 1 4294967313 +wendy falkner 2 8589934810 +wendy falkner 3 12884902236 +wendy garcia 1 4294967394 +wendy garcia 2 8589934775 +wendy garcia 3 12884902088 +wendy garcia 4 17179869400 +wendy hernandez 1 4294967299 +wendy ichabod 1 4294967516 +wendy king 1 4294967420 +wendy king 2 8589934811 +wendy king 3 12884902252 +wendy laertes 1 4294967519 +wendy laertes 2 8589934939 +wendy laertes 3 12884902315 +wendy miller 1 4294967478 +wendy miller 2 8589934957 +wendy nixon 1 4294967407 +wendy nixon 2 8589934901 +wendy ovid 1 4294967464 +wendy ovid 2 8589934894 +wendy polk 1 4294967434 +wendy polk 2 8589934824 +wendy quirinius 1 4294967334 +wendy quirinius 2 8589934782 +wendy robinson 1 4294967302 +wendy robinson 2 8589934613 +wendy robinson 3 12884901977 +wendy steinbeck 1 4294967444 +wendy thompson 1 4294967301 +wendy thompson 2 8589934621 +wendy underhill 1 4294967540 +wendy underhill 2 8589934993 +wendy underhill 3 12884902410 +wendy van buren 1 4294967488 +wendy van buren 2 8589934835 +wendy white 1 4294967490 +wendy xylophone 1 4294967488 +wendy xylophone 2 8589934939 +wendy young 1 4294967395 +wendy young 2 8589934708 +xavier allen 1 4294967304 +xavier allen 2 8589934743 +xavier allen 3 12884902129 +xavier brown 1 4294967546 +xavier brown 2 8589935074 +xavier brown 3 12884902532 +xavier carson 1 4294967547 +xavier carson 2 8589934862 +xavier davidson 1 4294967361 +xavier davidson 2 8589934760 +xavier davidson 3 12884902204 +xavier ellison 1 4294967441 +xavier ellison 2 8589934914 +xavier garcia 1 4294967465 +xavier hernandez 1 4294967383 +xavier hernandez 2 8589934743 +xavier hernandez 3 12884902274 +xavier ichabod 1 4294967511 +xavier ichabod 2 8589934950 +xavier johnson 1 4294967507 +xavier johnson 2 8589934898 +xavier king 1 4294967456 +xavier king 2 8589934758 +xavier laertes 1 4294967450 +xavier ovid 1 4294967403 +xavier polk 1 4294967506 +xavier polk 2 8589934925 +xavier polk 3 12884902406 +xavier polk 4 17179869906 +xavier quirinius 1 4294967383 +xavier quirinius 2 8589934748 +xavier quirinius 3 12884902060 +xavier quirinius 4 17179869562 +xavier thompson 1 4294967444 +xavier underhill 1 4294967332 +xavier white 1 4294967473 +xavier white 2 8589934952 +xavier xylophone 1 4294967499 +xavier zipper 1 4294967547 +yuri allen 1 4294967528 +yuri allen 2 8589935079 +yuri brown 1 4294967433 +yuri brown 2 8589934960 +yuri carson 1 4294967317 +yuri carson 2 8589934851 +yuri ellison 1 4294967299 +yuri ellison 2 8589934697 +yuri falkner 1 4294967368 +yuri falkner 2 8589934891 +yuri garcia 1 4294967362 +yuri hernandez 1 4294967367 +yuri johnson 1 4294967421 +yuri johnson 2 8589934877 +yuri johnson 3 12884902361 +yuri king 1 4294967376 +yuri laertes 1 4294967402 +yuri laertes 2 8589934924 +yuri nixon 1 4294967400 +yuri nixon 2 8589934706 +yuri polk 1 4294967391 +yuri polk 2 8589934861 +yuri polk 3 12884902167 +yuri quirinius 1 4294967398 +yuri quirinius 2 8589934768 +yuri quirinius 3 12884902081 +yuri steinbeck 1 4294967535 +yuri steinbeck 2 8589934873 +yuri thompson 1 4294967447 +yuri underhill 1 4294967499 +yuri underhill 2 8589934900 +yuri white 1 4294967341 +yuri xylophone 1 4294967420 +zach allen 1 4294967507 +zach brown 1 4294967316 +zach brown 2 8589934728 +zach brown 3 12884902099 +zach brown 4 17179869452 +zach brown 5 21474836769 +zach carson 1 4294967463 +zach ellison 1 4294967471 +zach falkner 1 4294967362 +zach falkner 2 8589934717 +zach garcia 1 4294967481 +zach garcia 2 8589934854 +zach garcia 3 12884902240 +zach garcia 4 17179869723 +zach ichabod 1 4294967539 +zach ichabod 2 8589934912 +zach king 1 4294967424 +zach king 2 8589934956 +zach king 3 12884902458 +zach miller 1 4294967442 +zach miller 2 8589934772 +zach miller 3 12884902163 +zach ovid 1 4294967412 +zach ovid 2 8589934775 +zach ovid 3 12884902244 +zach ovid 4 17179869574 +zach quirinius 1 4294967299 +zach robinson 1 4294967325 +zach steinbeck 1 4294967469 +zach steinbeck 2 8589934834 +zach thompson 1 4294967405 +zach thompson 2 8589934730 +zach underhill 1 4294967496 +zach white 1 4294967501 +zach xylophone 1 4294967452 +zach xylophone 2 8589934755 +zach young 1 4294967297 +zach zipper 1 4294967497 +zach zipper 2 8589934855 +zach zipper 3 12884902222 diff --git a/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 3-0-455e41d9949a2d22bab634fd8e42f2b1 b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 3-0-455e41d9949a2d22bab634fd8e42f2b1 new file mode 100644 index 000000000000..f47923618a1a --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 3-0-455e41d9949a2d22bab634fd8e42f2b1 @@ -0,0 +1 @@ +bob steinbeck 65637 9.699999809265137 diff --git a/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 4-0-cfad06ae8eba6b047d32a6a61dd59392 b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 4-0-cfad06ae8eba6b047d32a6a61dd59392 new file mode 100644 index 000000000000..f41eaa259cec --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 4-0-cfad06ae8eba6b047d32a6a61dd59392 @@ -0,0 +1 @@ +bob steinbeck 1 1 diff --git a/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 5-0-d7ca7a61377cef3a9f721a28afdae012 b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 5-0-d7ca7a61377cef3a9f721a28afdae012 new file mode 100644 index 000000000000..5308b2eb457e --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 5-0-d7ca7a61377cef3a9f721a28afdae012 @@ -0,0 +1 @@ +bob steinbeck 9.699999809265137 1 diff --git a/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 6-0-287bcc7679822bc7b684532b267bf11f b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 6-0-287bcc7679822bc7b684532b267bf11f new file mode 100644 index 000000000000..f41eaa259cec --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_multipartitioning.q (deterministic) 6-0-287bcc7679822bc7b684532b267bf11f @@ -0,0 +1 @@ +bob steinbeck 1 1 diff --git a/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-0-36217f6074daaacddb9fcb50a3f4fb5b b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-0-36217f6074daaacddb9fcb50a3f4fb5b new file mode 100644 index 000000000000..8150409e62d3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-0-36217f6074daaacddb9fcb50a3f4fb5b @@ -0,0 +1,1049 @@ + 1 + 1 + 1 +alice allen 2 +alice allen 1 +alice allen 1 +alice brown 1 +alice carson 1 +alice davidson 1 +alice falkner 1 +alice garcia 1 +alice hernandez 2 +alice hernandez 1 +alice johnson 2 +alice king 1 +alice king 1 +alice king 1 +alice laertes 1 +alice laertes 1 +alice miller 1 +alice nixon 1 +alice nixon 1 +alice nixon 1 +alice ovid 1 +alice polk 3 +alice quirinius 1 +alice quirinius 1 +alice robinson 1 +alice robinson 1 +alice steinbeck 1 +alice steinbeck 1 +alice steinbeck 1 +alice underhill 2 +alice van buren 1 +alice xylophone 1 +alice xylophone 1 +alice xylophone 1 +alice zipper 1 +alice zipper 1 +alice zipper 1 +bob brown 1 +bob brown 1 +bob brown 1 +bob carson 1 +bob davidson 1 +bob davidson 1 +bob davidson 1 +bob ellison 2 +bob ellison 1 +bob ellison 1 +bob ellison 1 +bob falkner 1 +bob garcia 1 +bob garcia 1 +bob garcia 1 +bob garcia 1 +bob garcia 1 +bob hernandez 1 +bob ichabod 1 +bob king 2 +bob king 1 +bob king 1 +bob laertes 2 +bob laertes 1 +bob miller 1 +bob ovid 1 +bob ovid 1 +bob ovid 1 +bob ovid 1 +bob polk 1 +bob quirinius 1 +bob steinbeck 1 +bob van buren 1 +bob white 1 +bob white 1 +bob xylophone 1 +bob xylophone 1 +bob young 1 +bob zipper 2 +bob zipper 1 +bob zipper 1 +calvin allen 1 +calvin brown 2 +calvin brown 1 +calvin brown 1 +calvin carson 2 +calvin davidson 2 +calvin davidson 1 +calvin ellison 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin garcia 1 +calvin hernandez 2 +calvin johnson 1 +calvin laertes 1 +calvin laertes 1 +calvin nixon 3 +calvin nixon 1 +calvin nixon 1 +calvin ovid 1 +calvin ovid 1 +calvin ovid 1 +calvin ovid 1 +calvin polk 1 +calvin quirinius 1 +calvin quirinius 1 +calvin robinson 2 +calvin steinbeck 1 +calvin steinbeck 1 +calvin steinbeck 1 +calvin thompson 1 +calvin thompson 1 +calvin underhill 1 +calvin van buren 1 +calvin van buren 1 +calvin white 1 +calvin white 1 +calvin xylophone 2 +calvin xylophone 1 +calvin xylophone 1 +calvin young 1 +calvin young 1 +calvin zipper 1 +calvin zipper 1 +david allen 1 +david allen 1 +david brown 1 +david brown 1 +david davidson 1 +david davidson 1 +david davidson 1 +david davidson 1 +david ellison 1 +david ellison 1 +david ellison 1 +david hernandez 1 +david ichabod 1 +david ichabod 1 +david laertes 1 +david nixon 1 +david ovid 1 +david ovid 1 +david quirinius 1 +david quirinius 1 +david quirinius 1 +david robinson 1 +david robinson 1 +david thompson 3 +david underhill 1 +david underhill 1 +david underhill 1 +david van buren 1 +david van buren 1 +david white 1 +david xylophone 1 +david xylophone 1 +david xylophone 1 +david young 1 +david young 1 +ethan allen 1 +ethan brown 2 +ethan brown 2 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan carson 1 +ethan ellison 1 +ethan ellison 1 +ethan falkner 3 +ethan falkner 1 +ethan garcia 1 +ethan hernandez 1 +ethan johnson 1 +ethan king 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan miller 2 +ethan nixon 1 +ethan ovid 1 +ethan polk 2 +ethan polk 1 +ethan polk 1 +ethan polk 1 +ethan quirinius 1 +ethan quirinius 1 +ethan quirinius 1 +ethan robinson 1 +ethan robinson 1 +ethan underhill 1 +ethan van buren 2 +ethan white 1 +ethan white 1 +ethan xylophone 1 +ethan zipper 1 +ethan zipper 1 +fred davidson 1 +fred davidson 1 +fred davidson 1 +fred ellison 1 +fred ellison 1 +fred ellison 1 +fred falkner 2 +fred falkner 1 +fred falkner 1 +fred hernandez 1 +fred ichabod 1 +fred ichabod 1 +fred johnson 1 +fred king 2 +fred king 1 +fred laertes 1 +fred miller 1 +fred nixon 2 +fred nixon 1 +fred nixon 1 +fred nixon 1 +fred polk 1 +fred polk 1 +fred polk 1 +fred polk 1 +fred quirinius 1 +fred quirinius 1 +fred robinson 1 +fred steinbeck 1 +fred steinbeck 1 +fred steinbeck 1 +fred underhill 1 +fred van buren 1 +fred van buren 1 +fred van buren 1 +fred van buren 1 +fred white 1 +fred young 3 +fred young 1 +fred zipper 1 +gabriella allen 1 +gabriella allen 1 +gabriella brown 2 +gabriella brown 1 +gabriella carson 1 +gabriella davidson 1 +gabriella ellison 1 +gabriella ellison 1 +gabriella falkner 1 +gabriella falkner 1 +gabriella falkner 1 +gabriella garcia 1 +gabriella hernandez 1 +gabriella hernandez 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella king 1 +gabriella king 1 +gabriella laertes 1 +gabriella miller 1 +gabriella ovid 1 +gabriella ovid 1 +gabriella polk 1 +gabriella polk 1 +gabriella steinbeck 1 +gabriella steinbeck 1 +gabriella thompson 1 +gabriella thompson 1 +gabriella thompson 1 +gabriella van buren 2 +gabriella van buren 1 +gabriella white 1 +gabriella young 1 +gabriella young 1 +gabriella zipper 1 +gabriella zipper 1 +holly allen 1 +holly brown 1 +holly brown 1 +holly falkner 1 +holly hernandez 1 +holly hernandez 1 +holly hernandez 1 +holly hernandez 1 +holly ichabod 1 +holly ichabod 1 +holly ichabod 1 +holly johnson 2 +holly johnson 1 +holly johnson 1 +holly king 1 +holly king 1 +holly laertes 1 +holly miller 1 +holly nixon 1 +holly nixon 1 +holly polk 2 +holly polk 1 +holly robinson 2 +holly thompson 1 +holly thompson 1 +holly thompson 1 +holly underhill 1 +holly underhill 1 +holly underhill 1 +holly underhill 1 +holly van buren 1 +holly white 1 +holly white 1 +holly xylophone 1 +holly young 1 +holly young 1 +holly zipper 1 +holly zipper 1 +irene allen 1 +irene brown 1 +irene brown 1 +irene brown 1 +irene carson 1 +irene ellison 1 +irene ellison 1 +irene falkner 1 +irene falkner 1 +irene garcia 1 +irene garcia 1 +irene garcia 1 +irene ichabod 1 +irene ichabod 1 +irene johnson 2 +irene laertes 1 +irene laertes 1 +irene laertes 1 +irene miller 2 +irene nixon 1 +irene nixon 1 +irene nixon 1 +irene ovid 1 +irene ovid 1 +irene ovid 1 +irene polk 1 +irene polk 1 +irene polk 1 +irene polk 1 +irene polk 1 +irene quirinius 1 +irene quirinius 1 +irene quirinius 1 +irene robinson 1 +irene steinbeck 1 +irene thompson 2 +irene underhill 2 +irene underhill 1 +irene van buren 1 +irene van buren 1 +irene xylophone 1 +jessica brown 2 +jessica carson 1 +jessica carson 1 +jessica carson 1 +jessica davidson 2 +jessica davidson 1 +jessica davidson 1 +jessica davidson 1 +jessica ellison 2 +jessica ellison 1 +jessica falkner 1 +jessica garcia 1 +jessica garcia 1 +jessica ichabod 1 +jessica johnson 1 +jessica johnson 1 +jessica miller 1 +jessica nixon 1 +jessica nixon 1 +jessica ovid 2 +jessica ovid 1 +jessica polk 1 +jessica quirinius 1 +jessica quirinius 1 +jessica quirinius 1 +jessica quirinius 1 +jessica robinson 1 +jessica thompson 1 +jessica thompson 1 +jessica underhill 1 +jessica underhill 1 +jessica underhill 1 +jessica van buren 1 +jessica white 2 +jessica white 1 +jessica white 1 +jessica white 1 +jessica white 1 +jessica xylophone 1 +jessica young 1 +jessica young 1 +jessica zipper 1 +jessica zipper 1 +jessica zipper 1 +katie allen 1 +katie brown 1 +katie davidson 1 +katie ellison 1 +katie ellison 1 +katie falkner 1 +katie garcia 1 +katie garcia 1 +katie hernandez 1 +katie ichabod 2 +katie ichabod 1 +katie ichabod 1 +katie king 1 +katie king 1 +katie king 1 +katie miller 1 +katie miller 1 +katie nixon 1 +katie ovid 1 +katie polk 1 +katie polk 1 +katie robinson 1 +katie van buren 2 +katie van buren 1 +katie white 1 +katie white 1 +katie xylophone 1 +katie young 1 +katie young 1 +katie young 1 +katie zipper 1 +katie zipper 1 +luke allen 2 +luke allen 1 +luke allen 1 +luke allen 1 +luke allen 1 +luke brown 2 +luke davidson 1 +luke davidson 1 +luke ellison 1 +luke ellison 1 +luke ellison 1 +luke falkner 2 +luke falkner 1 +luke garcia 1 +luke garcia 1 +luke ichabod 1 +luke ichabod 1 +luke johnson 1 +luke johnson 1 +luke johnson 1 +luke laertes 1 +luke laertes 1 +luke laertes 1 +luke laertes 1 +luke laertes 1 +luke miller 2 +luke ovid 2 +luke ovid 1 +luke polk 1 +luke polk 1 +luke quirinius 1 +luke robinson 1 +luke robinson 1 +luke thompson 1 +luke underhill 1 +luke underhill 1 +luke underhill 1 +luke van buren 2 +luke white 1 +luke xylophone 1 +luke zipper 1 +mike allen 2 +mike brown 1 +mike carson 1 +mike carson 1 +mike carson 1 +mike davidson 1 +mike davidson 1 +mike ellison 2 +mike ellison 1 +mike ellison 1 +mike ellison 1 +mike ellison 1 +mike falkner 1 +mike garcia 1 +mike garcia 1 +mike garcia 1 +mike hernandez 1 +mike hernandez 1 +mike ichabod 1 +mike king 2 +mike king 1 +mike king 1 +mike king 1 +mike king 1 +mike king 1 +mike miller 1 +mike nixon 2 +mike nixon 1 +mike polk 2 +mike polk 1 +mike polk 1 +mike quirinius 1 +mike steinbeck 1 +mike steinbeck 1 +mike steinbeck 1 +mike steinbeck 1 +mike van buren 2 +mike van buren 1 +mike white 1 +mike white 1 +mike white 1 +mike white 1 +mike young 1 +mike young 1 +mike young 1 +mike zipper 1 +mike zipper 1 +mike zipper 1 +nick allen 1 +nick allen 1 +nick brown 1 +nick davidson 1 +nick ellison 2 +nick ellison 1 +nick falkner 1 +nick falkner 1 +nick garcia 1 +nick garcia 1 +nick garcia 1 +nick ichabod 1 +nick ichabod 1 +nick ichabod 1 +nick johnson 1 +nick johnson 1 +nick laertes 1 +nick miller 1 +nick nixon 1 +nick ovid 1 +nick polk 2 +nick quirinius 2 +nick quirinius 1 +nick robinson 1 +nick robinson 1 +nick steinbeck 1 +nick thompson 2 +nick underhill 1 +nick van buren 1 +nick xylophone 1 +nick young 1 +nick young 1 +nick zipper 2 +nick zipper 1 +oscar allen 1 +oscar brown 1 +oscar carson 1 +oscar carson 1 +oscar carson 1 +oscar carson 1 +oscar carson 1 +oscar davidson 1 +oscar ellison 2 +oscar ellison 2 +oscar falkner 1 +oscar garcia 1 +oscar hernandez 1 +oscar hernandez 1 +oscar ichabod 2 +oscar ichabod 1 +oscar ichabod 1 +oscar ichabod 1 +oscar johnson 2 +oscar johnson 1 +oscar king 1 +oscar king 1 +oscar king 1 +oscar laertes 1 +oscar laertes 1 +oscar laertes 1 +oscar laertes 1 +oscar nixon 1 +oscar ovid 1 +oscar ovid 1 +oscar ovid 1 +oscar polk 1 +oscar polk 1 +oscar quirinius 2 +oscar quirinius 2 +oscar quirinius 1 +oscar quirinius 1 +oscar robinson 2 +oscar robinson 1 +oscar robinson 1 +oscar robinson 1 +oscar steinbeck 1 +oscar thompson 1 +oscar thompson 1 +oscar thompson 1 +oscar thompson 1 +oscar underhill 1 +oscar van buren 1 +oscar van buren 1 +oscar van buren 1 +oscar white 1 +oscar white 1 +oscar white 1 +oscar white 1 +oscar xylophone 2 +oscar xylophone 1 +oscar xylophone 1 +oscar zipper 2 +oscar zipper 1 +oscar zipper 1 +priscilla brown 2 +priscilla brown 1 +priscilla brown 1 +priscilla carson 2 +priscilla carson 1 +priscilla carson 1 +priscilla ichabod 2 +priscilla ichabod 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla king 1 +priscilla nixon 2 +priscilla nixon 2 +priscilla ovid 2 +priscilla ovid 1 +priscilla polk 1 +priscilla quirinius 1 +priscilla thompson 1 +priscilla underhill 1 +priscilla underhill 1 +priscilla van buren 2 +priscilla van buren 1 +priscilla van buren 1 +priscilla white 1 +priscilla xylophone 1 +priscilla xylophone 1 +priscilla xylophone 1 +priscilla young 1 +priscilla young 1 +priscilla zipper 1 +priscilla zipper 1 +quinn allen 1 +quinn allen 1 +quinn brown 1 +quinn brown 1 +quinn brown 1 +quinn davidson 2 +quinn davidson 1 +quinn davidson 1 +quinn davidson 1 +quinn ellison 1 +quinn ellison 1 +quinn garcia 1 +quinn garcia 1 +quinn garcia 1 +quinn garcia 1 +quinn ichabod 1 +quinn king 1 +quinn king 1 +quinn laertes 1 +quinn laertes 1 +quinn laertes 1 +quinn nixon 2 +quinn ovid 1 +quinn quirinius 1 +quinn robinson 2 +quinn steinbeck 1 +quinn steinbeck 1 +quinn thompson 1 +quinn thompson 1 +quinn underhill 2 +quinn underhill 1 +quinn underhill 1 +quinn van buren 1 +quinn young 2 +quinn zipper 1 +quinn zipper 1 +rachel allen 1 +rachel allen 1 +rachel brown 3 +rachel brown 1 +rachel brown 1 +rachel brown 1 +rachel brown 1 +rachel carson 2 +rachel carson 1 +rachel davidson 1 +rachel ellison 1 +rachel falkner 1 +rachel falkner 1 +rachel falkner 1 +rachel falkner 1 +rachel johnson 1 +rachel king 2 +rachel king 1 +rachel laertes 1 +rachel laertes 1 +rachel ovid 1 +rachel ovid 1 +rachel polk 1 +rachel quirinius 1 +rachel robinson 1 +rachel robinson 1 +rachel robinson 1 +rachel thompson 1 +rachel thompson 1 +rachel thompson 1 +rachel underhill 1 +rachel white 1 +rachel white 1 +rachel young 1 +rachel zipper 1 +rachel zipper 1 +sarah carson 1 +sarah carson 1 +sarah carson 1 +sarah ellison 1 +sarah falkner 1 +sarah falkner 1 +sarah garcia 1 +sarah garcia 1 +sarah garcia 1 +sarah ichabod 1 +sarah ichabod 1 +sarah johnson 1 +sarah johnson 1 +sarah johnson 1 +sarah johnson 1 +sarah king 1 +sarah king 1 +sarah miller 1 +sarah ovid 1 +sarah robinson 1 +sarah robinson 1 +sarah steinbeck 1 +sarah white 1 +sarah white 1 +sarah xylophone 1 +sarah young 1 +sarah zipper 1 +tom brown 1 +tom brown 1 +tom carson 1 +tom carson 1 +tom carson 1 +tom davidson 1 +tom ellison 1 +tom ellison 1 +tom ellison 1 +tom falkner 1 +tom falkner 1 +tom hernandez 1 +tom hernandez 1 +tom ichabod 1 +tom johnson 1 +tom johnson 1 +tom king 1 +tom laertes 2 +tom laertes 1 +tom miller 2 +tom miller 1 +tom miller 1 +tom nixon 2 +tom ovid 1 +tom polk 1 +tom polk 1 +tom quirinius 1 +tom quirinius 1 +tom robinson 1 +tom robinson 1 +tom robinson 1 +tom robinson 1 +tom steinbeck 3 +tom van buren 1 +tom van buren 1 +tom van buren 1 +tom white 2 +tom young 1 +tom young 1 +tom zipper 3 +ulysses brown 1 +ulysses carson 4 +ulysses carson 1 +ulysses carson 1 +ulysses carson 1 +ulysses davidson 2 +ulysses ellison 1 +ulysses garcia 1 +ulysses hernandez 1 +ulysses hernandez 1 +ulysses hernandez 1 +ulysses ichabod 1 +ulysses ichabod 1 +ulysses johnson 2 +ulysses king 1 +ulysses laertes 2 +ulysses laertes 1 +ulysses laertes 1 +ulysses miller 1 +ulysses miller 1 +ulysses nixon 1 +ulysses ovid 1 +ulysses polk 1 +ulysses polk 1 +ulysses polk 1 +ulysses polk 1 +ulysses quirinius 1 +ulysses robinson 1 +ulysses steinbeck 1 +ulysses steinbeck 1 +ulysses thompson 1 +ulysses underhill 2 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses van buren 1 +ulysses white 1 +ulysses white 1 +ulysses xylophone 2 +ulysses xylophone 1 +ulysses xylophone 1 +ulysses young 2 +ulysses young 1 +ulysses young 1 +victor allen 1 +victor allen 1 +victor brown 1 +victor brown 1 +victor brown 1 +victor brown 1 +victor davidson 1 +victor davidson 1 +victor davidson 1 +victor ellison 2 +victor ellison 1 +victor hernandez 1 +victor hernandez 1 +victor hernandez 1 +victor hernandez 1 +victor hernandez 1 +victor johnson 2 +victor johnson 1 +victor johnson 1 +victor king 1 +victor king 1 +victor laertes 1 +victor laertes 1 +victor miller 1 +victor nixon 1 +victor nixon 1 +victor ovid 2 +victor polk 2 +victor quirinius 1 +victor quirinius 1 +victor robinson 2 +victor robinson 1 +victor steinbeck 2 +victor steinbeck 1 +victor steinbeck 1 +victor thompson 1 +victor van buren 1 +victor van buren 1 +victor white 2 +victor white 1 +victor xylophone 1 +victor xylophone 1 +victor xylophone 1 +victor xylophone 1 +victor xylophone 1 +victor young 1 +victor zipper 1 +wendy allen 1 +wendy allen 1 +wendy allen 1 +wendy brown 1 +wendy brown 1 +wendy ellison 1 +wendy ellison 1 +wendy falkner 2 +wendy falkner 1 +wendy falkner 1 +wendy garcia 2 +wendy garcia 1 +wendy garcia 1 +wendy garcia 1 +wendy hernandez 1 +wendy ichabod 1 +wendy king 1 +wendy king 1 +wendy king 1 +wendy laertes 1 +wendy laertes 1 +wendy laertes 1 +wendy miller 1 +wendy miller 1 +wendy nixon 1 +wendy nixon 1 +wendy ovid 1 +wendy ovid 1 +wendy polk 2 +wendy polk 1 +wendy quirinius 1 +wendy quirinius 1 +wendy robinson 2 +wendy robinson 1 +wendy robinson 1 +wendy steinbeck 1 +wendy thompson 2 +wendy thompson 1 +wendy underhill 2 +wendy underhill 1 +wendy underhill 1 +wendy van buren 1 +wendy van buren 1 +wendy white 1 +wendy xylophone 1 +wendy xylophone 1 +wendy young 1 +wendy young 1 +xavier allen 1 +xavier allen 1 +xavier allen 1 +xavier brown 1 +xavier brown 1 +xavier brown 1 +xavier carson 1 +xavier carson 1 +xavier davidson 1 +xavier davidson 1 +xavier davidson 1 +xavier ellison 1 +xavier ellison 1 +xavier garcia 1 +xavier hernandez 1 +xavier hernandez 1 +xavier hernandez 1 +xavier ichabod 1 +xavier ichabod 1 +xavier johnson 1 +xavier johnson 1 +xavier king 1 +xavier king 1 +xavier laertes 1 +xavier ovid 2 +xavier polk 1 +xavier polk 1 +xavier polk 1 +xavier polk 1 +xavier quirinius 2 +xavier quirinius 2 +xavier quirinius 1 +xavier quirinius 1 +xavier thompson 1 +xavier underhill 1 +xavier white 2 +xavier white 1 +xavier xylophone 1 +xavier zipper 1 +yuri allen 1 +yuri allen 1 +yuri brown 1 +yuri brown 1 +yuri carson 1 +yuri carson 1 +yuri ellison 1 +yuri ellison 1 +yuri falkner 1 +yuri falkner 1 +yuri garcia 1 +yuri hernandez 1 +yuri johnson 1 +yuri johnson 1 +yuri johnson 1 +yuri king 2 +yuri laertes 1 +yuri laertes 1 +yuri nixon 1 +yuri nixon 1 +yuri polk 1 +yuri polk 1 +yuri polk 1 +yuri quirinius 1 +yuri quirinius 1 +yuri quirinius 1 +yuri steinbeck 1 +yuri steinbeck 1 +yuri thompson 1 +yuri underhill 1 +yuri underhill 1 +yuri white 1 +yuri xylophone 1 +zach allen 1 +zach brown 2 +zach brown 1 +zach brown 1 +zach brown 1 +zach brown 1 +zach carson 2 +zach ellison 1 +zach falkner 1 +zach falkner 1 +zach garcia 2 +zach garcia 1 +zach garcia 1 +zach garcia 1 +zach ichabod 1 +zach ichabod 1 +zach king 2 +zach king 1 +zach king 1 +zach miller 1 +zach miller 1 +zach miller 1 +zach ovid 1 +zach ovid 1 +zach ovid 1 +zach ovid 1 +zach quirinius 1 +zach robinson 1 +zach steinbeck 1 +zach steinbeck 1 +zach thompson 2 +zach thompson 1 +zach underhill 1 +zach white 1 +zach xylophone 2 +zach xylophone 1 +zach young 1 +zach zipper 1 +zach zipper 1 +zach zipper 1 diff --git a/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-1-9ee79e711248dd6e0a6ce27e439e55f4 b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-1-9ee79e711248dd6e0a6ce27e439e55f4 new file mode 100644 index 000000000000..275772e1f643 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-1-9ee79e711248dd6e0a6ce27e439e55f4 @@ -0,0 +1,1049 @@ +65791 calvin nixon +65791 katie garcia +65790 fred nixon +65790 victor polk +65790 yuri ellison +65789 NULL +65789 alice laertes +65789 gabriella king +65789 katie zipper +65789 oscar zipper +65789 quinn davidson +65789 wendy brown +65788 irene brown +65788 oscar zipper +65788 rachel king +65788 xavier thompson +65787 alice laertes +65787 david laertes +65787 katie ichabod +65787 ulysses king +65786 bob carson +65786 quinn king +65786 victor xylophone +65786 xavier allen +65786 xavier davidson +65785 sarah carson +65785 sarah johnson +65784 holly young +65784 jessica ellison +65784 jessica ovid +65784 jessica zipper +65784 quinn van buren +65783 david allen +65783 fred ellison +65783 irene nixon +65783 mike ichabod +65783 wendy miller +65783 zach garcia +65782 tom miller +65782 ulysses underhill +65782 victor nixon +65782 yuri white +65782 zach ovid +65781 ethan carson +65781 luke laertes +65781 quinn brown +65780 holly hernandez +65780 holly zipper +65780 wendy falkner +65779 gabriella ichabod +65779 irene brown +65779 irene underhill +65779 ulysses xylophone +65779 wendy hernandez +65779 yuri nixon +65779 zach ellison +65778 jessica davidson +65778 katie allen +65778 oscar van buren +65778 oscar white +65778 rachel johnson +65778 victor quirinius +65778 yuri polk +65778 yuri quirinius +65777 ethan garcia +65777 irene garcia +65777 katie polk +65777 rachel thompson +65776 NULL +65776 luke garcia +65776 luke quirinius +65776 priscilla ovid +65776 sarah king +65775 bob ellison +65775 calvin steinbeck +65775 ethan laertes +65775 luke robinson +65775 oscar polk +65774 calvin young +65774 irene quirinius +65774 katie brown +65774 oscar hernandez +65774 sarah ichabod +65774 zach king +65773 NULL +65773 calvin garcia +65773 irene polk +65773 jessica davidson +65773 nick ichabod +65773 nick zipper +65773 quinn ovid +65773 zach falkner +65772 oscar garcia +65771 ethan underhill +65771 yuri hernandez +65771 zach brown +65770 alice nixon +65770 gabriella ellison +65769 mike king +65769 nick steinbeck +65769 oscar nixon +65769 rachel zipper +65769 xavier quirinius +65769 zach ichabod +65768 jessica davidson +65767 tom robinson +65767 victor davidson +65767 xavier king +65766 fred davidson +65766 luke underhill +65766 ulysses carson +65766 wendy nixon +65766 xavier brown +65765 bob polk +65765 gabriella brown +65765 wendy allen +65764 alice hernandez +65764 alice robinson +65764 luke ellison +65763 calvin van buren +65763 ethan brown +65763 priscilla johnson +65763 tom ichabod +65763 tom miller +65762 david xylophone +65762 priscilla xylophone +65762 quinn allen +65762 sarah johnson +65762 tom carson +65761 david ovid +65760 NULL +65760 ethan king +65760 priscilla johnson +65760 wendy van buren +65759 alice xylophone +65759 ethan quirinius +65759 fred steinbeck +65759 sarah robinson +65759 xavier johnson +65758 alice van buren +65758 ethan white +65758 irene johnson +65758 jessica underhill +65758 rachel laertes +65757 bob quirinius +65757 nick quirinius +65757 tom van buren +65757 xavier carson +65757 zach brown +65756 gabriella falkner +65756 mike polk +65756 nick allen +65756 yuri xylophone +65755 alice johnson +65755 david ellison +65755 fred polk +65755 gabriella hernandez +65755 irene van buren +65755 jessica nixon +65755 mike ellison +65755 nick van buren +65755 zach miller +65754 luke johnson +65754 xavier quirinius +65753 rachel robinson +65753 yuri garcia +65752 bob miller +65752 oscar zipper +65751 calvin quirinius +65751 katie king +65751 mike allen +65751 mike quirinius +65751 mike white +65751 zach zipper +65750 mike white +65750 nick thompson +65750 oscar quirinius +65750 priscilla polk +65750 rachel brown +65749 david allen +65749 gabriella polk +65749 mike ellison +65749 sarah carson +65749 wendy thompson +65748 david davidson +65748 oscar laertes +65747 calvin falkner +65747 fred steinbeck +65747 priscilla zipper +65747 quinn underhill +65747 rachel falkner +65747 wendy falkner +65747 yuri falkner +65746 +65746 katie robinson +65746 luke garcia +65746 tom brown +65746 zach steinbeck +65745 oscar steinbeck +65745 oscar xylophone +65745 priscilla nixon +65745 victor laertes +65744 alice allen +65744 priscilla ichabod +65743 jessica carson +65743 oscar ichabod +65743 sarah falkner +65743 tom brown +65742 jessica brown +65742 jessica ellison +65742 wendy polk +65741 calvin thompson +65740 irene van buren +65740 mike zipper +65740 rachel quirinius +65739 gabriella van buren +65739 victor robinson +65739 wendy nixon +65738 sarah ichabod +65738 wendy robinson +65738 wendy young +65738 xavier king +65737 holly brown +65737 irene falkner +65737 jessica white +65737 quinn garcia +65737 wendy laertes +65736 fred young +65736 quinn laertes +65736 rachel allen +65736 victor ellison +65735 oscar laertes +65735 tom falkner +65735 ulysses thompson +65734 calvin falkner +65734 mike garcia +65733 gabriella van buren +65733 holly nixon +65733 luke underhill +65733 oscar allen +65733 xavier polk +65733 yuri ellison +65732 mike steinbeck +65732 tom carson +65732 tom ellison +65731 alice robinson +65731 priscilla carson +65731 tom johnson +65731 ulysses ovid +65730 bob king +65730 calvin polk +65730 gabriella ichabod +65730 rachel brown +65729 fred quirinius +65728 priscilla king +65728 victor brown +65727 ethan brown +65727 fred polk +65727 jessica white +65727 priscilla underhill +65727 sarah white +65726 mike garcia +65726 ulysses hernandez +65726 wendy quirinius +65725 zach thompson +65724 NULL +65724 alice king +65724 alice steinbeck +65724 mike ellison +65724 nick ovid +65724 priscilla zipper +65723 nick allen +65723 rachel white +65722 fred falkner +65722 jessica underhill +65722 luke ichabod +65721 ethan falkner +65721 jessica zipper +65721 luke laertes +65721 luke underhill +65721 mike carson +65721 oscar laertes +65721 ulysses ichabod +65720 calvin nixon +65720 calvin thompson +65720 gabriella young +65720 holly polk +65720 mike king +65719 bob brown +65719 holly ichabod +65719 ulysses ellison +65719 ulysses young +65718 jessica ichabod +65718 oscar johnson +65718 victor hernandez +65718 zach ovid +65717 holly hernandez +65717 mike nixon +65717 quinn garcia +65717 ulysses davidson +65717 ulysses polk +65716 ethan zipper +65716 holly xylophone +65716 jessica garcia +65716 nick falkner +65716 sarah king +65716 victor xylophone +65715 alice hernandez +65715 david young +65715 ethan polk +65715 oscar falkner +65715 priscilla brown +65714 NULL +65714 ethan laertes +65714 luke allen +65713 bob ellison +65713 nick nixon +65713 ulysses carson +65713 ulysses ichabod +65713 zach garcia +65712 NULL +65712 katie king +65712 luke davidson +65712 quinn garcia +65712 zach carson +65711 bob zipper +65711 fred miller +65711 holly nixon +65711 katie ellison +65711 wendy ovid +65711 zach zipper +65710 luke polk +65710 mike young +65709 bob laertes +65709 calvin laertes +65709 oscar laertes +65708 gabriella laertes +65708 priscilla van buren +65707 bob garcia +65707 mike steinbeck +65706 bob ellison +65706 bob xylophone +65706 luke allen +65706 ulysses xylophone +65706 wendy ichabod +65705 david ellison +65705 mike white +65705 priscilla johnson +65705 tom ellison +65704 ethan brown +65704 holly king +65704 jessica young +65704 tom steinbeck +65704 victor steinbeck +65704 zach falkner +65703 holly underhill +65703 mike ellison +65703 victor davidson +65703 xavier underhill +65702 NULL +65702 bob brown +65702 bob zipper +65702 ulysses quirinius +65701 alice allen +65701 mike steinbeck +65701 oscar thompson +65700 luke allen +65700 quinn laertes +65700 tom quirinius +65699 ethan brown +65699 ethan van buren +65699 irene laertes +65699 mike young +65699 nick garcia +65699 sarah white +65698 gabriella thompson +65698 nick polk +65697 NULL +65697 NULL +65697 holly miller +65697 oscar quirinius +65697 oscar thompson +65697 xavier davidson +65696 alice xylophone +65696 jessica davidson +65696 luke allen +65696 oscar xylophone +65695 oscar king +65695 rachel young +65695 wendy ellison +65695 yuri quirinius +65694 david brown +65694 holly underhill +65694 victor quirinius +65694 zach brown +65693 bob hernandez +65693 bob young +65693 david brown +65693 holly hernandez +65693 tom polk +65693 ulysses polk +65693 victor brown +65692 holly johnson +65692 tom robinson +65691 calvin ovid +65691 ethan nixon +65691 ethan robinson +65691 fred underhill +65691 holly white +65691 irene polk +65691 oscar white +65691 rachel brown +65690 fred van buren +65690 jessica quirinius +65689 oscar ovid +65689 wendy thompson +65688 bob steinbeck +65688 victor steinbeck +65687 gabriella ichabod +65687 jessica underhill +65687 mike zipper +65687 quinn thompson +65686 bob king +65686 bob zipper +65686 david quirinius +65686 luke ichabod +65685 ethan robinson +65685 gabriella hernandez +65685 katie garcia +65685 sarah ellison +65685 victor hernandez +65685 victor nixon +65684 priscilla brown +65684 victor laertes +65684 wendy van buren +65683 NULL +65683 mike king +65683 tom laertes +65682 calvin quirinius +65682 ethan brown +65682 katie ellison +65681 gabriella allen +65681 luke laertes +65681 oscar quirinius +65681 ulysses laertes +65681 wendy ellison +65681 xavier polk +65680 NULL +65680 alice nixon +65680 gabriella ovid +65680 jessica carson +65680 ulysses nixon +65680 zach zipper +65679 bob garcia +65679 wendy underhill +65678 bob falkner +65678 victor xylophone +65678 wendy king +65677 alice allen +65677 fred van buren +65677 mike brown +65677 nick xylophone +65677 ulysses underhill +65677 zach robinson +65676 bob davidson +65676 bob laertes +65676 tom ovid +65676 xavier johnson +65675 david hernandez +65675 david nixon +65675 holly falkner +65675 quinn steinbeck +65675 rachel robinson +65675 sarah zipper +65675 tom polk +65675 victor allen +65674 gabriella falkner +65673 nick johnson +65673 quinn brown +65673 quinn underhill +65673 rachel ovid +65673 wendy brown +65672 nick laertes +65672 nick underhill +65672 rachel zipper +65672 tom white +65672 victor king +65671 fred ellison +65671 fred falkner +65671 zach white +65670 david robinson +65670 jessica zipper +65670 luke van buren +65670 oscar ovid +65670 quinn steinbeck +65669 NULL +65669 alice king +65669 calvin hernandez +65669 katie polk +65669 nick miller +65669 oscar van buren +65668 luke ellison +65667 bob brown +65667 irene nixon +65667 oscar brown +65667 tom falkner +65666 +65666 david underhill +65666 fred van buren +65665 rachel brown +65664 NULL +65664 bob davidson +65664 david ichabod +65664 ethan laertes +65664 irene robinson +65664 mike carson +65664 priscilla young +65664 victor king +65663 calvin underhill +65663 jessica johnson +65663 priscilla carson +65663 zach ichabod +65662 ethan allen +65662 katie ovid +65662 oscar johnson +65662 ulysses carson +65662 ulysses polk +65662 victor ovid +65661 david van buren +65661 luke xylophone +65661 mike falkner +65661 priscilla van buren +65661 victor johnson +65660 holly ichabod +65660 priscilla johnson +65660 victor thompson +65659 david robinson +65659 gabriella king +65659 luke davidson +65659 mike king +65659 mike zipper +65659 nick brown +65659 nick zipper +65659 yuri underhill +65658 NULL +65658 alice zipper +65658 calvin allen +65658 calvin johnson +65658 jessica garcia +65658 quinn davidson +65658 sarah ovid +65658 ulysses brown +65658 ulysses miller +65658 yuri king +65657 ethan falkner +65657 holly zipper +65657 irene ovid +65657 luke ovid +65657 priscilla white +65656 david davidson +65656 irene ovid +65656 jessica xylophone +65656 luke laertes +65656 oscar ichabod +65656 xavier ellison +65655 calvin falkner +65655 yuri laertes +65654 alice carson +65654 alice quirinius +65654 gabriella falkner +65654 nick young +65654 oscar robinson +65654 quinn robinson +65654 rachel falkner +65654 tom laertes +65654 yuri johnson +65653 calvin ellison +65653 holly underhill +65653 ulysses polk +65653 wendy xylophone +65652 NULL +65652 tom ellison +65652 victor johnson +65651 NULL +65651 ethan laertes +65651 ethan laertes +65651 irene garcia +65651 mike young +65650 irene ellison +65650 oscar white +65650 sarah steinbeck +65650 ulysses underhill +65650 ulysses xylophone +65650 victor xylophone +65649 irene underhill +65649 priscilla quirinius +65649 quinn ellison +65649 tom quirinius +65648 alice nixon +65648 calvin brown +65648 sarah carson +65648 xavier ellison +65647 irene allen +65647 mike ellison +65646 bob ovid +65646 xavier brown +65646 xavier ovid +65645 jessica quirinius +65645 katie miller +65645 ulysses hernandez +65644 alice king +65644 calvin ovid +65644 jessica white +65644 katie van buren +65644 sarah young +65644 ulysses hernandez +65644 yuri carson +65643 david davidson +65643 ethan polk +65643 ethan zipper +65643 gabriella ichabod +65643 mike davidson +65643 mike hernandez +65643 oscar robinson +65643 priscilla underhill +65643 zach king +65642 gabriella thompson +65641 fred laertes +65641 sarah garcia +65641 tom miller +65641 xavier hernandez +65640 david underhill +65639 wendy garcia +65638 fred nixon +65638 luke polk +65638 rachel carson +65637 alice underhill +65637 david davidson +65637 fred davidson +65637 gabriella davidson +65637 oscar carson +65637 rachel laertes +65637 sarah garcia +65637 wendy garcia +65636 +65636 irene polk +65636 wendy allen +65635 alice steinbeck +65635 alice zipper +65635 ulysses white +65634 NULL +65634 calvin white +65634 holly underhill +65634 sarah falkner +65633 NULL +65633 holly polk +65633 jessica nixon +65633 oscar quirinius +65632 alice falkner +65632 zach quirinius +65631 fred ellison +65630 rachel brown +65630 xavier hernandez +65629 jessica quirinius +65629 priscilla carson +65629 victor young +65629 xavier carson +65628 bob ovid +65628 ethan ovid +65628 irene ichabod +65628 oscar hernandez +65628 oscar robinson +65628 xavier quirinius +65627 alice miller +65627 holly johnson +65627 luke falkner +65627 yuri polk +65626 ethan polk +65626 holly robinson +65626 tom young +65626 yuri johnson +65625 david xylophone +65625 fred ichabod +65625 katie white +65625 ulysses garcia +65624 calvin steinbeck +65624 calvin xylophone +65624 rachel carson +65624 tom van buren +65624 yuri brown +65623 alice quirinius +65623 jessica miller +65623 oscar ichabod +65623 quinn zipper +65623 tom van buren +65623 victor brown +65623 wendy young +65622 nick davidson +65622 rachel king +65622 wendy robinson +65622 xavier ichabod +65622 zach xylophone +65622 zach young +65621 quinn underhill +65621 ulysses young +65620 nick garcia +65620 oscar thompson +65620 quinn quirinius +65620 victor white +65620 victor xylophone +65620 wendy quirinius +65619 calvin brown +65619 gabriella polk +65619 oscar king +65619 ulysses miller +65619 ulysses robinson +65619 ulysses steinbeck +65618 gabriella ovid +65618 irene laertes +65618 katie king +65618 oscar ovid +65618 quinn thompson +65617 fred van buren +65617 gabriella carson +65617 sarah johnson +65617 ulysses underhill +65616 calvin steinbeck +65616 xavier ichabod +65615 alice ovid +65615 david quirinius +65615 irene quirinius +65615 katie nixon +65614 wendy king +65614 xavier quirinius +65614 xavier white +65613 xavier zipper +65612 irene miller +65612 victor hernandez +65612 wendy white +65612 yuri polk +65611 ethan johnson +65611 fred zipper +65611 irene carson +65611 nick quirinius +65610 tom king +65610 victor steinbeck +65610 wendy garcia +65610 yuri carson +65610 zach ovid +65609 sarah robinson +65608 katie van buren +65608 mike van buren +65608 quinn ichabod +65608 zach underhill +65607 katie miller +65607 luke falkner +65607 mike polk +65607 priscilla xylophone +65607 yuri allen +65607 yuri allen +65606 bob white +65606 gabriella white +65606 oscar carson +65606 victor white +65606 xavier allen +65606 zach allen +65605 holly king +65604 katie zipper +65604 oscar davidson +65604 wendy laertes +65604 zach brown +65603 alice davidson +65603 ethan miller +65603 katie davidson +65603 katie young +65603 mike garcia +65602 NULL +65602 calvin laertes +65602 ethan laertes +65602 fred steinbeck +65602 jessica young +65602 xavier brown +65601 priscilla ovid +65601 sarah xylophone +65601 tom robinson +65600 gabriella thompson +65600 jessica polk +65600 nick robinson +65600 rachel allen +65599 fred quirinius +65599 luke johnson +65599 nick garcia +65599 oscar xylophone +65599 ulysses underhill +65598 ulysses van buren +65598 victor zipper +65597 ethan ellison +65597 nick ellison +65597 quinn davidson +65596 NULL +65596 calvin zipper +65596 david ellison +65596 irene ichabod +65596 wendy laertes +65595 bob white +65595 holly hernandez +65595 luke brown +65595 oscar ellison +65595 oscar ichabod +65595 quinn ellison +65594 gabriella ellison +65594 oscar robinson +65594 ulysses underhill +65594 victor robinson +65593 oscar white +65593 zach xylophone +65592 calvin xylophone +65591 alice zipper +65591 nick ichabod +65591 priscilla ichabod +65591 rachel underhill +65590 NULL +65590 katie falkner +65590 oscar van buren +65590 xavier garcia +65590 yuri underhill +65589 ethan white +65589 gabriella zipper +65589 irene ovid +65589 oscar king +65589 wendy xylophone +65588 bob van buren +65588 david ichabod +65588 mike miller +65588 tom hernandez +65588 victor van buren +65587 bob garcia +65587 luke johnson +65587 mike king +65587 victor allen +65587 xavier white +65586 david young +65586 irene brown +65586 priscilla brown +65586 wendy allen +65586 xavier laertes +65585 alice garcia +65585 bob garcia +65585 ethan ellison +65585 nick ellison +65585 priscilla thompson +65584 jessica carson +65584 jessica van buren +65584 jessica white +65583 bob xylophone +65583 nick ichabod +65583 yuri brown +65583 yuri steinbeck +65582 holly johnson +65582 mike carson +65582 victor van buren +65582 zach miller +65581 gabriella steinbeck +65581 irene quirinius +65581 luke allen +65581 nick robinson +65581 nick young +65581 wendy robinson +65580 alice steinbeck +65580 alice xylophone +65580 irene xylophone +65579 irene polk +65579 luke ovid +65579 quinn nixon +65579 sarah garcia +65579 wendy ovid +65578 calvin robinson +65578 fred king +65578 holly thompson +65578 katie ichabod +65578 quinn king +65578 rachel davidson +65578 victor hernandez +65577 holly white +65576 calvin falkner +65576 calvin ovid +65576 fred polk +65576 luke robinson +65575 calvin falkner +65575 irene steinbeck +65575 luke zipper +65575 zach king +65574 gabriella steinbeck +65574 priscilla nixon +65574 rachel thompson +65573 victor ellison +65573 victor hernandez +65573 yuri nixon +65572 calvin davidson +65572 calvin young +65572 katie young +65572 oscar ellison +65572 quinn garcia +65571 bob king +65571 irene polk +65571 katie ichabod +65571 mike steinbeck +65570 NULL +65570 bob ovid +65570 fred polk +65570 luke ellison +65570 mike hernandez +65570 yuri quirinius +65569 nick falkner +65568 bob ichabod +65568 holly thompson +65568 jessica thompson +65567 katie xylophone +65566 gabriella garcia +65566 rachel white +65565 katie young +65565 quinn young +65564 alice polk +65564 calvin carson +65564 calvin white +65564 ethan hernandez +65564 ethan quirinius +65564 jessica thompson +65564 katie hernandez +65563 calvin zipper +65563 priscilla young +65563 xavier davidson +65563 yuri steinbeck +65562 calvin falkner +65562 ethan xylophone +65562 luke white +65562 quinn allen +65562 rachel polk +65562 wendy polk +65561 bob davidson +65561 ethan polk +65561 jessica robinson +65560 fred white +65560 jessica johnson +65560 oscar thompson +65560 ulysses steinbeck +65560 zach brown +65559 NULL +65559 ethan laertes +65559 gabriella ichabod +65559 gabriella zipper +65559 irene garcia +65558 fred robinson +65557 fred hernandez +65557 nick johnson +65556 oscar underhill +65556 xavier hernandez +65556 yuri falkner +65556 zach garcia +65556 zach steinbeck +65555 fred nixon +65554 gabriella miller +65554 rachel falkner +65553 calvin van buren +65553 david van buren +65553 irene nixon +65553 luke laertes +65553 oscar carson +65552 NULL +65552 irene ellison +65552 oscar polk +65552 wendy falkner +65552 zach miller +65551 fred young +65551 ulysses underhill +65551 wendy underhill +65550 ethan quirinius +65550 fred davidson +65550 holly young +65550 jessica ovid +65550 quinn brown +65550 quinn laertes +65550 tom johnson +65549 bob garcia +65549 bob ovid +65549 fred ichabod +65549 fred king +65549 jessica white +65549 ulysses laertes +65549 victor davidson +65549 victor miller +65548 calvin ovid +65548 gabriella allen +65548 holly ichabod +65548 priscilla johnson +65548 quinn zipper +65548 tom hernandez +65548 wendy king +65547 bob ellison +65547 jessica quirinius +65547 mike davidson +65547 xavier allen +65546 katie white +65545 mike king +65545 tom carson +65545 victor brown +65544 calvin davidson +65544 calvin nixon +65544 david ovid +65544 irene thompson +65544 ulysses young +65544 xavier polk +65544 xavier xylophone +65544 zach ovid +65543 fred johnson +65543 sarah johnson +65542 fred falkner +65542 holly thompson +65542 luke miller +65542 mike white +65542 tom davidson +65541 calvin brown +65541 ethan brown +65541 holly brown +65541 jessica falkner +65541 rachel thompson +65541 tom zipper +65541 wendy underhill +65541 xavier polk +65541 yuri johnson +65540 rachel falkner +65539 gabriella young +65539 holly laertes +65539 oscar carson +65538 irene laertes +65538 mike polk +65538 tom robinson +65537 NULL +65537 david quirinius +65537 rachel ovid +65537 ulysses laertes +65537 zach garcia +65536 calvin xylophone +65536 david thompson +65536 irene falkner +65536 ulysses johnson +65536 victor johnson +65536 wendy miller +65536 yuri thompson diff --git a/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-2-1e88e0ba414a00195f7ebf6b8600ac04 b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-2-1e88e0ba414a00195f7ebf6b8600ac04 new file mode 100644 index 000000000000..62d71abc6fc7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-2-1e88e0ba414a00195f7ebf6b8600ac04 @@ -0,0 +1,1049 @@ +65536 NULL +65536 16.85 +65536 32.68 +65536 33.45 +65536 58.86 +65536 75.7 +65536 83.48 +65537 NULL +65537 4.49 +65537 11.87 +65537 51.91 +65537 99.34 +65538 NULL +65538 9.53 +65538 30.27 +65539 NULL +65539 58.85 +65539 96.64 +65540 NULL +65541 NULL +65541 9.04 +65541 14.94 +65541 15.85 +65541 27.89 +65541 35.38 +65541 72.33 +65541 89.14 +65541 98.87 +65542 NULL +65542 42.85 +65542 73.72 +65542 73.93 +65542 84.63 +65543 NULL +65543 21.59 +65544 NULL +65544 26.94 +65544 32.03 +65544 48.84 +65544 65.77 +65544 65.85 +65544 73.08 +65544 79.21 +65545 NULL +65545 34.65 +65545 65.81 +65546 NULL +65547 NULL +65547 17.71 +65547 62.31 +65547 83.21 +65548 NULL +65548 52.94 +65548 53.52 +65548 57.61 +65548 58.51 +65548 75.39 +65548 77.24 +65549 NULL +65549 13.3 +65549 28.93 +65549 50.6 +65549 55.04 +65549 64.91 +65549 76.06 +65549 80.09 +65550 NULL +65550 1.75 +65550 8.46 +65550 33.01 +65550 57.63 +65550 91.38 +65550 96.9 +65551 NULL +65551 39.43 +65551 73.93 +65552 NULL +65552 3.43 +65552 18.11 +65552 48.96 +65552 82.21 +65553 NULL +65553 25.31 +65553 29.62 +65553 71.07 +65553 72.16 +65554 NULL +65554 0.24 +65555 NULL +65556 NULL +65556 9.96 +65556 45.81 +65556 50.42 +65556 95.12 +65557 NULL +65557 21.14 +65558 NULL +65559 NULL +65559 29.55 +65559 56.06 +65559 73.94 +65559 83.5 +65560 NULL +65560 16.86 +65560 21.81 +65560 64.38 +65560 71.59 +65561 NULL +65561 32.86 +65561 47.71 +65562 NULL +65562 26.86 +65562 45.94 +65562 54.94 +65562 69.41 +65562 72.04 +65563 NULL +65563 14.36 +65563 33.29 +65563 39.96 +65564 NULL +65564 9.47 +65564 15.82 +65564 26.97 +65564 30.78 +65564 90.34 +65564 98.36 +65565 NULL +65565 81.72 +65566 NULL +65566 7.8 +65567 NULL +65568 NULL +65568 21.79 +65568 58.66 +65569 NULL +65570 NULL +65570 17.09 +65570 18.2 +65570 25.57 +65570 45.23 +65570 76.8 +65571 NULL +65571 26.64 +65571 40.68 +65571 82.5 +65572 NULL +65572 22.64 +65572 43.49 +65572 70.91 +65572 93.48 +65573 NULL +65573 53.56 +65573 96.32 +65574 NULL +65574 31.28 +65574 38.54 +65575 NULL +65575 17 +65575 32.85 +65575 83.4 +65576 NULL +65576 2.04 +65576 4.88 +65576 66.55 +65577 NULL +65578 NULL +65578 16.01 +65578 41.1 +65578 51.36 +65578 54.35 +65578 58.78 +65578 75.62 +65579 NULL +65579 21.36 +65579 33.37 +65579 73.48 +65579 91.42 +65580 NULL +65580 56.33 +65580 81.42 +65581 NULL +65581 29.74 +65581 45.48 +65581 56.59 +65581 60.88 +65581 88.09 +65582 NULL +65582 1.23 +65582 9.35 +65582 96.6 +65583 NULL +65583 28.07 +65583 50.57 +65583 57.67 +65584 NULL +65584 6.88 +65584 65.78 +65585 NULL +65585 31.23 +65585 37.34 +65585 39.32 +65585 50.38 +65586 NULL +65586 1.15 +65586 18.65 +65586 87.57 +65586 94.25 +65587 NULL +65587 5.83 +65587 11.86 +65587 53.84 +65587 94.47 +65588 NULL +65588 64.63 +65588 69.48 +65588 76.44 +65588 98.33 +65589 NULL +65589 49.49 +65589 72.3 +65589 74.83 +65589 94.73 +65590 NULL +65590 13.51 +65590 22.13 +65590 42.53 +65590 92.71 +65591 NULL +65591 9.85 +65591 11.43 +65591 60.78 +65592 NULL +65593 NULL +65593 35.15 +65594 NULL +65594 35.28 +65594 60.57 +65594 65.61 +65595 NULL +65595 8.76 +65595 67.56 +65595 72.7 +65595 89.6 +65595 90.24 +65596 NULL +65596 12.72 +65596 54.75 +65596 85.74 +65596 94.35 +65597 NULL +65597 37.41 +65597 69.05 +65598 NULL +65598 63.3 +65599 NULL +65599 0.56 +65599 4.93 +65599 41.61 +65599 76.29 +65600 NULL +65600 36.57 +65600 43.03 +65600 92.92 +65601 NULL +65601 26.54 +65601 37.93 +65602 NULL +65602 8.13 +65602 47.16 +65602 83.67 +65602 86.23 +65602 95.58 +65603 NULL +65603 41.44 +65603 45.63 +65603 69.26 +65603 80.24 +65604 NULL +65604 63.36 +65604 63.65 +65604 85.91 +65605 NULL +65606 NULL +65606 7.51 +65606 24.8 +65606 57.69 +65606 67.94 +65606 87.16 +65607 NULL +65607 9.67 +65607 36.58 +65607 71.75 +65607 75.86 +65607 91.52 +65608 NULL +65608 48.9 +65608 69.42 +65608 87.9 +65609 NULL +65610 NULL +65610 7.59 +65610 11.99 +65610 36.77 +65610 39.74 +65611 NULL +65611 21.21 +65611 25.92 +65611 64.89 +65612 NULL +65612 16.05 +65612 25.1 +65612 52.64 +65613 NULL +65614 NULL +65614 1.42 +65614 94.47 +65615 NULL +65615 10.79 +65615 39.4 +65615 99.88 +65616 NULL +65616 75.2 +65617 NULL +65617 18.51 +65617 47.45 +65617 64.9 +65618 NULL +65618 10.06 +65618 16.6 +65618 81.99 +65618 88.38 +65619 NULL +65619 27.32 +65619 32.64 +65619 34.72 +65619 36.48 +65619 36.59 +65620 NULL +65620 6.85 +65620 8.16 +65620 29.14 +65620 64.65 +65620 81.28 +65621 NULL +65621 95.14 +65622 NULL +65622 28.37 +65622 50.08 +65622 74.31 +65622 88.6 +65622 93.7 +65623 NULL +65623 30.83 +65623 31.22 +65623 39.74 +65623 48.51 +65623 95.58 +65623 97.2 +65624 NULL +65624 58.02 +65624 65.31 +65624 70.08 +65624 93.3 +65625 NULL +65625 20.61 +65625 42.86 +65625 55.06 +65626 NULL +65626 63.54 +65626 64.61 +65626 75.15 +65627 NULL +65627 19.65 +65627 61.89 +65627 93.29 +65628 NULL +65628 14.83 +65628 30.43 +65628 37.8 +65628 74.31 +65628 83.26 +65629 NULL +65629 19.33 +65629 58.81 +65629 72.9 +65630 NULL +65630 72.13 +65631 NULL +65632 NULL +65632 88.51 +65633 NULL +65633 59.56 +65633 72.54 +65633 81.02 +65634 NULL +65634 57.09 +65634 64.36 +65634 99.34 +65635 NULL +65635 64.99 +65635 82.29 +65636 NULL +65636 21.15 +65636 86.29 +65637 NULL +65637 16.89 +65637 26.78 +65637 29.34 +65637 35.51 +65637 44.32 +65637 48.88 +65637 93.41 +65638 NULL +65638 11.2 +65638 19.13 +65639 NULL +65640 NULL +65641 NULL +65641 26.02 +65641 84.27 +65641 91.46 +65642 NULL +65643 NULL +65643 22.05 +65643 50.79 +65643 52.56 +65643 61.29 +65643 71.29 +65643 80.96 +65643 92.24 +65643 93.11 +65644 NULL +65644 1.97 +65644 30.25 +65644 58.05 +65644 87.31 +65644 89.95 +65644 96.45 +65645 NULL +65645 3.95 +65645 63.22 +65646 NULL +65646 17.92 +65646 27.34 +65647 NULL +65647 58.03 +65648 NULL +65648 0.08 +65648 17.66 +65648 64.06 +65649 NULL +65649 8.69 +65649 43.92 +65649 91.03 +65650 NULL +65650 23.55 +65650 59.55 +65650 85.89 +65650 89.12 +65650 90.77 +65651 NULL +65651 24.25 +65651 58.25 +65651 74.13 +65651 84.42 +65652 NULL +65652 55.04 +65652 73.61 +65653 NULL +65653 3.81 +65653 52.23 +65653 85.09 +65654 NULL +65654 8.91 +65654 11.64 +65654 26.73 +65654 29.85 +65654 37.74 +65654 37.8 +65654 53.55 +65654 88.23 +65655 NULL +65655 77.41 +65656 NULL +65656 14 +65656 14.96 +65656 53.27 +65656 64.44 +65656 82.67 +65657 NULL +65657 11.93 +65657 26.4 +65657 64.39 +65657 65.01 +65658 NULL +65658 2.63 +65658 20.69 +65658 42.93 +65658 46.61 +65658 60.94 +65658 66.53 +65658 68.85 +65658 77.66 +65658 92.67 +65659 NULL +65659 8.95 +65659 46.57 +65659 53.8 +65659 94.3 +65659 94.69 +65659 95.71 +65659 99.87 +65660 NULL +65660 28.05 +65660 62.82 +65661 NULL +65661 5.24 +65661 8.06 +65661 26.8 +65661 68.98 +65662 NULL +65662 59.92 +65662 76.11 +65662 76.51 +65662 88.64 +65662 99.18 +65663 NULL +65663 5.42 +65663 78.56 +65663 94.16 +65664 NULL +65664 11.46 +65664 27.6 +65664 34.71 +65664 38.42 +65664 45.4 +65664 55.82 +65664 97.64 +65665 NULL +65666 NULL +65666 32.73 +65666 83.95 +65667 NULL +65667 13.96 +65667 63.9 +65667 97.87 +65668 NULL +65669 NULL +65669 1.76 +65669 16.95 +65669 38.6 +65669 54.25 +65669 93.79 +65670 NULL +65670 5.37 +65670 61.06 +65670 61.54 +65670 92.97 +65671 NULL +65671 8.65 +65671 52.05 +65672 NULL +65672 52.6 +65672 58.1 +65672 64.09 +65672 75.27 +65673 NULL +65673 0.9 +65673 33.27 +65673 43.81 +65673 87.78 +65674 NULL +65675 NULL +65675 4.19 +65675 24.19 +65675 35.33 +65675 35.78 +65675 79.9 +65675 83.09 +65675 87.36 +65676 NULL +65676 8.77 +65676 58.12 +65676 80.13 +65677 NULL +65677 5.06 +65677 25.37 +65677 44.47 +65677 48.79 +65677 87.67 +65678 NULL +65678 8.72 +65678 33.9 +65679 NULL +65679 64.15 +65680 NULL +65680 1.01 +65680 34.08 +65680 54.11 +65680 55.3 +65680 65.88 +65681 NULL +65681 35.45 +65681 41.57 +65681 61.3 +65681 71.17 +65681 75.85 +65682 NULL +65682 67.17 +65682 92.95 +65683 NULL +65683 17.62 +65683 99.56 +65684 NULL +65684 3.51 +65684 67.34 +65685 NULL +65685 38.71 +65685 43.48 +65685 63.27 +65685 87.84 +65685 90.69 +65686 NULL +65686 31.75 +65686 58.87 +65686 98.68 +65687 NULL +65687 3.37 +65687 21.79 +65687 48.73 +65688 NULL +65688 76.21 +65689 NULL +65689 9.12 +65690 NULL +65690 3.43 +65691 NULL +65691 5.01 +65691 6.93 +65691 28.47 +65691 56.02 +65691 58.01 +65691 69.8 +65691 76.98 +65692 NULL +65692 54.76 +65693 NULL +65693 8.38 +65693 32.33 +65693 45.69 +65693 69.32 +65693 71.72 +65693 84.88 +65694 NULL +65694 58.23 +65694 82.24 +65694 88.5 +65695 NULL +65695 57.33 +65695 59.96 +65695 77.09 +65696 NULL +65696 17.35 +65696 40.3 +65696 54.02 +65697 NULL +65697 3.18 +65697 50.01 +65697 67.9 +65697 86.79 +65697 90.16 +65698 NULL +65698 42.98 +65699 NULL +65699 13.29 +65699 38.71 +65699 68.94 +65699 84.79 +65699 88.09 +65700 NULL +65700 2.83 +65700 37.61 +65701 NULL +65701 1.81 +65701 6.35 +65702 NULL +65702 37.6 +65702 55.68 +65702 79.5 +65703 NULL +65703 37.18 +65703 40.81 +65703 90.89 +65704 NULL +65704 16.22 +65704 37.12 +65704 48.48 +65704 54.76 +65704 93.21 +65705 NULL +65705 20.57 +65705 25.89 +65705 65.13 +65706 NULL +65706 3.91 +65706 9.74 +65706 55.94 +65706 72.87 +65707 NULL +65707 76.2 +65708 NULL +65708 1.29 +65709 NULL +65709 5.64 +65709 49.79 +65710 NULL +65710 86.7 +65711 NULL +65711 8.66 +65711 50.26 +65711 71.89 +65711 78.69 +65711 96.1 +65712 NULL +65712 30.27 +65712 34.7 +65712 49.69 +65712 53.65 +65713 NULL +65713 10.94 +65713 39.47 +65713 72.37 +65713 90.91 +65714 NULL +65714 14.85 +65714 47.42 +65715 NULL +65715 39.62 +65715 54.79 +65715 81.28 +65715 89.4 +65716 NULL +65716 9 +65716 10.07 +65716 33.4 +65716 71.53 +65716 85.93 +65717 NULL +65717 1.23 +65717 5.81 +65717 57.61 +65717 80.05 +65718 NULL +65718 63.06 +65718 84.35 +65718 89.67 +65719 NULL +65719 51.13 +65719 66.85 +65719 82.1 +65720 NULL +65720 2.72 +65720 18.8 +65720 22.34 +65720 62.04 +65721 NULL +65721 23.78 +65721 39.19 +65721 55.75 +65721 72.82 +65721 95.12 +65721 95.38 +65722 NULL +65722 1.76 +65722 38.82 +65723 NULL +65723 39.9 +65724 NULL +65724 10.52 +65724 36.05 +65724 50.96 +65724 71.66 +65724 85.52 +65725 NULL +65726 NULL +65726 6 +65726 60.46 +65727 NULL +65727 19.81 +65727 49.19 +65727 87.37 +65727 88.11 +65728 NULL +65728 55.37 +65729 NULL +65730 NULL +65730 1.35 +65730 30.6 +65730 81.44 +65731 NULL +65731 24.48 +65731 61.52 +65731 97.18 +65732 NULL +65732 30.06 +65732 91.15 +65733 NULL +65733 11.44 +65733 20.72 +65733 88.46 +65733 93.45 +65733 99.8 +65734 NULL +65734 31.71 +65735 NULL +65735 12.67 +65735 61.16 +65736 NULL +65736 28.9 +65736 48.54 +65736 86.51 +65737 NULL +65737 3.98 +65737 20.85 +65737 29.92 +65737 80.97 +65738 NULL +65738 30.94 +65738 82.32 +65738 95.1 +65739 NULL +65739 74.77 +65739 92.4 +65740 NULL +65740 7.49 +65740 58.65 +65741 NULL +65742 NULL +65742 6.61 +65742 43.84 +65743 NULL +65743 26.6 +65743 52.65 +65743 62 +65744 NULL +65744 46.98 +65745 NULL +65745 25.19 +65745 66.36 +65745 80.12 +65746 NULL +65746 36.74 +65746 93.21 +65746 97.52 +65746 98.1 +65747 NULL +65747 11.16 +65747 15.07 +65747 21.8 +65747 39.77 +65747 52.77 +65747 71.87 +65748 NULL +65748 29.49 +65749 NULL +65749 15.14 +65749 45 +65749 65.49 +65749 73.24 +65750 NULL +65750 20.91 +65750 83.44 +65750 85.44 +65750 96.85 +65751 NULL +65751 2.96 +65751 9.02 +65751 30.68 +65751 47.81 +65751 78.75 +65752 NULL +65752 47.82 +65753 NULL +65753 86.97 +65754 NULL +65754 54.35 +65755 NULL +65755 11.23 +65755 22.44 +65755 64 +65755 67.54 +65755 76.75 +65755 81.44 +65755 90.08 +65755 96.8 +65756 NULL +65756 1.45 +65756 11.81 +65756 63.51 +65757 NULL +65757 1.86 +65757 9.24 +65757 34.84 +65757 90.09 +65758 NULL +65758 25.62 +65758 56.56 +65758 60.88 +65758 94.9 +65759 NULL +65759 10.63 +65759 14.1 +65759 47.54 +65759 92.81 +65760 NULL +65760 21.14 +65760 27.52 +65760 95.45 +65761 NULL +65762 NULL +65762 5.49 +65762 45.7 +65762 77.96 +65762 87.5 +65763 NULL +65763 0.72 +65763 43.8 +65763 86.43 +65763 87.99 +65764 NULL +65764 31.41 +65764 57.1 +65765 NULL +65765 88.52 +65765 88.56 +65766 NULL +65766 37.06 +65766 66.34 +65766 86.53 +65766 98.9 +65767 NULL +65767 90.88 +65767 95.57 +65768 NULL +65769 NULL +65769 11.45 +65769 38.98 +65769 58.05 +65769 70.52 +65769 91.49 +65770 NULL +65770 51.9 +65771 NULL +65771 6.15 +65771 7.5 +65772 NULL +65773 NULL +65773 3.81 +65773 18.2 +65773 30.49 +65773 47.09 +65773 53.09 +65773 63.26 +65773 76.46 +65774 NULL +65774 45.74 +65774 45.97 +65774 48.8 +65774 56.84 +65774 94.77 +65775 NULL +65775 7.88 +65775 66.56 +65775 66.68 +65775 98.43 +65776 NULL +65776 18.7 +65776 28.47 +65776 49.73 +65776 98.87 +65777 NULL +65777 54.39 +65777 73.79 +65777 82.62 +65778 NULL +65778 7.37 +65778 51.64 +65778 59.03 +65778 62.17 +65778 64.69 +65778 89.51 +65778 95.69 +65779 NULL +65779 11.87 +65779 28.2 +65779 39.48 +65779 45.61 +65779 64.41 +65779 65.24 +65780 NULL +65780 10.95 +65780 38.58 +65781 NULL +65781 70.59 +65781 95.52 +65782 NULL +65782 30.24 +65782 34.31 +65782 76.14 +65782 81.9 +65783 NULL +65783 46.34 +65783 51.08 +65783 52.43 +65783 62.58 +65783 77.4 +65784 NULL +65784 15.7 +65784 31.35 +65784 68.18 +65784 93.95 +65785 NULL +65785 29.61 +65786 NULL +65786 8.99 +65786 29.32 +65786 66.89 +65786 80.94 +65787 NULL +65787 18.78 +65787 31.19 +65787 64.88 +65788 NULL +65788 16.1 +65788 21.81 +65788 25.77 +65789 NULL +65789 20.44 +65789 43.53 +65789 52.49 +65789 83.18 +65789 92.74 +65789 96.9 +65790 NULL +65790 46.91 +65790 84.87 +65791 NULL +65791 4.24 diff --git a/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-3-34d9ee4120f21d0d0ae914fba0acc60c b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-3-34d9ee4120f21d0d0ae914fba0acc60c new file mode 100644 index 000000000000..569c1d4e5f7b --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-3-34d9ee4120f21d0d0ae914fba0acc60c @@ -0,0 +1,1049 @@ + 7 + 11 + 38 +alice allen 48 +alice allen 78 +alice allen 79 +alice brown 91 +alice carson 28 +alice davidson 88 +alice falkner 117 +alice garcia 106 +alice hernandez 37 +alice hernandez 85 +alice johnson 65 +alice king 109 +alice king 113 +alice king 118 +alice laertes 6 +alice laertes 65 +alice miller 97 +alice nixon 15 +alice nixon 31 +alice nixon 56 +alice ovid 15 +alice polk 90 +alice quirinius 27 +alice quirinius 89 +alice robinson 5 +alice robinson 68 +alice steinbeck 45 +alice steinbeck 50 +alice steinbeck 77 +alice underhill 34 +alice van buren 31 +alice xylophone 0 +alice xylophone 5 +alice xylophone 113 +alice zipper 19 +alice zipper 51 +alice zipper 74 +bob brown 35 +bob brown 61 +bob brown 71 +bob carson 40 +bob davidson 44 +bob davidson 103 +bob davidson 124 +bob ellison 7 +bob ellison 46 +bob ellison 85 +bob ellison 103 +bob falkner 2 +bob garcia 16 +bob garcia 42 +bob garcia 71 +bob garcia 77 +bob garcia 83 +bob hernandez 93 +bob ichabod 72 +bob king 9 +bob king 22 +bob king 81 +bob laertes -1 +bob laertes 105 +bob miller 31 +bob ovid 26 +bob ovid 27 +bob ovid 81 +bob ovid 86 +bob polk 55 +bob quirinius 26 +bob steinbeck 95 +bob van buren 88 +bob white 1 +bob white 16 +bob xylophone -2 +bob xylophone 49 +bob young -1 +bob zipper 36 +bob zipper 78 +bob zipper 92 +calvin allen 98 +calvin brown 81 +calvin brown 87 +calvin brown 121 +calvin carson 105 +calvin davidson 34 +calvin davidson 89 +calvin ellison 34 +calvin falkner -1 +calvin falkner 12 +calvin falkner 46 +calvin falkner 50 +calvin falkner 53 +calvin falkner 101 +calvin garcia 79 +calvin hernandez 22 +calvin johnson 34 +calvin laertes 37 +calvin laertes 100 +calvin nixon 50 +calvin nixon 71 +calvin nixon 72 +calvin ovid -1 +calvin ovid 50 +calvin ovid 65 +calvin ovid 71 +calvin polk 23 +calvin quirinius 5 +calvin quirinius 96 +calvin robinson 0 +calvin steinbeck 13 +calvin steinbeck 14 +calvin steinbeck 33 +calvin thompson 12 +calvin thompson 40 +calvin underhill 19 +calvin van buren 57 +calvin van buren 78 +calvin white 26 +calvin white 80 +calvin xylophone 17 +calvin xylophone 48 +calvin xylophone 78 +calvin young 8 +calvin young 99 +calvin zipper 31 +calvin zipper 46 +david allen 27 +david allen 80 +david brown 75 +david brown 117 +david davidson 11 +david davidson 38 +david davidson 54 +david davidson 74 +david ellison 50 +david ellison 54 +david ellison 120 +david hernandez 72 +david ichabod 6 +david ichabod 55 +david laertes 88 +david nixon 118 +david ovid 61 +david ovid 95 +david quirinius 43 +david quirinius 85 +david quirinius 121 +david robinson 47 +david robinson 59 +david thompson 89 +david underhill 69 +david underhill 87 +david underhill 98 +david van buren 1 +david van buren 38 +david white 93 +david xylophone 0 +david xylophone 22 +david xylophone 82 +david young 2 +david young 79 +ethan allen 24 +ethan brown 3 +ethan brown 29 +ethan brown 55 +ethan brown 64 +ethan brown 84 +ethan brown 108 +ethan carson 83 +ethan ellison 28 +ethan ellison 65 +ethan falkner 45 +ethan falkner 49 +ethan garcia 16 +ethan hernandez 71 +ethan johnson 108 +ethan king 44 +ethan laertes -1 +ethan laertes 27 +ethan laertes 46 +ethan laertes 68 +ethan laertes 81 +ethan laertes 103 +ethan laertes 114 +ethan miller 40 +ethan nixon 18 +ethan ovid 66 +ethan polk 46 +ethan polk 71 +ethan polk 114 +ethan polk 117 +ethan quirinius 16 +ethan quirinius 41 +ethan quirinius 85 +ethan robinson 32 +ethan robinson 34 +ethan underhill 89 +ethan van buren 43 +ethan white 38 +ethan white 51 +ethan xylophone 107 +ethan zipper 37 +ethan zipper 89 +fred davidson 18 +fred davidson 18 +fred davidson 77 +fred ellison -3 +fred ellison 44 +fred ellison 116 +fred falkner 18 +fred falkner 33 +fred falkner 100 +fred hernandez 36 +fred ichabod 1 +fred ichabod 77 +fred johnson 62 +fred king 33 +fred king 92 +fred laertes 17 +fred miller NULL +fred nixon 33 +fred nixon 36 +fred nixon 102 +fred nixon 111 +fred polk -2 +fred polk 39 +fred polk 60 +fred polk 85 +fred quirinius 25 +fred quirinius 124 +fred robinson 89 +fred steinbeck 79 +fred steinbeck 118 +fred steinbeck 119 +fred underhill 122 +fred van buren 4 +fred van buren 24 +fred van buren 63 +fred van buren 106 +fred white 97 +fred young 33 +fred young 103 +fred zipper 66 +gabriella allen 114 +gabriella allen 119 +gabriella brown 25 +gabriella brown 92 +gabriella carson 112 +gabriella davidson 45 +gabriella ellison 21 +gabriella ellison 101 +gabriella falkner 14 +gabriella falkner 66 +gabriella falkner 77 +gabriella garcia 110 +gabriella hernandez 20 +gabriella hernandez 36 +gabriella ichabod 17 +gabriella ichabod 66 +gabriella ichabod 71 +gabriella ichabod 90 +gabriella ichabod 91 +gabriella king 100 +gabriella king 115 +gabriella laertes 50 +gabriella miller 35 +gabriella ovid 38 +gabriella ovid 89 +gabriella polk 42 +gabriella polk 105 +gabriella steinbeck 18 +gabriella steinbeck 115 +gabriella thompson 45 +gabriella thompson 70 +gabriella thompson 88 +gabriella van buren 5 +gabriella van buren 117 +gabriella white 37 +gabriella young 48 +gabriella young 107 +gabriella zipper 57 +gabriella zipper 71 +holly allen 63 +holly brown 50 +holly brown 117 +holly falkner NULL +holly hernandez 31 +holly hernandez 43 +holly hernandez 48 +holly hernandez 100 +holly ichabod 28 +holly ichabod 53 +holly ichabod 83 +holly johnson 60 +holly johnson 112 +holly johnson 121 +holly king 90 +holly king 114 +holly laertes 13 +holly miller 28 +holly nixon -2 +holly nixon 120 +holly polk 54 +holly polk 124 +holly robinson 1 +holly thompson 28 +holly thompson 33 +holly thompson 83 +holly underhill 5 +holly underhill 31 +holly underhill 101 +holly underhill 113 +holly van buren 19 +holly white 18 +holly white 90 +holly xylophone 18 +holly young 32 +holly young 72 +holly zipper 78 +holly zipper 88 +irene allen 56 +irene brown 78 +irene brown 93 +irene brown 108 +irene carson 0 +irene ellison -3 +irene ellison 40 +irene falkner 13 +irene falkner 104 +irene garcia 34 +irene garcia 62 +irene garcia 124 +irene ichabod 83 +irene ichabod 112 +irene johnson 88 +irene laertes 9 +irene laertes 54 +irene laertes 60 +irene miller 108 +irene nixon -1 +irene nixon 12 +irene nixon 101 +irene ovid 26 +irene ovid 32 +irene ovid 53 +irene polk 18 +irene polk 92 +irene polk 99 +irene polk 109 +irene polk 116 +irene quirinius 7 +irene quirinius 76 +irene quirinius 97 +irene robinson 51 +irene steinbeck 46 +irene thompson 10 +irene underhill 27 +irene underhill 63 +irene van buren 17 +irene van buren 104 +irene xylophone 18 +jessica brown 117 +jessica carson 3 +jessica carson 13 +jessica carson 88 +jessica davidson 11 +jessica davidson 28 +jessica davidson 89 +jessica davidson 124 +jessica ellison 38 +jessica ellison 50 +jessica falkner 71 +jessica garcia 25 +jessica garcia 43 +jessica ichabod 104 +jessica johnson 31 +jessica johnson 69 +jessica miller 74 +jessica nixon 22 +jessica nixon 120 +jessica ovid 47 +jessica ovid 73 +jessica polk 118 +jessica quirinius 0 +jessica quirinius 87 +jessica quirinius 105 +jessica quirinius 114 +jessica robinson 15 +jessica thompson 1 +jessica thompson 77 +jessica underhill 32 +jessica underhill 46 +jessica underhill 83 +jessica van buren 54 +jessica white 5 +jessica white 30 +jessica white 45 +jessica white 65 +jessica white 98 +jessica xylophone 67 +jessica young 61 +jessica young 123 +jessica zipper 27 +jessica zipper 33 +jessica zipper 54 +katie allen 114 +katie brown 39 +katie davidson 35 +katie ellison 5 +katie ellison 58 +katie falkner 15 +katie garcia 49 +katie garcia 65 +katie hernandez 83 +katie ichabod 9 +katie ichabod 75 +katie ichabod 104 +katie king 44 +katie king 59 +katie king 93 +katie miller 23 +katie miller 117 +katie nixon 43 +katie ovid 81 +katie polk 17 +katie polk 85 +katie robinson 92 +katie van buren 25 +katie van buren 88 +katie white 34 +katie white 86 +katie xylophone 84 +katie young 2 +katie young 24 +katie young 70 +katie zipper 25 +katie zipper 87 +luke allen 7 +luke allen 44 +luke allen 62 +luke allen 100 +luke allen 114 +luke brown 112 +luke davidson 51 +luke davidson 84 +luke ellison 35 +luke ellison 40 +luke ellison 86 +luke falkner 59 +luke falkner 97 +luke garcia 51 +luke garcia 100 +luke ichabod 42 +luke ichabod 123 +luke johnson 9 +luke johnson 17 +luke johnson 53 +luke laertes 66 +luke laertes 73 +luke laertes 76 +luke laertes 101 +luke laertes 118 +luke miller 93 +luke ovid 43 +luke ovid 70 +luke polk 53 +luke polk 88 +luke quirinius 82 +luke robinson 0 +luke robinson 114 +luke thompson 51 +luke underhill 2 +luke underhill 109 +luke underhill 119 +luke van buren 43 +luke white 110 +luke xylophone 15 +luke zipper 10 +mike allen 0 +mike brown 88 +mike carson 12 +mike carson 17 +mike carson 122 +mike davidson 9 +mike davidson 110 +mike ellison 5 +mike ellison 50 +mike ellison 70 +mike ellison 94 +mike ellison 95 +mike falkner 61 +mike garcia 2 +mike garcia 68 +mike garcia 110 +mike hernandez 91 +mike hernandez 106 +mike ichabod 18 +mike king 4 +mike king 58 +mike king 83 +mike king 96 +mike king 103 +mike king 118 +mike miller 51 +mike nixon 97 +mike nixon 106 +mike polk 6 +mike polk 65 +mike polk 119 +mike quirinius 22 +mike steinbeck 75 +mike steinbeck 85 +mike steinbeck 101 +mike steinbeck 116 +mike van buren 16 +mike van buren 111 +mike white -1 +mike white 22 +mike white 45 +mike white 61 +mike young 37 +mike young 53 +mike young 72 +mike zipper 27 +mike zipper 76 +mike zipper 106 +nick allen 8 +nick allen 57 +nick brown 114 +nick davidson 84 +nick ellison 10 +nick ellison 107 +nick falkner 83 +nick falkner 86 +nick garcia 53 +nick garcia 69 +nick garcia 108 +nick ichabod 59 +nick ichabod 71 +nick ichabod 84 +nick johnson 47 +nick johnson 88 +nick laertes 17 +nick miller 101 +nick nixon 43 +nick ovid 42 +nick polk 1 +nick quirinius 22 +nick quirinius 36 +nick robinson 48 +nick robinson 54 +nick steinbeck 33 +nick thompson 73 +nick underhill 122 +nick van buren 53 +nick xylophone 80 +nick young 6 +nick young 60 +nick zipper 3 +nick zipper 21 +oscar allen 58 +oscar brown 80 +oscar carson 10 +oscar carson 27 +oscar carson 36 +oscar carson 72 +oscar carson 88 +oscar davidson 14 +oscar ellison 50 +oscar ellison 74 +oscar falkner 96 +oscar garcia 44 +oscar hernandez 1 +oscar hernandez 93 +oscar ichabod 20 +oscar ichabod 28 +oscar ichabod 69 +oscar ichabod 120 +oscar johnson 44 +oscar johnson 53 +oscar king 67 +oscar king 71 +oscar king 81 +oscar laertes 4 +oscar laertes 28 +oscar laertes 53 +oscar laertes 63 +oscar nixon 58 +oscar ovid 3 +oscar ovid 27 +oscar ovid 47 +oscar polk 8 +oscar polk 112 +oscar quirinius NULL +oscar quirinius 0 +oscar quirinius 17 +oscar quirinius 114 +oscar robinson 16 +oscar robinson 42 +oscar robinson 59 +oscar robinson 93 +oscar steinbeck 51 +oscar thompson 44 +oscar thompson 44 +oscar thompson 60 +oscar thompson 66 +oscar underhill 86 +oscar van buren 40 +oscar van buren 51 +oscar van buren 114 +oscar white 2 +oscar white 20 +oscar white 49 +oscar white 58 +oscar xylophone 18 +oscar xylophone 73 +oscar xylophone 74 +oscar zipper 0 +oscar zipper 23 +oscar zipper 95 +priscilla brown 51 +priscilla brown 75 +priscilla brown 97 +priscilla carson 16 +priscilla carson 52 +priscilla carson 124 +priscilla ichabod 117 +priscilla ichabod 122 +priscilla johnson 5 +priscilla johnson 17 +priscilla johnson 62 +priscilla johnson 77 +priscilla johnson 117 +priscilla king 43 +priscilla nixon 61 +priscilla nixon 66 +priscilla ovid 46 +priscilla ovid 118 +priscilla polk 45 +priscilla quirinius 83 +priscilla thompson 82 +priscilla underhill 117 +priscilla underhill 122 +priscilla van buren 0 +priscilla van buren 22 +priscilla van buren 102 +priscilla white 88 +priscilla xylophone 8 +priscilla xylophone 90 +priscilla xylophone 109 +priscilla young 17 +priscilla young 113 +priscilla zipper 27 +priscilla zipper 35 +quinn allen 27 +quinn allen 114 +quinn brown 70 +quinn brown 88 +quinn brown 117 +quinn davidson 93 +quinn davidson 93 +quinn davidson 109 +quinn davidson 121 +quinn ellison 83 +quinn ellison 116 +quinn garcia 78 +quinn garcia 104 +quinn garcia 110 +quinn garcia 120 +quinn ichabod 60 +quinn king 14 +quinn king 46 +quinn laertes -2 +quinn laertes 65 +quinn laertes 95 +quinn nixon 11 +quinn ovid 123 +quinn quirinius 94 +quinn robinson 60 +quinn steinbeck 82 +quinn steinbeck 122 +quinn thompson 41 +quinn thompson 60 +quinn underhill 19 +quinn underhill 28 +quinn underhill 34 +quinn van buren 18 +quinn young 15 +quinn zipper 44 +quinn zipper 103 +rachel allen 76 +rachel allen 122 +rachel brown 23 +rachel brown 56 +rachel brown 71 +rachel brown 101 +rachel brown 108 +rachel carson 27 +rachel carson 74 +rachel davidson 84 +rachel ellison 51 +rachel falkner -2 +rachel falkner 43 +rachel falkner 72 +rachel falkner 104 +rachel johnson 32 +rachel king 84 +rachel king 95 +rachel laertes 37 +rachel laertes 106 +rachel ovid 5 +rachel ovid 31 +rachel polk 79 +rachel quirinius 108 +rachel robinson 24 +rachel robinson 41 +rachel robinson 91 +rachel thompson -3 +rachel thompson -2 +rachel thompson 74 +rachel underhill 11 +rachel white 108 +rachel white 119 +rachel young 77 +rachel zipper 16 +rachel zipper 116 +sarah carson 41 +sarah carson 58 +sarah carson 119 +sarah ellison 14 +sarah falkner 112 +sarah falkner 123 +sarah garcia 72 +sarah garcia 91 +sarah garcia 98 +sarah ichabod 38 +sarah ichabod 80 +sarah johnson 5 +sarah johnson 51 +sarah johnson 69 +sarah johnson 116 +sarah king 13 +sarah king 120 +sarah miller 31 +sarah ovid 122 +sarah robinson 26 +sarah robinson 35 +sarah steinbeck 30 +sarah white 11 +sarah white 32 +sarah xylophone 28 +sarah young 120 +sarah zipper 107 +tom brown 27 +tom brown 89 +tom carson 11 +tom carson 70 +tom carson 123 +tom davidson 72 +tom ellison 28 +tom ellison 118 +tom ellison 120 +tom falkner 11 +tom falkner 35 +tom hernandez -3 +tom hernandez 118 +tom ichabod 19 +tom johnson 42 +tom johnson 82 +tom king 59 +tom laertes 33 +tom laertes 54 +tom miller 9 +tom miller 48 +tom miller 94 +tom nixon 45 +tom ovid 68 +tom polk 70 +tom polk 107 +tom quirinius 10 +tom quirinius 38 +tom robinson 52 +tom robinson 104 +tom robinson 109 +tom robinson 115 +tom steinbeck 113 +tom van buren 5 +tom van buren 48 +tom van buren 63 +tom white 81 +tom young 13 +tom young 99 +tom zipper 31 +ulysses brown 46 +ulysses carson 5 +ulysses carson 26 +ulysses carson 55 +ulysses carson 109 +ulysses davidson 18 +ulysses ellison 61 +ulysses garcia 12 +ulysses hernandez 9 +ulysses hernandez 22 +ulysses hernandez 53 +ulysses ichabod 32 +ulysses ichabod 99 +ulysses johnson 41 +ulysses king 2 +ulysses laertes 40 +ulysses laertes 51 +ulysses laertes 95 +ulysses miller 23 +ulysses miller 85 +ulysses nixon 92 +ulysses ovid 31 +ulysses polk 28 +ulysses polk 74 +ulysses polk 86 +ulysses polk 89 +ulysses quirinius 7 +ulysses robinson 79 +ulysses steinbeck 6 +ulysses steinbeck 45 +ulysses thompson 24 +ulysses underhill 6 +ulysses underhill 27 +ulysses underhill 42 +ulysses underhill 51 +ulysses underhill 93 +ulysses underhill 98 +ulysses underhill 111 +ulysses van buren 58 +ulysses white 67 +ulysses white 109 +ulysses xylophone 47 +ulysses xylophone 105 +ulysses xylophone 123 +ulysses young 61 +ulysses young 86 +ulysses young 89 +victor allen 2 +victor allen 17 +victor brown 0 +victor brown 23 +victor brown 60 +victor brown 64 +victor davidson 42 +victor davidson 89 +victor davidson 123 +victor ellison 35 +victor ellison 84 +victor hernandez 1 +victor hernandez 17 +victor hernandez 91 +victor hernandez 94 +victor hernandez 116 +victor johnson 34 +victor johnson 53 +victor johnson 57 +victor king 59 +victor king 112 +victor laertes 18 +victor laertes 118 +victor miller 79 +victor nixon 50 +victor nixon 104 +victor ovid 120 +victor polk 106 +victor quirinius 77 +victor quirinius 85 +victor robinson 29 +victor robinson 105 +victor steinbeck 20 +victor steinbeck 92 +victor steinbeck 100 +victor thompson 124 +victor van buren 41 +victor van buren 71 +victor white 15 +victor white 49 +victor xylophone -3 +victor xylophone 41 +victor xylophone 43 +victor xylophone 54 +victor xylophone 91 +victor young 24 +victor zipper 3 +wendy allen 25 +wendy allen 38 +wendy allen 95 +wendy brown 92 +wendy brown 119 +wendy ellison 53 +wendy ellison 103 +wendy falkner 23 +wendy falkner 28 +wendy falkner 58 +wendy garcia 3 +wendy garcia 48 +wendy garcia 60 +wendy garcia 99 +wendy hernandez 53 +wendy ichabod 87 +wendy king -2 +wendy king 45 +wendy king 124 +wendy laertes 26 +wendy laertes 51 +wendy laertes 72 +wendy miller 51 +wendy miller 105 +wendy nixon 5 +wendy nixon 25 +wendy ovid 17 +wendy ovid 85 +wendy polk 44 +wendy polk 99 +wendy quirinius 77 +wendy quirinius 88 +wendy robinson -3 +wendy robinson 71 +wendy robinson 97 +wendy steinbeck 37 +wendy thompson 28 +wendy thompson 31 +wendy underhill 58 +wendy underhill 82 +wendy underhill 120 +wendy van buren 27 +wendy van buren 82 +wendy white 63 +wendy xylophone 53 +wendy xylophone 119 +wendy young 66 +wendy young 112 +xavier allen 18 +xavier allen 41 +xavier allen 106 +xavier brown 10 +xavier brown 63 +xavier brown 108 +xavier carson 20 +xavier carson 57 +xavier davidson 21 +xavier davidson 24 +xavier davidson 106 +xavier ellison 0 +xavier ellison 53 +xavier garcia 42 +xavier hernandez 9 +xavier hernandez 80 +xavier hernandez 114 +xavier ichabod 20 +xavier ichabod 58 +xavier johnson 44 +xavier johnson 85 +xavier king 26 +xavier king 107 +xavier laertes 60 +xavier ovid 3 +xavier polk 29 +xavier polk 83 +xavier polk 91 +xavier polk 122 +xavier quirinius 27 +xavier quirinius 35 +xavier quirinius 39 +xavier quirinius 111 +xavier thompson 2 +xavier underhill 102 +xavier white 8 +xavier white 56 +xavier xylophone 24 +xavier zipper 48 +yuri allen 31 +yuri allen 121 +yuri brown 101 +yuri brown 106 +yuri carson 1 +yuri carson 36 +yuri ellison -1 +yuri ellison 43 +yuri falkner 31 +yuri falkner 96 +yuri garcia 49 +yuri hernandez 92 +yuri johnson 1 +yuri johnson 2 +yuri johnson 111 +yuri king 44 +yuri laertes 84 +yuri laertes 115 +yuri nixon 5 +yuri nixon 111 +yuri polk 13 +yuri polk 49 +yuri polk 115 +yuri quirinius 24 +yuri quirinius 28 +yuri quirinius 90 +yuri steinbeck 8 +yuri steinbeck 65 +yuri thompson 42 +yuri underhill 10 +yuri underhill 66 +yuri white 73 +yuri xylophone 63 +zach allen 35 +zach brown 7 +zach brown 15 +zach brown 37 +zach brown 61 +zach brown 94 +zach carson 114 +zach ellison 16 +zach falkner 70 +zach falkner 115 +zach garcia -2 +zach garcia 59 +zach garcia 68 +zach garcia 97 +zach ichabod 14 +zach ichabod 73 +zach king 66 +zach king 70 +zach king 81 +zach miller 4 +zach miller 9 +zach miller 73 +zach ovid 61 +zach ovid 68 +zach ovid 77 +zach ovid 114 +zach quirinius 79 +zach robinson 69 +zach steinbeck 6 +zach steinbeck 122 +zach thompson 75 +zach thompson 95 +zach underhill 123 +zach white 58 +zach xylophone 19 +zach xylophone 85 +zach young 11 +zach zipper 68 +zach zipper 100 +zach zipper 101 diff --git a/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-4-dfd39236756a3951bc1ec354799d69e4 b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-4-dfd39236756a3951bc1ec354799d69e4 new file mode 100644 index 000000000000..86ca4e49d21b --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-4-dfd39236756a3951bc1ec354799d69e4 @@ -0,0 +1,1049 @@ + + + +alice allen +alice allen +alice allen +alice brown +alice carson +alice davidson +alice falkner +alice garcia +alice hernandez +alice hernandez +alice johnson +alice king +alice king +alice king +alice laertes +alice laertes +alice miller +alice nixon +alice nixon +alice nixon +alice ovid +alice polk +alice quirinius +alice quirinius +alice robinson +alice robinson +alice steinbeck +alice steinbeck +alice steinbeck +alice underhill +alice van buren +alice xylophone +alice xylophone +alice xylophone +alice zipper +alice zipper +alice zipper +bob brown +bob brown +bob brown +bob carson +bob davidson +bob davidson +bob davidson +bob ellison +bob ellison +bob ellison +bob ellison +bob falkner +bob garcia +bob garcia +bob garcia +bob garcia +bob garcia +bob hernandez +bob ichabod +bob king +bob king +bob king +bob laertes +bob laertes +bob miller +bob ovid +bob ovid +bob ovid +bob ovid +bob polk +bob quirinius +bob steinbeck +bob van buren +bob white +bob white +bob xylophone +bob xylophone +bob young +bob zipper +bob zipper +bob zipper +calvin allen +calvin brown +calvin brown +calvin brown +calvin carson +calvin davidson +calvin davidson +calvin ellison +calvin falkner +calvin falkner +calvin falkner +calvin falkner +calvin falkner +calvin falkner +calvin garcia +calvin hernandez +calvin johnson +calvin laertes +calvin laertes +calvin nixon +calvin nixon +calvin nixon +calvin ovid +calvin ovid +calvin ovid +calvin ovid +calvin polk +calvin quirinius +calvin quirinius +calvin robinson +calvin steinbeck +calvin steinbeck +calvin steinbeck +calvin thompson +calvin thompson +calvin underhill +calvin van buren +calvin van buren +calvin white +calvin white +calvin xylophone +calvin xylophone +calvin xylophone +calvin young +calvin young +calvin zipper +calvin zipper +david allen +david allen +david brown +david brown +david davidson +david davidson +david davidson +david davidson +david ellison +david ellison +david ellison +david hernandez +david ichabod +david ichabod +david laertes +david nixon +david ovid +david ovid +david quirinius +david quirinius +david quirinius +david robinson +david robinson +david thompson +david underhill +david underhill +david underhill +david van buren +david van buren +david white +david xylophone +david xylophone +david xylophone +david young +david young +ethan allen +ethan brown +ethan brown +ethan brown +ethan brown +ethan brown +ethan brown +ethan carson +ethan ellison +ethan ellison +ethan falkner +ethan falkner +ethan garcia +ethan hernandez +ethan johnson +ethan king +ethan laertes +ethan laertes +ethan laertes +ethan laertes +ethan laertes +ethan laertes +ethan laertes +ethan miller +ethan nixon +ethan ovid +ethan polk +ethan polk +ethan polk +ethan polk +ethan quirinius +ethan quirinius +ethan quirinius +ethan robinson +ethan robinson +ethan underhill +ethan van buren +ethan white +ethan white +ethan xylophone +ethan zipper +ethan zipper +fred davidson +fred davidson +fred davidson +fred ellison +fred ellison +fred ellison +fred falkner +fred falkner +fred falkner +fred hernandez +fred ichabod +fred ichabod +fred johnson +fred king +fred king +fred laertes +fred miller +fred nixon +fred nixon +fred nixon +fred nixon +fred polk +fred polk +fred polk +fred polk +fred quirinius +fred quirinius +fred robinson +fred steinbeck +fred steinbeck +fred steinbeck +fred underhill +fred van buren +fred van buren +fred van buren +fred van buren +fred white +fred young +fred young +fred zipper +gabriella allen +gabriella allen +gabriella brown +gabriella brown +gabriella carson +gabriella davidson +gabriella ellison +gabriella ellison +gabriella falkner +gabriella falkner +gabriella falkner +gabriella garcia +gabriella hernandez +gabriella hernandez +gabriella ichabod +gabriella ichabod +gabriella ichabod +gabriella ichabod +gabriella ichabod +gabriella king +gabriella king +gabriella laertes +gabriella miller +gabriella ovid +gabriella ovid +gabriella polk +gabriella polk +gabriella steinbeck +gabriella steinbeck +gabriella thompson +gabriella thompson +gabriella thompson +gabriella van buren +gabriella van buren +gabriella white +gabriella young +gabriella young +gabriella zipper +gabriella zipper +holly allen +holly brown +holly brown +holly falkner +holly hernandez +holly hernandez +holly hernandez +holly hernandez +holly ichabod +holly ichabod +holly ichabod +holly johnson +holly johnson +holly johnson +holly king +holly king +holly laertes +holly miller +holly nixon +holly nixon +holly polk +holly polk +holly robinson +holly thompson +holly thompson +holly thompson +holly underhill +holly underhill +holly underhill +holly underhill +holly van buren +holly white +holly white +holly xylophone +holly young +holly young +holly zipper +holly zipper +irene allen +irene brown +irene brown +irene brown +irene carson +irene ellison +irene ellison +irene falkner +irene falkner +irene garcia +irene garcia +irene garcia +irene ichabod +irene ichabod +irene johnson +irene laertes +irene laertes +irene laertes +irene miller +irene nixon +irene nixon +irene nixon +irene ovid +irene ovid +irene ovid +irene polk +irene polk +irene polk +irene polk +irene polk +irene quirinius +irene quirinius +irene quirinius +irene robinson +irene steinbeck +irene thompson +irene underhill +irene underhill +irene van buren +irene van buren +irene xylophone +jessica brown +jessica carson +jessica carson +jessica carson +jessica davidson +jessica davidson +jessica davidson +jessica davidson +jessica ellison +jessica ellison +jessica falkner +jessica garcia +jessica garcia +jessica ichabod +jessica johnson +jessica johnson +jessica miller +jessica nixon +jessica nixon +jessica ovid +jessica ovid +jessica polk +jessica quirinius +jessica quirinius +jessica quirinius +jessica quirinius +jessica robinson +jessica thompson +jessica thompson +jessica underhill +jessica underhill +jessica underhill +jessica van buren +jessica white +jessica white +jessica white +jessica white +jessica white +jessica xylophone +jessica young +jessica young +jessica zipper +jessica zipper +jessica zipper +katie allen +katie brown +katie davidson +katie ellison +katie ellison +katie falkner +katie garcia +katie garcia +katie hernandez +katie ichabod +katie ichabod +katie ichabod +katie king +katie king +katie king +katie miller +katie miller +katie nixon +katie ovid +katie polk +katie polk +katie robinson +katie van buren +katie van buren +katie white +katie white +katie xylophone +katie young +katie young +katie young +katie zipper +katie zipper +luke allen +luke allen +luke allen +luke allen +luke allen +luke brown +luke davidson +luke davidson +luke ellison +luke ellison +luke ellison +luke falkner +luke falkner +luke garcia +luke garcia +luke ichabod +luke ichabod +luke johnson +luke johnson +luke johnson +luke laertes +luke laertes +luke laertes +luke laertes +luke laertes +luke miller +luke ovid +luke ovid +luke polk +luke polk +luke quirinius +luke robinson +luke robinson +luke thompson +luke underhill +luke underhill +luke underhill +luke van buren +luke white +luke xylophone +luke zipper +mike allen +mike brown +mike carson +mike carson +mike carson +mike davidson +mike davidson +mike ellison +mike ellison +mike ellison +mike ellison +mike ellison +mike falkner +mike garcia +mike garcia +mike garcia +mike hernandez +mike hernandez +mike ichabod +mike king +mike king +mike king +mike king +mike king +mike king +mike miller +mike nixon +mike nixon +mike polk +mike polk +mike polk +mike quirinius +mike steinbeck +mike steinbeck +mike steinbeck +mike steinbeck +mike van buren +mike van buren +mike white +mike white +mike white +mike white +mike young +mike young +mike young +mike zipper +mike zipper +mike zipper +nick allen +nick allen +nick brown +nick davidson +nick ellison +nick ellison +nick falkner +nick falkner +nick garcia +nick garcia +nick garcia +nick ichabod +nick ichabod +nick ichabod +nick johnson +nick johnson +nick laertes +nick miller +nick nixon +nick ovid +nick polk +nick quirinius +nick quirinius +nick robinson +nick robinson +nick steinbeck +nick thompson +nick underhill +nick van buren +nick xylophone +nick young +nick young +nick zipper +nick zipper +oscar allen +oscar brown +oscar carson +oscar carson +oscar carson +oscar carson +oscar carson +oscar davidson +oscar ellison +oscar ellison +oscar falkner +oscar garcia +oscar hernandez +oscar hernandez +oscar ichabod +oscar ichabod +oscar ichabod +oscar ichabod +oscar johnson +oscar johnson +oscar king +oscar king +oscar king +oscar laertes +oscar laertes +oscar laertes +oscar laertes +oscar nixon +oscar ovid +oscar ovid +oscar ovid +oscar polk +oscar polk +oscar quirinius +oscar quirinius +oscar quirinius +oscar quirinius +oscar robinson +oscar robinson +oscar robinson +oscar robinson +oscar steinbeck +oscar thompson +oscar thompson +oscar thompson +oscar thompson +oscar underhill +oscar van buren +oscar van buren +oscar van buren +oscar white +oscar white +oscar white +oscar white +oscar xylophone +oscar xylophone +oscar xylophone +oscar zipper +oscar zipper +oscar zipper +priscilla brown +priscilla brown +priscilla brown +priscilla carson +priscilla carson +priscilla carson +priscilla ichabod +priscilla ichabod +priscilla johnson +priscilla johnson +priscilla johnson +priscilla johnson +priscilla johnson +priscilla king +priscilla nixon +priscilla nixon +priscilla ovid +priscilla ovid +priscilla polk +priscilla quirinius +priscilla thompson +priscilla underhill +priscilla underhill +priscilla van buren +priscilla van buren +priscilla van buren +priscilla white +priscilla xylophone +priscilla xylophone +priscilla xylophone +priscilla young +priscilla young +priscilla zipper +priscilla zipper +quinn allen +quinn allen +quinn brown +quinn brown +quinn brown +quinn davidson +quinn davidson +quinn davidson +quinn davidson +quinn ellison +quinn ellison +quinn garcia +quinn garcia +quinn garcia +quinn garcia +quinn ichabod +quinn king +quinn king +quinn laertes +quinn laertes +quinn laertes +quinn nixon +quinn ovid +quinn quirinius +quinn robinson +quinn steinbeck +quinn steinbeck +quinn thompson +quinn thompson +quinn underhill +quinn underhill +quinn underhill +quinn van buren +quinn young +quinn zipper +quinn zipper +rachel allen +rachel allen +rachel brown +rachel brown +rachel brown +rachel brown +rachel brown +rachel carson +rachel carson +rachel davidson +rachel ellison +rachel falkner +rachel falkner +rachel falkner +rachel falkner +rachel johnson +rachel king +rachel king +rachel laertes +rachel laertes +rachel ovid +rachel ovid +rachel polk +rachel quirinius +rachel robinson +rachel robinson +rachel robinson +rachel thompson +rachel thompson +rachel thompson +rachel underhill +rachel white +rachel white +rachel young +rachel zipper +rachel zipper +sarah carson +sarah carson +sarah carson +sarah ellison +sarah falkner +sarah falkner +sarah garcia +sarah garcia +sarah garcia +sarah ichabod +sarah ichabod +sarah johnson +sarah johnson +sarah johnson +sarah johnson +sarah king +sarah king +sarah miller +sarah ovid +sarah robinson +sarah robinson +sarah steinbeck +sarah white +sarah white +sarah xylophone +sarah young +sarah zipper +tom brown +tom brown +tom carson +tom carson +tom carson +tom davidson +tom ellison +tom ellison +tom ellison +tom falkner +tom falkner +tom hernandez +tom hernandez +tom ichabod +tom johnson +tom johnson +tom king +tom laertes +tom laertes +tom miller +tom miller +tom miller +tom nixon +tom ovid +tom polk +tom polk +tom quirinius +tom quirinius +tom robinson +tom robinson +tom robinson +tom robinson +tom steinbeck +tom van buren +tom van buren +tom van buren +tom white +tom young +tom young +tom zipper +ulysses brown +ulysses carson +ulysses carson +ulysses carson +ulysses carson +ulysses davidson +ulysses ellison +ulysses garcia +ulysses hernandez +ulysses hernandez +ulysses hernandez +ulysses ichabod +ulysses ichabod +ulysses johnson +ulysses king +ulysses laertes +ulysses laertes +ulysses laertes +ulysses miller +ulysses miller +ulysses nixon +ulysses ovid +ulysses polk +ulysses polk +ulysses polk +ulysses polk +ulysses quirinius +ulysses robinson +ulysses steinbeck +ulysses steinbeck +ulysses thompson +ulysses underhill +ulysses underhill +ulysses underhill +ulysses underhill +ulysses underhill +ulysses underhill +ulysses underhill +ulysses van buren +ulysses white +ulysses white +ulysses xylophone +ulysses xylophone +ulysses xylophone +ulysses young +ulysses young +ulysses young +victor allen +victor allen +victor brown +victor brown +victor brown +victor brown +victor davidson +victor davidson +victor davidson +victor ellison +victor ellison +victor hernandez +victor hernandez +victor hernandez +victor hernandez +victor hernandez +victor johnson +victor johnson +victor johnson +victor king +victor king +victor laertes +victor laertes +victor miller +victor nixon +victor nixon +victor ovid +victor polk +victor quirinius +victor quirinius +victor robinson +victor robinson +victor steinbeck +victor steinbeck +victor steinbeck +victor thompson +victor van buren +victor van buren +victor white +victor white +victor xylophone +victor xylophone +victor xylophone +victor xylophone +victor xylophone +victor young +victor zipper +wendy allen +wendy allen +wendy allen +wendy brown +wendy brown +wendy ellison +wendy ellison +wendy falkner +wendy falkner +wendy falkner +wendy garcia +wendy garcia +wendy garcia +wendy garcia +wendy hernandez +wendy ichabod +wendy king +wendy king +wendy king +wendy laertes +wendy laertes +wendy laertes +wendy miller +wendy miller +wendy nixon +wendy nixon +wendy ovid +wendy ovid +wendy polk +wendy polk +wendy quirinius +wendy quirinius +wendy robinson +wendy robinson +wendy robinson +wendy steinbeck +wendy thompson +wendy thompson +wendy underhill +wendy underhill +wendy underhill +wendy van buren +wendy van buren +wendy white +wendy xylophone +wendy xylophone +wendy young +wendy young +xavier allen +xavier allen +xavier allen +xavier brown +xavier brown +xavier brown +xavier carson +xavier carson +xavier davidson +xavier davidson +xavier davidson +xavier ellison +xavier ellison +xavier garcia +xavier hernandez +xavier hernandez +xavier hernandez +xavier ichabod +xavier ichabod +xavier johnson +xavier johnson +xavier king +xavier king +xavier laertes +xavier ovid +xavier polk +xavier polk +xavier polk +xavier polk +xavier quirinius +xavier quirinius +xavier quirinius +xavier quirinius +xavier thompson +xavier underhill +xavier white +xavier white +xavier xylophone +xavier zipper +yuri allen +yuri allen +yuri brown +yuri brown +yuri carson +yuri carson +yuri ellison +yuri ellison +yuri falkner +yuri falkner +yuri garcia +yuri hernandez +yuri johnson +yuri johnson +yuri johnson +yuri king +yuri laertes +yuri laertes +yuri nixon +yuri nixon +yuri polk +yuri polk +yuri polk +yuri quirinius +yuri quirinius +yuri quirinius +yuri steinbeck +yuri steinbeck +yuri thompson +yuri underhill +yuri underhill +yuri white +yuri xylophone +zach allen +zach brown +zach brown +zach brown +zach brown +zach brown +zach carson +zach ellison +zach falkner +zach falkner +zach garcia +zach garcia +zach garcia +zach garcia +zach ichabod +zach ichabod +zach king +zach king +zach king +zach miller +zach miller +zach miller +zach ovid +zach ovid +zach ovid +zach ovid +zach quirinius +zach robinson +zach steinbeck +zach steinbeck +zach thompson +zach thompson +zach underhill +zach white +zach xylophone +zach xylophone +zach young +zach zipper +zach zipper +zach zipper diff --git a/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-5-8d0ee3e1605f38214bfad28a5ce897cc b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-5-8d0ee3e1605f38214bfad28a5ce897cc new file mode 100644 index 000000000000..ddb15e338263 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_navfn.q (deterministic)-5-8d0ee3e1605f38214bfad28a5ce897cc @@ -0,0 +1 @@ +10 oscar carson 65549 65549 diff --git a/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-0-b7cb25303831392a51cd996e758ac79a b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-0-b7cb25303831392a51cd996e758ac79a new file mode 100644 index 000000000000..42e5151fe211 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-0-b7cb25303831392a51cd996e758ac79a @@ -0,0 +1,1049 @@ +65536 1 +65536 1 +65536 1 +65536 1 +65536 1 +65536 1 +65536 1 +65537 1 +65537 1 +65537 1 +65537 1 +65537 1 +65538 1 +65538 1 +65538 1 +65539 1 +65539 1 +65539 1 +65540 1 +65541 1 +65541 1 +65541 1 +65541 1 +65541 1 +65541 1 +65541 1 +65541 1 +65541 1 +65542 1 +65542 1 +65542 1 +65542 1 +65542 1 +65543 1 +65543 1 +65544 1 +65544 1 +65544 1 +65544 1 +65544 1 +65544 1 +65544 1 +65544 2 +65545 1 +65545 1 +65545 1 +65546 2 +65547 1 +65547 1 +65547 1 +65547 1 +65548 1 +65548 1 +65548 1 +65548 1 +65548 1 +65548 1 +65548 2 +65549 1 +65549 1 +65549 1 +65549 1 +65549 1 +65549 1 +65549 1 +65549 3 +65550 1 +65550 1 +65550 1 +65550 1 +65550 1 +65550 1 +65550 2 +65551 1 +65551 1 +65551 1 +65552 1 +65552 1 +65552 1 +65552 1 +65552 2 +65553 1 +65553 1 +65553 1 +65553 1 +65553 1 +65554 1 +65554 1 +65555 2 +65556 1 +65556 1 +65556 1 +65556 1 +65556 1 +65557 1 +65557 1 +65558 1 +65559 1 +65559 1 +65559 1 +65559 1 +65559 1 +65560 1 +65560 1 +65560 1 +65560 2 +65560 2 +65561 1 +65561 2 +65561 2 +65562 1 +65562 1 +65562 1 +65562 1 +65562 1 +65562 2 +65563 1 +65563 1 +65563 1 +65563 1 +65564 1 +65564 1 +65564 1 +65564 1 +65564 2 +65564 2 +65564 2 +65565 1 +65565 1 +65566 1 +65566 2 +65567 1 +65568 1 +65568 1 +65568 1 +65569 1 +65570 1 +65570 1 +65570 1 +65570 1 +65570 1 +65570 2 +65571 1 +65571 1 +65571 1 +65571 2 +65572 1 +65572 1 +65572 1 +65572 1 +65572 1 +65573 1 +65573 1 +65573 3 +65574 1 +65574 1 +65574 1 +65575 1 +65575 1 +65575 1 +65575 2 +65576 1 +65576 1 +65576 1 +65576 3 +65577 1 +65578 1 +65578 1 +65578 1 +65578 1 +65578 1 +65578 1 +65578 2 +65579 1 +65579 1 +65579 1 +65579 1 +65579 3 +65580 1 +65580 2 +65580 2 +65581 1 +65581 1 +65581 1 +65581 1 +65581 2 +65581 2 +65582 1 +65582 1 +65582 1 +65582 1 +65583 1 +65583 1 +65583 2 +65583 2 +65584 1 +65584 1 +65584 1 +65585 1 +65585 1 +65585 1 +65585 1 +65585 2 +65586 1 +65586 1 +65586 1 +65586 1 +65586 2 +65587 1 +65587 1 +65587 1 +65587 2 +65587 3 +65588 1 +65588 1 +65588 1 +65588 1 +65588 2 +65589 1 +65589 1 +65589 1 +65589 1 +65589 2 +65590 1 +65590 1 +65590 1 +65590 1 +65590 1 +65591 1 +65591 1 +65591 2 +65591 3 +65592 1 +65593 1 +65593 1 +65594 1 +65594 1 +65594 1 +65594 3 +65595 1 +65595 1 +65595 1 +65595 2 +65595 2 +65595 4 +65596 1 +65596 1 +65596 1 +65596 2 +65596 2 +65597 1 +65597 1 +65597 2 +65598 1 +65598 2 +65599 1 +65599 1 +65599 1 +65599 1 +65599 3 +65600 1 +65600 1 +65600 2 +65600 2 +65601 1 +65601 2 +65601 2 +65602 1 +65602 1 +65602 1 +65602 1 +65602 1 +65602 2 +65603 1 +65603 1 +65603 1 +65603 2 +65603 2 +65604 1 +65604 2 +65604 2 +65604 3 +65605 2 +65606 1 +65606 1 +65606 1 +65606 1 +65606 2 +65606 2 +65607 1 +65607 1 +65607 1 +65607 1 +65607 2 +65607 3 +65608 1 +65608 1 +65608 1 +65608 2 +65609 1 +65610 1 +65610 1 +65610 3 +65610 4 +65610 5 +65611 1 +65611 1 +65611 1 +65611 1 +65612 1 +65612 2 +65612 2 +65612 3 +65613 1 +65614 1 +65614 1 +65614 2 +65615 1 +65615 1 +65615 2 +65615 2 +65616 1 +65616 2 +65617 1 +65617 2 +65617 2 +65617 2 +65618 1 +65618 1 +65618 2 +65618 2 +65618 2 +65619 1 +65619 1 +65619 1 +65619 1 +65619 2 +65619 2 +65620 1 +65620 1 +65620 1 +65620 1 +65620 1 +65620 3 +65621 1 +65621 1 +65622 1 +65622 1 +65622 1 +65622 3 +65622 3 +65622 4 +65623 1 +65623 1 +65623 1 +65623 1 +65623 1 +65623 2 +65623 4 +65624 1 +65624 2 +65624 2 +65624 2 +65624 4 +65625 1 +65625 1 +65625 1 +65625 1 +65626 1 +65626 1 +65626 1 +65626 2 +65627 1 +65627 1 +65627 2 +65627 3 +65628 1 +65628 1 +65628 1 +65628 2 +65628 2 +65628 3 +65629 1 +65629 1 +65629 1 +65629 3 +65630 1 +65630 2 +65631 1 +65632 2 +65632 2 +65633 1 +65633 2 +65633 2 +65633 2 +65634 1 +65634 1 +65634 1 +65634 2 +65635 1 +65635 2 +65635 3 +65636 1 +65636 1 +65636 2 +65637 1 +65637 1 +65637 1 +65637 1 +65637 1 +65637 1 +65637 2 +65637 3 +65638 1 +65638 1 +65638 2 +65639 1 +65640 2 +65641 1 +65641 1 +65641 1 +65641 5 +65642 4 +65643 1 +65643 1 +65643 2 +65643 2 +65643 2 +65643 2 +65643 2 +65643 2 +65643 4 +65644 1 +65644 1 +65644 1 +65644 1 +65644 2 +65644 2 +65644 3 +65645 1 +65645 1 +65645 1 +65646 1 +65646 1 +65646 1 +65647 1 +65647 2 +65648 1 +65648 1 +65648 2 +65648 2 +65649 1 +65649 1 +65649 2 +65649 2 +65650 1 +65650 1 +65650 1 +65650 1 +65650 2 +65650 2 +65651 1 +65651 1 +65651 2 +65651 2 +65651 3 +65652 1 +65652 1 +65652 2 +65653 1 +65653 2 +65653 2 +65653 3 +65654 1 +65654 1 +65654 1 +65654 1 +65654 1 +65654 1 +65654 2 +65654 2 +65654 2 +65655 1 +65655 1 +65656 1 +65656 1 +65656 1 +65656 1 +65656 2 +65656 3 +65657 1 +65657 2 +65657 2 +65657 2 +65657 2 +65658 1 +65658 1 +65658 1 +65658 1 +65658 1 +65658 2 +65658 2 +65658 2 +65658 2 +65658 2 +65659 1 +65659 1 +65659 1 +65659 1 +65659 1 +65659 2 +65659 2 +65659 3 +65660 1 +65660 1 +65660 2 +65661 1 +65661 2 +65661 2 +65661 2 +65661 3 +65662 1 +65662 1 +65662 2 +65662 2 +65662 2 +65662 2 +65663 2 +65663 2 +65663 2 +65663 3 +65664 1 +65664 1 +65664 1 +65664 1 +65664 1 +65664 2 +65664 2 +65664 2 +65665 2 +65666 1 +65666 1 +65666 2 +65667 1 +65667 1 +65667 1 +65667 1 +65668 3 +65669 1 +65669 1 +65669 1 +65669 1 +65669 1 +65669 3 +65670 1 +65670 1 +65670 2 +65670 2 +65670 3 +65671 2 +65671 2 +65671 3 +65672 1 +65672 1 +65672 1 +65672 2 +65672 2 +65673 2 +65673 2 +65673 3 +65673 4 +65673 4 +65674 1 +65675 1 +65675 1 +65675 2 +65675 2 +65675 2 +65675 2 +65675 3 +65675 3 +65676 1 +65676 1 +65676 2 +65676 3 +65677 1 +65677 1 +65677 1 +65677 2 +65677 3 +65677 4 +65678 1 +65678 1 +65678 1 +65679 1 +65679 2 +65680 1 +65680 1 +65680 2 +65680 2 +65680 2 +65680 5 +65681 1 +65681 2 +65681 2 +65681 2 +65681 4 +65681 4 +65682 1 +65682 2 +65682 2 +65683 1 +65683 1 +65683 2 +65684 1 +65684 1 +65684 2 +65685 2 +65685 2 +65685 2 +65685 3 +65685 3 +65685 4 +65686 1 +65686 2 +65686 3 +65686 3 +65687 1 +65687 1 +65687 2 +65687 3 +65688 1 +65688 2 +65689 1 +65689 3 +65690 2 +65690 2 +65691 1 +65691 1 +65691 1 +65691 1 +65691 1 +65691 2 +65691 3 +65691 3 +65692 1 +65692 2 +65693 1 +65693 2 +65693 2 +65693 2 +65693 2 +65693 3 +65693 5 +65694 1 +65694 2 +65694 2 +65694 3 +65695 1 +65695 2 +65695 2 +65695 3 +65696 1 +65696 2 +65696 2 +65696 4 +65697 1 +65697 1 +65697 1 +65697 1 +65697 2 +65697 5 +65698 1 +65698 2 +65699 1 +65699 1 +65699 1 +65699 2 +65699 2 +65699 2 +65700 1 +65700 2 +65700 2 +65701 1 +65701 2 +65701 3 +65702 2 +65702 2 +65702 2 +65702 3 +65703 1 +65703 2 +65703 3 +65703 3 +65704 1 +65704 1 +65704 2 +65704 2 +65704 3 +65704 3 +65705 1 +65705 1 +65705 3 +65705 4 +65706 1 +65706 1 +65706 1 +65706 3 +65706 4 +65707 2 +65707 2 +65708 2 +65708 3 +65709 1 +65709 1 +65709 2 +65710 1 +65710 2 +65711 1 +65711 2 +65711 2 +65711 2 +65711 2 +65711 2 +65712 1 +65712 2 +65712 3 +65712 3 +65712 4 +65713 1 +65713 2 +65713 3 +65713 3 +65713 6 +65714 1 +65714 1 +65714 2 +65715 1 +65715 1 +65715 2 +65715 2 +65715 4 +65716 1 +65716 2 +65716 2 +65716 2 +65716 4 +65716 4 +65717 1 +65717 2 +65717 2 +65717 2 +65717 5 +65718 1 +65718 2 +65718 3 +65718 3 +65719 1 +65719 1 +65719 2 +65719 3 +65720 1 +65720 2 +65720 2 +65720 3 +65720 4 +65721 1 +65721 1 +65721 1 +65721 1 +65721 3 +65721 3 +65721 3 +65722 2 +65722 3 +65722 5 +65723 2 +65723 3 +65724 1 +65724 1 +65724 2 +65724 2 +65724 3 +65724 3 +65725 1 +65726 2 +65726 2 +65726 4 +65727 1 +65727 1 +65727 3 +65727 3 +65727 4 +65728 1 +65728 2 +65729 2 +65730 1 +65730 1 +65730 2 +65730 4 +65731 1 +65731 1 +65731 1 +65731 3 +65732 1 +65732 2 +65732 2 +65733 1 +65733 1 +65733 2 +65733 3 +65733 3 +65733 6 +65734 2 +65734 2 +65735 1 +65735 4 +65735 4 +65736 1 +65736 2 +65736 2 +65736 3 +65737 1 +65737 1 +65737 2 +65737 3 +65737 4 +65738 3 +65738 3 +65738 4 +65738 4 +65739 1 +65739 2 +65739 3 +65740 2 +65740 2 +65740 3 +65741 1 +65742 1 +65742 2 +65742 3 +65743 1 +65743 1 +65743 1 +65743 2 +65744 1 +65744 2 +65745 2 +65745 2 +65745 3 +65745 6 +65746 1 +65746 2 +65746 2 +65746 2 +65746 2 +65747 1 +65747 1 +65747 1 +65747 2 +65747 2 +65747 3 +65747 3 +65748 1 +65748 3 +65749 2 +65749 3 +65749 3 +65749 3 +65749 4 +65750 1 +65750 1 +65750 2 +65750 3 +65750 3 +65751 1 +65751 2 +65751 2 +65751 3 +65751 3 +65751 4 +65752 1 +65752 3 +65753 2 +65753 3 +65754 2 +65754 4 +65755 1 +65755 2 +65755 2 +65755 3 +65755 3 +65755 3 +65755 3 +65755 5 +65755 5 +65756 2 +65756 3 +65756 3 +65756 5 +65757 1 +65757 1 +65757 1 +65757 2 +65757 3 +65758 1 +65758 2 +65758 2 +65758 3 +65758 4 +65759 2 +65759 2 +65759 2 +65759 4 +65759 4 +65760 2 +65760 4 +65760 5 +65760 7 +65761 1 +65762 1 +65762 1 +65762 2 +65762 4 +65762 5 +65763 1 +65763 2 +65763 2 +65763 2 +65763 4 +65764 3 +65764 3 +65764 4 +65765 2 +65765 2 +65765 3 +65766 1 +65766 1 +65766 3 +65766 3 +65766 3 +65767 2 +65767 3 +65767 3 +65768 3 +65769 1 +65769 2 +65769 2 +65769 2 +65769 3 +65769 5 +65770 2 +65770 3 +65771 2 +65771 3 +65771 3 +65772 2 +65773 1 +65773 2 +65773 2 +65773 2 +65773 2 +65773 3 +65773 3 +65773 4 +65774 2 +65774 2 +65774 2 +65774 2 +65774 3 +65774 3 +65775 1 +65775 2 +65775 3 +65775 3 +65775 5 +65776 1 +65776 3 +65776 3 +65776 5 +65776 6 +65777 2 +65777 3 +65777 4 +65777 4 +65778 1 +65778 2 +65778 2 +65778 2 +65778 3 +65778 3 +65778 4 +65778 5 +65779 3 +65779 3 +65779 3 +65779 3 +65779 4 +65779 4 +65779 5 +65780 1 +65780 2 +65780 3 +65781 1 +65781 3 +65781 3 +65782 2 +65782 2 +65782 3 +65782 4 +65782 5 +65783 2 +65783 3 +65783 3 +65783 3 +65783 3 +65783 3 +65784 2 +65784 2 +65784 2 +65784 4 +65784 6 +65785 2 +65785 7 +65786 1 +65786 2 +65786 3 +65786 3 +65786 4 +65787 1 +65787 2 +65787 3 +65787 3 +65788 1 +65788 1 +65788 2 +65788 4 +65789 1 +65789 1 +65789 2 +65789 2 +65789 2 +65789 2 +65789 4 +65790 2 +65790 2 +65790 4 +65791 2 +65791 2 diff --git a/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-1-a3d352560ac835993001665db6954965 b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-1-a3d352560ac835993001665db6954965 new file mode 100644 index 000000000000..dc72606a83db --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-1-a3d352560ac835993001665db6954965 @@ -0,0 +1,1049 @@ + 1 + 1 + 1 +alice allen 1 +alice allen 1 +alice allen 1 +alice brown 1 +alice carson 1 +alice davidson 1 +alice falkner 1 +alice garcia 1 +alice hernandez 1 +alice hernandez 1 +alice johnson 1 +alice king 1 +alice king 1 +alice king 1 +alice laertes 1 +alice laertes 2 +alice miller 1 +alice nixon 1 +alice nixon 1 +alice nixon 1 +alice ovid 2 +alice polk 1 +alice quirinius 1 +alice quirinius 1 +alice robinson 1 +alice robinson 1 +alice steinbeck 1 +alice steinbeck 1 +alice steinbeck 1 +alice underhill 1 +alice van buren 1 +alice xylophone 1 +alice xylophone 1 +alice xylophone 2 +alice zipper 1 +alice zipper 1 +alice zipper 2 +bob brown 1 +bob brown 1 +bob brown 1 +bob carson 1 +bob davidson 1 +bob davidson 1 +bob davidson 1 +bob ellison 1 +bob ellison 1 +bob ellison 1 +bob ellison 1 +bob falkner 2 +bob garcia 1 +bob garcia 1 +bob garcia 1 +bob garcia 2 +bob garcia 2 +bob hernandez 1 +bob ichabod 1 +bob king 1 +bob king 1 +bob king 2 +bob laertes 1 +bob laertes 1 +bob miller 1 +bob ovid 1 +bob ovid 1 +bob ovid 1 +bob ovid 1 +bob polk 1 +bob quirinius 1 +bob steinbeck 1 +bob van buren 1 +bob white 1 +bob white 1 +bob xylophone 1 +bob xylophone 1 +bob young 1 +bob zipper 1 +bob zipper 2 +bob zipper 2 +calvin allen 3 +calvin brown 1 +calvin brown 1 +calvin brown 1 +calvin carson 2 +calvin davidson 1 +calvin davidson 2 +calvin ellison 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 2 +calvin falkner 2 +calvin falkner 2 +calvin garcia 2 +calvin hernandez 3 +calvin johnson 1 +calvin laertes 1 +calvin laertes 1 +calvin nixon 1 +calvin nixon 1 +calvin nixon 1 +calvin ovid 1 +calvin ovid 1 +calvin ovid 2 +calvin ovid 2 +calvin polk 2 +calvin quirinius 1 +calvin quirinius 1 +calvin robinson 1 +calvin steinbeck 1 +calvin steinbeck 1 +calvin steinbeck 2 +calvin thompson 1 +calvin thompson 2 +calvin underhill 1 +calvin van buren 1 +calvin van buren 1 +calvin white 1 +calvin white 2 +calvin xylophone 1 +calvin xylophone 2 +calvin xylophone 2 +calvin young 1 +calvin young 2 +calvin zipper 3 +calvin zipper 4 +david allen 1 +david allen 1 +david brown 2 +david brown 3 +david davidson 1 +david davidson 2 +david davidson 3 +david davidson 3 +david ellison 1 +david ellison 2 +david ellison 3 +david hernandez 1 +david ichabod 1 +david ichabod 3 +david laertes 3 +david nixon 1 +david ovid 1 +david ovid 1 +david quirinius 1 +david quirinius 1 +david quirinius 3 +david robinson 1 +david robinson 4 +david thompson 1 +david underhill 1 +david underhill 2 +david underhill 3 +david van buren 1 +david van buren 2 +david white 1 +david xylophone 1 +david xylophone 1 +david xylophone 2 +david young 1 +david young 1 +ethan allen 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan brown 2 +ethan brown 2 +ethan carson 1 +ethan ellison 1 +ethan ellison 2 +ethan falkner 1 +ethan falkner 1 +ethan garcia 1 +ethan hernandez 1 +ethan johnson 2 +ethan king 2 +ethan laertes 1 +ethan laertes 1 +ethan laertes 2 +ethan laertes 2 +ethan laertes 2 +ethan laertes 2 +ethan laertes 3 +ethan miller 1 +ethan nixon 2 +ethan ovid 2 +ethan polk 1 +ethan polk 1 +ethan polk 1 +ethan polk 2 +ethan quirinius 1 +ethan quirinius 1 +ethan quirinius 2 +ethan robinson 1 +ethan robinson 2 +ethan underhill 2 +ethan van buren 1 +ethan white 1 +ethan white 2 +ethan xylophone 2 +ethan zipper 1 +ethan zipper 3 +fred davidson 2 +fred davidson 2 +fred davidson 3 +fred ellison 1 +fred ellison 2 +fred ellison 2 +fred falkner 1 +fred falkner 3 +fred falkner 4 +fred hernandez 2 +fred ichabod 2 +fred ichabod 3 +fred johnson 2 +fred king 2 +fred king 2 +fred laertes 2 +fred miller 3 +fred nixon 1 +fred nixon 1 +fred nixon 1 +fred nixon 3 +fred polk 1 +fred polk 1 +fred polk 1 +fred polk 2 +fred quirinius 2 +fred quirinius 3 +fred robinson 2 +fred steinbeck 1 +fred steinbeck 1 +fred steinbeck 1 +fred underhill 1 +fred van buren 2 +fred van buren 3 +fred van buren 3 +fred van buren 4 +fred white 2 +fred young 1 +fred young 2 +fred zipper 3 +gabriella allen 1 +gabriella allen 3 +gabriella brown 1 +gabriella brown 1 +gabriella carson 2 +gabriella davidson 1 +gabriella ellison 1 +gabriella ellison 3 +gabriella falkner 1 +gabriella falkner 1 +gabriella falkner 3 +gabriella garcia 1 +gabriella hernandez 1 +gabriella hernandez 1 +gabriella ichabod 1 +gabriella ichabod 2 +gabriella ichabod 2 +gabriella ichabod 2 +gabriella ichabod 4 +gabriella king 1 +gabriella king 2 +gabriella laertes 1 +gabriella miller 2 +gabriella ovid 1 +gabriella ovid 2 +gabriella polk 1 +gabriella polk 2 +gabriella steinbeck 1 +gabriella steinbeck 1 +gabriella thompson 1 +gabriella thompson 2 +gabriella thompson 3 +gabriella van buren 1 +gabriella van buren 2 +gabriella white 1 +gabriella young 1 +gabriella young 2 +gabriella zipper 1 +gabriella zipper 2 +holly allen 3 +holly brown 2 +holly brown 2 +holly falkner 2 +holly hernandez 2 +holly hernandez 2 +holly hernandez 2 +holly hernandez 3 +holly ichabod 1 +holly ichabod 2 +holly ichabod 2 +holly johnson 1 +holly johnson 3 +holly johnson 4 +holly king 2 +holly king 2 +holly laertes 3 +holly miller 2 +holly nixon 1 +holly nixon 2 +holly polk 1 +holly polk 2 +holly robinson 3 +holly thompson 1 +holly thompson 3 +holly thompson 4 +holly underhill 2 +holly underhill 2 +holly underhill 3 +holly underhill 3 +holly van buren 1 +holly white 4 +holly white 4 +holly xylophone 2 +holly young 1 +holly young 2 +holly zipper 1 +holly zipper 4 +irene allen 3 +irene brown 1 +irene brown 2 +irene brown 3 +irene carson 2 +irene ellison 2 +irene ellison 2 +irene falkner 1 +irene falkner 1 +irene garcia 1 +irene garcia 2 +irene garcia 3 +irene ichabod 1 +irene ichabod 2 +irene johnson 2 +irene laertes 1 +irene laertes 3 +irene laertes 4 +irene miller 1 +irene nixon 1 +irene nixon 3 +irene nixon 3 +irene ovid 2 +irene ovid 2 +irene ovid 2 +irene polk 1 +irene polk 1 +irene polk 2 +irene polk 2 +irene polk 4 +irene quirinius 2 +irene quirinius 3 +irene quirinius 4 +irene robinson 2 +irene steinbeck 1 +irene thompson 1 +irene underhill 2 +irene underhill 3 +irene van buren 2 +irene van buren 3 +irene xylophone 2 +jessica brown 2 +jessica carson 1 +jessica carson 2 +jessica carson 4 +jessica davidson 1 +jessica davidson 2 +jessica davidson 3 +jessica davidson 3 +jessica ellison 1 +jessica ellison 3 +jessica falkner 2 +jessica garcia 1 +jessica garcia 5 +jessica ichabod 2 +jessica johnson 2 +jessica johnson 3 +jessica miller 2 +jessica nixon 2 +jessica nixon 3 +jessica ovid 2 +jessica ovid 3 +jessica polk 5 +jessica quirinius 2 +jessica quirinius 2 +jessica quirinius 3 +jessica quirinius 3 +jessica robinson 1 +jessica thompson 2 +jessica thompson 3 +jessica underhill 2 +jessica underhill 2 +jessica underhill 4 +jessica van buren 2 +jessica white 3 +jessica white 3 +jessica white 3 +jessica white 3 +jessica white 4 +jessica xylophone 4 +jessica young 4 +jessica young 4 +jessica zipper 1 +jessica zipper 2 +jessica zipper 4 +katie allen 2 +katie brown 4 +katie davidson 3 +katie ellison 3 +katie ellison 3 +katie falkner 2 +katie garcia 2 +katie garcia 3 +katie hernandez 2 +katie ichabod 2 +katie ichabod 2 +katie ichabod 2 +katie king 1 +katie king 1 +katie king 2 +katie miller 2 +katie miller 3 +katie nixon 5 +katie ovid 1 +katie polk 2 +katie polk 3 +katie robinson 4 +katie van buren 2 +katie van buren 4 +katie white 1 +katie white 2 +katie xylophone 3 +katie young 2 +katie young 2 +katie young 3 +katie zipper 1 +katie zipper 3 +luke allen 2 +luke allen 2 +luke allen 2 +luke allen 3 +luke allen 3 +luke brown 2 +luke davidson 1 +luke davidson 3 +luke ellison 3 +luke ellison 5 +luke ellison 5 +luke falkner 2 +luke falkner 4 +luke garcia 1 +luke garcia 5 +luke ichabod 3 +luke ichabod 3 +luke johnson 1 +luke johnson 2 +luke johnson 3 +luke laertes 2 +luke laertes 3 +luke laertes 3 +luke laertes 3 +luke laertes 3 +luke miller 2 +luke ovid 1 +luke ovid 2 +luke polk 2 +luke polk 3 +luke quirinius 2 +luke robinson 1 +luke robinson 4 +luke thompson 1 +luke underhill 2 +luke underhill 3 +luke underhill 5 +luke van buren 2 +luke white 3 +luke xylophone 2 +luke zipper 1 +mike allen 4 +mike brown 4 +mike carson 1 +mike carson 2 +mike carson 4 +mike davidson 3 +mike davidson 4 +mike ellison 2 +mike ellison 3 +mike ellison 3 +mike ellison 4 +mike ellison 4 +mike falkner 1 +mike garcia 1 +mike garcia 2 +mike garcia 3 +mike hernandez 2 +mike hernandez 3 +mike ichabod 1 +mike king 1 +mike king 1 +mike king 3 +mike king 3 +mike king 4 +mike king 4 +mike miller 4 +mike nixon 3 +mike nixon 4 +mike polk 4 +mike polk 5 +mike polk 5 +mike quirinius 3 +mike steinbeck 2 +mike steinbeck 3 +mike steinbeck 3 +mike steinbeck 4 +mike van buren 2 +mike van buren 3 +mike white 3 +mike white 4 +mike white 5 +mike white 6 +mike young 2 +mike young 2 +mike young 4 +mike zipper 1 +mike zipper 4 +mike zipper 6 +nick allen 2 +nick allen 2 +nick brown 5 +nick davidson 1 +nick ellison 3 +nick ellison 4 +nick falkner 2 +nick falkner 3 +nick garcia 2 +nick garcia 4 +nick garcia 4 +nick ichabod 3 +nick ichabod 3 +nick ichabod 3 +nick johnson 4 +nick johnson 4 +nick laertes 2 +nick miller 3 +nick nixon 2 +nick ovid 3 +nick polk 4 +nick quirinius 1 +nick quirinius 3 +nick robinson 4 +nick robinson 4 +nick steinbeck 1 +nick thompson 2 +nick underhill 6 +nick van buren 2 +nick xylophone 3 +nick young 3 +nick young 5 +nick zipper 3 +nick zipper 4 +oscar allen 4 +oscar brown 3 +oscar carson 3 +oscar carson 4 +oscar carson 4 +oscar carson 5 +oscar carson 5 +oscar davidson 4 +oscar ellison 1 +oscar ellison 2 +oscar falkner 1 +oscar garcia 5 +oscar hernandez 1 +oscar hernandez 2 +oscar ichabod 2 +oscar ichabod 5 +oscar ichabod 5 +oscar ichabod 6 +oscar johnson 3 +oscar johnson 7 +oscar king 3 +oscar king 4 +oscar king 4 +oscar laertes 1 +oscar laertes 2 +oscar laertes 3 +oscar laertes 3 +oscar nixon 4 +oscar ovid 3 +oscar ovid 3 +oscar ovid 6 +oscar polk 5 +oscar polk 5 +oscar quirinius 2 +oscar quirinius 4 +oscar quirinius 5 +oscar quirinius 6 +oscar robinson 2 +oscar robinson 3 +oscar robinson 5 +oscar robinson 6 +oscar steinbeck 1 +oscar thompson 2 +oscar thompson 3 +oscar thompson 3 +oscar thompson 4 +oscar underhill 2 +oscar van buren 1 +oscar van buren 2 +oscar van buren 5 +oscar white 1 +oscar white 2 +oscar white 5 +oscar white 5 +oscar xylophone 3 +oscar xylophone 3 +oscar xylophone 4 +oscar zipper 2 +oscar zipper 2 +oscar zipper 2 +priscilla brown 2 +priscilla brown 2 +priscilla brown 4 +priscilla carson 3 +priscilla carson 5 +priscilla carson 7 +priscilla ichabod 1 +priscilla ichabod 4 +priscilla johnson 1 +priscilla johnson 2 +priscilla johnson 4 +priscilla johnson 4 +priscilla johnson 6 +priscilla king 3 +priscilla nixon 3 +priscilla nixon 6 +priscilla ovid 3 +priscilla ovid 7 +priscilla polk 4 +priscilla quirinius 3 +priscilla thompson 6 +priscilla underhill 1 +priscilla underhill 4 +priscilla van buren 3 +priscilla van buren 5 +priscilla van buren 5 +priscilla white 1 +priscilla xylophone 2 +priscilla xylophone 3 +priscilla xylophone 3 +priscilla young 5 +priscilla young 8 +priscilla zipper 3 +priscilla zipper 3 +quinn allen 1 +quinn allen 4 +quinn brown 3 +quinn brown 4 +quinn brown 4 +quinn davidson 2 +quinn davidson 4 +quinn davidson 6 +quinn davidson 7 +quinn ellison 3 +quinn ellison 8 +quinn garcia 2 +quinn garcia 3 +quinn garcia 3 +quinn garcia 5 +quinn ichabod 7 +quinn king 1 +quinn king 1 +quinn laertes 2 +quinn laertes 4 +quinn laertes 5 +quinn nixon 4 +quinn ovid 4 +quinn quirinius 5 +quinn robinson 3 +quinn steinbeck 4 +quinn steinbeck 5 +quinn thompson 4 +quinn thompson 6 +quinn underhill 2 +quinn underhill 3 +quinn underhill 7 +quinn van buren 1 +quinn young 2 +quinn zipper 3 +quinn zipper 4 +rachel allen 2 +rachel allen 3 +rachel brown 2 +rachel brown 3 +rachel brown 4 +rachel brown 4 +rachel brown 5 +rachel carson 2 +rachel carson 4 +rachel davidson 7 +rachel ellison 1 +rachel falkner 1 +rachel falkner 3 +rachel falkner 5 +rachel falkner 6 +rachel johnson 9 +rachel king 3 +rachel king 7 +rachel laertes 4 +rachel laertes 6 +rachel ovid 3 +rachel ovid 4 +rachel polk 3 +rachel quirinius 4 +rachel robinson 4 +rachel robinson 4 +rachel robinson 6 +rachel thompson 4 +rachel thompson 5 +rachel thompson 5 +rachel underhill 2 +rachel white 4 +rachel white 5 +rachel young 4 +rachel zipper 1 +rachel zipper 5 +sarah carson 1 +sarah carson 4 +sarah carson 7 +sarah ellison 1 +sarah falkner 4 +sarah falkner 5 +sarah garcia 2 +sarah garcia 2 +sarah garcia 4 +sarah ichabod 3 +sarah ichabod 3 +sarah johnson 3 +sarah johnson 5 +sarah johnson 5 +sarah johnson 6 +sarah king 3 +sarah king 5 +sarah miller 2 +sarah ovid 5 +sarah robinson 5 +sarah robinson 5 +sarah steinbeck 5 +sarah white 4 +sarah white 6 +sarah xylophone 3 +sarah young 5 +sarah zipper 6 +tom brown 2 +tom brown 5 +tom carson 1 +tom carson 3 +tom carson 5 +tom davidson 2 +tom ellison 3 +tom ellison 4 +tom ellison 6 +tom falkner 3 +tom falkner 4 +tom hernandez 1 +tom hernandez 3 +tom ichabod 4 +tom johnson 6 +tom johnson 7 +tom king 3 +tom laertes 3 +tom laertes 3 +tom miller 3 +tom miller 4 +tom miller 5 +tom nixon 4 +tom ovid 3 +tom polk 2 +tom polk 2 +tom quirinius 3 +tom quirinius 5 +tom robinson 2 +tom robinson 3 +tom robinson 3 +tom robinson 5 +tom steinbeck 2 +tom van buren 2 +tom van buren 3 +tom van buren 6 +tom white 5 +tom young 1 +tom young 5 +tom zipper 7 +ulysses brown 2 +ulysses carson 2 +ulysses carson 5 +ulysses carson 6 +ulysses carson 8 +ulysses davidson 3 +ulysses ellison 4 +ulysses garcia 3 +ulysses hernandez 3 +ulysses hernandez 3 +ulysses hernandez 4 +ulysses ichabod 1 +ulysses ichabod 3 +ulysses johnson 5 +ulysses king 2 +ulysses laertes 2 +ulysses laertes 5 +ulysses laertes 6 +ulysses miller 5 +ulysses miller 7 +ulysses nixon 4 +ulysses ovid 4 +ulysses polk 2 +ulysses polk 2 +ulysses polk 3 +ulysses polk 4 +ulysses quirinius 4 +ulysses robinson 1 +ulysses steinbeck 2 +ulysses steinbeck 5 +ulysses thompson 3 +ulysses underhill 2 +ulysses underhill 2 +ulysses underhill 3 +ulysses underhill 4 +ulysses underhill 4 +ulysses underhill 4 +ulysses underhill 5 +ulysses van buren 2 +ulysses white 6 +ulysses white 7 +ulysses xylophone 2 +ulysses xylophone 3 +ulysses xylophone 6 +ulysses young 1 +ulysses young 4 +ulysses young 7 +victor allen 2 +victor allen 3 +victor brown 1 +victor brown 4 +victor brown 5 +victor brown 7 +victor davidson 4 +victor davidson 4 +victor davidson 6 +victor ellison 4 +victor ellison 4 +victor hernandez 2 +victor hernandez 3 +victor hernandez 4 +victor hernandez 4 +victor hernandez 6 +victor johnson 4 +victor johnson 5 +victor johnson 6 +victor king 2 +victor king 6 +victor laertes 3 +victor laertes 5 +victor miller 5 +victor nixon 2 +victor nixon 3 +victor ovid 3 +victor polk 4 +victor quirinius 5 +victor quirinius 5 +victor robinson 5 +victor robinson 5 +victor steinbeck 3 +victor steinbeck 4 +victor steinbeck 5 +victor thompson 6 +victor van buren 5 +victor van buren 6 +victor white 2 +victor white 7 +victor xylophone 4 +victor xylophone 6 +victor xylophone 6 +victor xylophone 8 +victor xylophone 8 +victor young 5 +victor zipper 3 +wendy allen 5 +wendy allen 6 +wendy allen 6 +wendy brown 3 +wendy brown 5 +wendy ellison 3 +wendy ellison 5 +wendy falkner 2 +wendy falkner 4 +wendy falkner 6 +wendy garcia 4 +wendy garcia 4 +wendy garcia 7 +wendy garcia 7 +wendy hernandez 4 +wendy ichabod 4 +wendy king 4 +wendy king 5 +wendy king 7 +wendy laertes 2 +wendy laertes 3 +wendy laertes 5 +wendy miller 4 +wendy miller 4 +wendy nixon 3 +wendy nixon 5 +wendy ovid 5 +wendy ovid 9 +wendy polk 2 +wendy polk 5 +wendy quirinius 3 +wendy quirinius 4 +wendy robinson 5 +wendy robinson 6 +wendy robinson 6 +wendy steinbeck 3 +wendy thompson 4 +wendy thompson 5 +wendy underhill 4 +wendy underhill 5 +wendy underhill 6 +wendy van buren 6 +wendy van buren 6 +wendy white 4 +wendy xylophone 4 +wendy xylophone 6 +wendy young 1 +wendy young 6 +xavier allen 3 +xavier allen 4 +xavier allen 5 +xavier brown 2 +xavier brown 4 +xavier brown 6 +xavier carson 4 +xavier carson 5 +xavier davidson 7 +xavier davidson 8 +xavier davidson 9 +xavier ellison 8 +xavier ellison 8 +xavier garcia 5 +xavier hernandez 5 +xavier hernandez 6 +xavier hernandez 9 +xavier ichabod 3 +xavier ichabod 4 +xavier johnson 2 +xavier johnson 9 +xavier king 3 +xavier king 5 +xavier laertes 4 +xavier ovid 4 +xavier polk 3 +xavier polk 4 +xavier polk 4 +xavier polk 8 +xavier quirinius 3 +xavier quirinius 5 +xavier quirinius 6 +xavier quirinius 6 +xavier thompson 4 +xavier underhill 2 +xavier white 3 +xavier white 3 +xavier xylophone 4 +xavier zipper 3 +yuri allen 2 +yuri allen 3 +yuri brown 2 +yuri brown 3 +yuri carson 5 +yuri carson 6 +yuri ellison 6 +yuri ellison 6 +yuri falkner 6 +yuri falkner 10 +yuri garcia 1 +yuri hernandez 5 +yuri johnson 5 +yuri johnson 5 +yuri johnson 6 +yuri king 7 +yuri laertes 7 +yuri laertes 8 +yuri nixon 3 +yuri nixon 3 +yuri polk 3 +yuri polk 5 +yuri polk 6 +yuri quirinius 3 +yuri quirinius 4 +yuri quirinius 7 +yuri steinbeck 1 +yuri steinbeck 2 +yuri thompson 3 +yuri underhill 4 +yuri underhill 4 +yuri white 8 +yuri xylophone 3 +zach allen 4 +zach brown 5 +zach brown 5 +zach brown 5 +zach brown 5 +zach brown 7 +zach carson 5 +zach ellison 2 +zach falkner 4 +zach falkner 6 +zach garcia 4 +zach garcia 5 +zach garcia 7 +zach garcia 8 +zach ichabod 4 +zach ichabod 4 +zach king 4 +zach king 5 +zach king 8 +zach miller 1 +zach miller 3 +zach miller 4 +zach ovid 4 +zach ovid 5 +zach ovid 5 +zach ovid 7 +zach quirinius 8 +zach robinson 5 +zach steinbeck 4 +zach steinbeck 6 +zach thompson 3 +zach thompson 4 +zach underhill 3 +zach white 6 +zach xylophone 3 +zach xylophone 5 +zach young 4 +zach zipper 4 +zach zipper 4 +zach zipper 5 diff --git a/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-2-fafa16c0f7697ca28aeb6f2698799562 b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-2-fafa16c0f7697ca28aeb6f2698799562 new file mode 100644 index 000000000000..76cbeb254c0e --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-2-fafa16c0f7697ca28aeb6f2698799562 @@ -0,0 +1,1049 @@ +0.08 1 +0.1 1 +0.13 1 +0.15 1 +0.27 1 +0.28 1 +0.43 1 +0.52 1 +0.56 1 +0.6 1 +0.61 1 +0.79 1 +0.84 1 +0.98 1 +1.02 1 +1.08 1 +1.08 1 +1.12 1 +1.21 1 +1.25 1 +1.27 1 +1.29 1 +1.31 1 +1.58 1 +1.87 1 +1.91 1 +1.92 1 +2.07 1 +2.18 1 +2.2 1 +2.35 1 +2.6 1 +2.79 1 +2.92 1 +2.96 1 +2.96 1 +2.97 1 +3.0 1 +3.21 1 +3.28 1 +3.33 1 +3.61 1 +3.62 1 +3.82 1 +3.86 1 +3.96 1 +3.97 1 +4.17 1 +4.32 1 +4.35 1 +4.41 1 +4.46 1 +4.47 1 +4.57 1 +4.59 1 +4.71 1 +4.72 1 +4.79 1 +4.8 1 +4.92 1 +5.08 1 +5.24 1 +5.28 1 +5.4 1 +5.44 1 +5.45 1 +5.51 1 +5.54 1 +5.62 1 +5.67 1 +5.85 1 +5.88 1 +6.29 1 +6.55 1 +6.57 1 +6.63 1 +6.67 1 +6.72 1 +6.74 1 +6.84 1 +6.87 1 +7.05 1 +7.06 1 +7.11 1 +7.54 1 +7.56 1 +7.79 1 +7.82 1 +7.96 1 +7.96 1 +7.98 1 +8.07 1 +8.07 1 +8.32 1 +8.37 1 +8.42 1 +8.45 1 +8.45 1 +8.45 1 +8.45 1 +8.57 1 +8.61 1 +8.67 2 +8.71 1 +8.79 1 +8.91 1 +9.04 1 +9.13 1 +9.19 1 +9.22 1 +9.25 1 +9.26 1 +9.35 1 +9.48 1 +9.56 1 +9.57 1 +9.57 1 +9.68 1 +9.7 1 +9.71 1 +9.74 1 +9.8 1 +9.81 1 +9.93 1 +10.09 1 +10.09 1 +10.13 1 +10.16 1 +10.17 1 +10.19 1 +10.2 1 +10.22 1 +10.25 1 +10.26 1 +10.29 1 +10.6 1 +10.66 1 +10.67 1 +10.73 1 +11.15 1 +11.18 1 +11.19 1 +11.22 1 +11.34 1 +11.55 1 +11.57 1 +11.68 1 +11.82 1 +11.89 1 +11.91 1 +12.02 1 +12.16 1 +12.19 1 +12.32 1 +12.42 1 +12.44 1 +12.45 1 +12.46 1 +12.5 1 +12.54 1 +12.85 1 +12.9 1 +13.01 1 +13.1 1 +13.15 1 +13.35 1 +13.87 1 +13.89 1 +13.94 1 +13.99 1 +14.13 1 +14.21 2 +14.3 1 +14.44 1 +14.84 2 +14.92 1 +14.92 1 +14.93 1 +15.1 1 +15.15 1 +15.18 1 +15.22 1 +15.26 2 +15.3 1 +15.37 1 +15.45 1 +15.63 2 +15.75 1 +15.81 1 +15.86 1 +15.9 1 +15.92 1 +16.08 1 +16.09 1 +16.24 1 +16.25 1 +16.48 1 +16.69 1 +16.99 1 +16.99 1 +17.16 1 +17.37 1 +17.74 1 +17.79 1 +17.87 1 +18.2 1 +18.5 1 +18.56 1 +18.63 1 +18.63 1 +18.86 1 +18.89 1 +18.93 1 +19.0 1 +19.03 1 +19.06 1 +19.06 1 +19.13 1 +19.14 1 +19.28 1 +19.69 1 +20.07 1 +20.38 1 +20.64 1 +20.67 1 +20.79 1 +20.81 1 +20.82 1 +20.82 1 +21.18 1 +21.19 1 +21.23 1 +21.28 1 +21.32 1 +21.45 1 +21.49 1 +21.61 1 +21.7 1 +21.8 1 +21.94 1 +22.01 1 +22.08 1 +22.12 1 +22.12 1 +22.25 1 +22.27 1 +22.36 1 +22.68 1 +22.78 1 +22.85 1 +22.85 1 +22.94 1 +23.07 1 +23.13 1 +23.17 1 +23.19 1 +23.44 1 +23.45 1 +23.6 1 +23.77 1 +23.96 1 +24.02 1 +24.28 1 +24.49 1 +24.52 1 +24.73 1 +24.79 1 +24.8 1 +24.83 1 +24.86 1 +25.11 1 +25.28 1 +25.37 1 +25.42 1 +25.55 1 +25.67 1 +25.88 1 +26.08 1 +26.39 1 +26.43 1 +26.47 1 +26.49 1 +26.49 1 +26.64 1 +26.71 2 +26.73 1 +26.76 1 +27.07 1 +27.12 1 +27.3 1 +27.31 1 +27.63 1 +27.66 1 +27.72 2 +27.87 1 +28.11 1 +28.31 1 +28.45 1 +28.5 1 +28.56 1 +28.69 1 +28.71 1 +28.79 1 +28.89 1 +28.95 1 +29.02 1 +29.24 1 +29.36 1 +29.4 1 +29.41 2 +29.54 1 +29.59 1 +29.78 1 +30.25 2 +30.36 1 +30.37 1 +30.61 1 +30.62 1 +30.63 1 +30.65 1 +30.71 1 +30.81 1 +31.01 1 +31.15 1 +31.4 1 +31.61 1 +31.67 1 +31.77 1 +31.86 1 +31.91 1 +32.01 1 +32.18 1 +32.2 1 +32.23 1 +32.25 1 +32.37 1 +32.41 1 +32.47 1 +32.52 1 +32.75 1 +32.89 2 +32.92 1 +33.36 1 +33.52 1 +33.55 1 +33.58 1 +33.67 1 +33.76 1 +33.83 1 +33.85 2 +33.87 1 +34.03 1 +34.21 1 +34.35 1 +34.41 2 +34.58 1 +34.73 1 +34.97 1 +35.0 2 +35.08 1 +35.13 1 +35.17 1 +35.17 1 +35.56 1 +35.62 1 +35.65 1 +35.68 1 +35.72 1 +35.8 1 +35.89 1 +36.22 1 +36.26 1 +36.58 1 +36.7 1 +36.79 1 +36.89 1 +36.95 1 +37.07 2 +37.1 1 +37.14 1 +37.14 1 +37.24 1 +37.59 1 +37.6 1 +37.72 2 +37.78 1 +37.8 1 +37.85 1 +37.9 1 +38.05 1 +38.05 1 +38.3 2 +38.33 1 +38.57 1 +38.62 1 +38.79 1 +38.85 1 +38.88 1 +38.94 1 +39.01 1 +39.03 1 +39.18 1 +39.21 1 +39.34 1 +39.69 1 +39.81 1 +39.82 1 +39.83 1 +39.87 1 +39.9 1 +39.98 1 +40.0 1 +40.04 1 +40.17 1 +40.24 1 +40.42 1 +40.44 1 +40.78 1 +40.8 1 +40.98 1 +41.2 2 +41.29 1 +41.29 1 +41.31 2 +41.34 1 +41.34 1 +41.36 1 +41.44 1 +41.45 2 +41.62 1 +41.68 1 +41.71 1 +41.81 1 +41.85 1 +41.87 1 +41.89 2 +42.24 1 +42.31 1 +42.42 2 +42.48 1 +42.51 1 +42.55 1 +42.56 1 +42.67 2 +42.76 1 +42.85 1 +43.01 1 +43.02 1 +43.13 1 +43.16 1 +43.17 1 +43.19 1 +43.31 1 +43.37 1 +43.57 1 +43.71 1 +43.73 1 +43.92 1 +44.1 1 +44.22 1 +44.27 1 +44.43 1 +44.57 1 +45.06 2 +45.1 1 +45.19 1 +45.19 1 +45.24 1 +45.34 1 +45.35 1 +45.42 1 +45.45 1 +45.56 1 +45.59 1 +45.68 1 +45.92 1 +45.99 2 +46.02 1 +46.09 1 +46.1 1 +46.15 1 +46.18 1 +46.21 1 +46.27 1 +46.43 1 +46.45 1 +46.62 1 +46.8 1 +46.86 1 +46.87 1 +46.88 1 +46.97 2 +47.08 2 +47.27 1 +47.32 2 +47.57 1 +47.59 1 +47.69 1 +47.88 1 +48.01 1 +48.08 1 +48.11 1 +48.15 1 +48.22 1 +48.23 1 +48.25 1 +48.28 1 +48.37 1 +48.45 1 +48.45 1 +48.52 1 +48.59 1 +49.12 1 +49.28 2 +49.44 1 +49.68 1 +49.77 2 +49.78 1 +50.02 3 +50.08 1 +50.09 1 +50.26 1 +50.28 1 +50.31 1 +50.32 1 +50.4 1 +50.41 1 +50.66 1 +50.7 1 +50.83 2 +50.92 2 +50.96 1 +51.25 1 +51.29 1 +51.29 2 +51.72 1 +51.79 2 +51.84 1 +51.85 1 +52.17 1 +52.23 1 +52.44 1 +52.5 1 +52.53 2 +52.72 1 +52.73 1 +52.85 2 +52.87 1 +53.02 1 +53.06 1 +53.18 2 +53.27 1 +53.59 2 +53.78 1 +53.93 1 +53.94 1 +54.1 2 +54.31 1 +54.34 1 +54.43 1 +54.44 1 +54.47 1 +54.73 1 +54.75 1 +54.83 1 +54.99 1 +55.1 1 +55.18 1 +55.2 1 +55.39 1 +55.51 1 +55.63 1 +55.99 1 +56.04 1 +56.07 1 +56.1 1 +56.15 1 +56.33 1 +56.62 1 +56.68 2 +56.81 1 +57.08 1 +57.11 1 +57.12 2 +57.23 1 +57.25 1 +57.29 1 +57.35 1 +57.37 1 +57.46 1 +57.64 1 +57.67 1 +57.89 2 +57.93 1 +58.0 1 +58.08 2 +58.09 2 +58.13 1 +58.43 1 +58.52 1 +58.66 1 +58.67 1 +58.75 1 +58.86 2 +59.07 1 +59.16 1 +59.21 1 +59.34 1 +59.43 1 +59.45 1 +59.45 2 +59.5 1 +59.55 1 +59.61 1 +59.62 1 +59.68 1 +59.68 1 +59.7 1 +59.71 1 +59.83 1 +59.87 1 +59.99 1 +60.02 1 +60.06 1 +60.12 1 +60.13 1 +60.22 2 +60.26 1 +60.26 1 +60.53 1 +60.6 2 +60.71 1 +60.85 1 +61.21 1 +61.7 1 +61.86 1 +61.88 1 +61.92 2 +61.94 1 +62.14 1 +62.2 1 +62.23 1 +62.3 1 +62.39 1 +62.52 1 +62.72 1 +62.74 1 +62.85 2 +62.9 3 +62.92 1 +63.12 1 +63.33 1 +63.35 2 +63.42 1 +63.51 1 +63.9 1 +64.0 1 +64.22 1 +64.25 1 +64.3 1 +64.36 2 +64.46 1 +64.65 2 +64.67 1 +64.77 1 +64.87 1 +64.95 1 +65.02 1 +65.02 1 +65.38 1 +65.43 1 +65.43 1 +65.44 2 +65.55 1 +65.62 1 +65.7 1 +65.72 1 +66.17 1 +66.17 2 +66.36 1 +66.51 1 +66.61 1 +66.61 1 +66.67 1 +66.89 1 +67.12 1 +67.18 1 +67.26 1 +67.38 1 +67.45 1 +67.48 1 +67.59 1 +67.94 1 +67.98 1 +68.01 2 +68.04 1 +68.22 1 +68.25 1 +68.25 1 +68.32 1 +68.41 1 +68.5 1 +68.81 1 +68.85 2 +68.89 1 +68.95 1 +68.96 1 +69.32 2 +69.53 1 +69.74 3 +69.8 2 +69.88 1 +69.96 1 +69.97 1 +70.0 2 +70.04 1 +70.06 1 +70.24 1 +70.35 1 +70.38 1 +70.39 1 +70.52 1 +70.53 1 +70.56 1 +70.85 1 +70.89 1 +70.93 1 +71.01 1 +71.07 2 +71.13 1 +71.19 1 +71.26 1 +71.31 1 +71.32 1 +71.35 1 +71.5 1 +71.54 1 +71.55 3 +71.68 1 +71.68 2 +71.78 2 +71.8 1 +71.89 2 +72.04 1 +72.18 1 +72.51 1 +72.53 2 +72.56 1 +72.62 1 +72.79 1 +72.98 1 +73.18 1 +73.32 1 +73.48 2 +73.63 2 +73.65 1 +73.68 1 +73.88 1 +73.93 1 +74.0 1 +74.02 1 +74.15 1 +74.19 1 +74.19 1 +74.3 1 +74.42 1 +74.45 1 +74.52 1 +74.53 1 +74.59 1 +74.62 1 +74.72 1 +74.78 1 +75.03 2 +75.1 3 +75.19 1 +75.29 1 +75.35 2 +75.42 2 +75.66 2 +75.73 2 +75.83 1 +75.88 1 +76.05 3 +76.1 1 +76.28 1 +76.28 1 +76.33 1 +76.52 1 +76.69 1 +76.7 1 +76.71 1 +76.72 2 +76.72 2 +76.74 1 +76.92 1 +76.93 1 +77.02 3 +77.1 1 +77.36 2 +77.42 1 +77.57 2 +77.66 1 +77.81 1 +77.84 2 +77.89 1 +77.97 1 +78.21 2 +78.26 1 +78.28 1 +78.3 2 +78.31 1 +78.62 1 +78.64 1 +78.73 1 +78.89 1 +78.98 2 +79.12 1 +79.19 1 +79.21 1 +79.38 1 +79.42 1 +79.48 1 +79.48 2 +79.49 2 +79.54 1 +79.55 2 +79.75 1 +79.83 1 +79.96 1 +79.97 2 +79.99 1 +80.23 2 +80.3 1 +80.3 2 +80.46 1 +80.52 1 +80.58 1 +80.6 1 +80.71 1 +80.74 1 +80.84 2 +80.92 1 +80.96 1 +80.97 1 +80.99 1 +81.17 1 +81.32 1 +81.32 1 +81.47 1 +81.58 1 +81.64 1 +81.66 1 +82.24 1 +82.3 1 +82.34 2 +82.41 1 +82.52 1 +82.55 1 +82.56 1 +82.72 1 +82.97 1 +83.08 1 +83.27 1 +83.33 1 +83.4 1 +83.54 2 +83.57 1 +83.58 1 +83.87 1 +83.92 1 +83.93 1 +84.03 1 +84.23 2 +84.31 1 +84.38 1 +84.4 1 +84.69 1 +84.72 1 +84.83 1 +85.0 1 +85.03 2 +85.1 1 +85.14 1 +85.23 1 +85.49 1 +85.49 2 +85.51 2 +85.74 1 +85.76 1 +85.87 1 +85.9 1 +86.0 1 +86.22 1 +86.23 1 +86.63 1 +86.69 1 +86.92 2 +86.93 1 +86.93 1 +87.14 2 +87.22 1 +87.4 1 +87.48 1 +87.57 1 +87.61 1 +87.67 1 +87.83 2 +87.94 1 +87.99 1 +88.02 1 +88.05 1 +88.07 2 +88.17 1 +88.22 1 +88.36 1 +88.47 1 +88.48 1 +88.55 1 +88.77 1 +88.78 1 +88.8 1 +88.91 2 +89.01 2 +89.03 1 +89.1 3 +89.15 2 +89.28 1 +89.38 1 +89.53 1 +89.55 1 +89.55 1 +89.55 1 +89.8 1 +89.81 1 +89.93 1 +90.05 1 +90.05 1 +90.07 1 +90.12 1 +90.2 1 +90.25 1 +90.28 2 +90.35 1 +90.38 1 +90.51 1 +90.56 2 +90.69 1 +90.69 1 +90.73 1 +90.77 1 +91.05 1 +91.16 1 +91.42 1 +91.48 1 +91.53 1 +91.61 1 +91.63 1 +91.78 1 +91.88 1 +91.97 1 +92.05 2 +92.11 2 +92.33 2 +92.37 1 +92.4 2 +92.55 1 +92.61 1 +92.82 1 +92.96 2 +92.98 1 +93.03 1 +93.09 1 +93.11 1 +93.61 1 +93.64 1 +93.73 1 +94.08 1 +94.15 1 +94.25 1 +94.27 1 +94.31 1 +94.33 1 +94.34 1 +94.38 1 +94.43 1 +94.54 1 +94.66 1 +94.68 1 +94.68 1 +94.72 1 +95.07 1 +95.11 1 +95.28 1 +95.33 1 +95.34 1 +95.38 2 +95.48 1 +95.53 1 +95.53 1 +95.81 1 +95.81 2 +95.84 1 +96.09 1 +96.23 1 +96.25 1 +96.29 1 +96.38 1 +96.62 1 +96.68 1 +96.73 1 +96.78 1 +96.91 2 +96.94 1 +97.09 1 +97.24 1 +97.26 1 +97.39 1 +97.46 1 +97.51 2 +97.56 1 +97.57 1 +97.65 2 +97.68 1 +97.71 4 +97.81 1 +97.83 1 +97.87 1 +98.18 2 +98.22 1 +98.23 1 +98.31 1 +98.48 1 +98.51 1 +98.57 1 +98.72 1 +98.96 1 +99.13 1 +99.15 1 +99.21 1 +99.24 1 +99.29 1 +99.36 1 +99.62 1 +99.65 1 +99.67 1 +99.68 1 +99.91 1 +99.92 1 diff --git a/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-3-bda0e7c77d6f4712a03389cb5032bc6d b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-3-bda0e7c77d6f4712a03389cb5032bc6d new file mode 100644 index 000000000000..a9ec53c0cb21 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_ntile.q (deterministic)-3-bda0e7c77d6f4712a03389cb5032bc6d @@ -0,0 +1,1049 @@ +0.02 1 +0.21 1 +0.27 1 +0.37 1 +0.37 1 +0.47 1 +0.48 1 +0.52 1 +0.6 1 +0.63 1 +0.63 1 +0.66 1 +0.73 1 +0.74 1 +0.74 1 +0.78 1 +0.8 1 +0.86 1 +0.86 1 +0.98 1 +1.17 1 +1.29 1 +1.3 1 +1.31 1 +1.37 1 +1.42 1 +1.45 1 +1.5 1 +1.53 1 +1.61 1 +1.62 1 +1.69 1 +1.71 1 +1.79 1 +1.98 1 +2.07 1 +2.09 1 +2.13 1 +2.16 1 +2.26 1 +2.29 1 +2.34 1 +2.34 1 +2.36 1 +2.43 1 +2.43 1 +2.52 1 +2.53 1 +2.53 1 +2.61 1 +2.63 1 +2.74 1 +2.75 1 +2.75 1 +2.79 1 +2.8 1 +2.82 1 +2.89 1 +2.89 1 +2.89 1 +2.9 1 +2.92 1 +3.03 1 +3.07 1 +3.07 1 +3.12 1 +3.15 1 +3.17 1 +3.27 1 +3.27 1 +3.29 1 +3.4 1 +3.4 1 +3.4 1 +3.4 1 +3.42 1 +3.66 1 +3.67 1 +3.69 1 +3.71 1 +3.78 1 +3.83 1 +3.86 1 +3.92 1 +3.98 1 +3.99 1 +4.04 1 +4.16 1 +4.25 1 +4.27 1 +4.44 1 +4.48 1 +4.53 1 +4.57 1 +4.58 1 +4.62 1 +4.81 1 +4.82 1 +4.83 1 +4.92 1 +4.95 1 +4.96 1 +4.97 1 +4.98 1 +5.09 1 +5.11 1 +5.19 1 +5.23 1 +5.3 1 +5.31 1 +5.31 1 +5.35 1 +5.42 1 +5.51 1 +5.55 1 +5.58 1 +5.74 1 +5.82 1 +5.84 1 +5.93 1 +5.93 1 +5.96 1 +6.06 1 +6.06 1 +6.21 1 +6.28 1 +6.46 1 +6.52 2 +6.54 1 +6.56 1 +6.57 1 +6.58 1 +6.61 1 +6.62 1 +6.76 1 +6.81 1 +6.81 1 +6.96 1 +6.98 1 +7.02 1 +7.03 1 +7.14 1 +7.18 1 +7.24 1 +7.24 1 +7.31 1 +7.36 1 +7.37 1 +7.45 1 +7.53 1 +7.62 1 +7.66 1 +7.71 1 +7.71 1 +7.8 1 +7.92 1 +8.05 1 +8.09 1 +8.21 1 +8.33 1 +8.33 1 +8.49 1 +8.49 1 +8.52 1 +8.56 1 +8.61 1 +8.62 2 +8.72 1 +8.76 1 +8.79 1 +8.82 1 +8.84 1 +8.95 1 +8.98 1 +9.14 1 +9.19 1 +9.21 1 +9.22 1 +9.26 1 +9.27 1 +9.39 2 +9.4 1 +9.42 1 +9.5 1 +9.51 1 +9.56 1 +9.6 1 +9.61 1 +9.62 1 +9.64 1 +9.81 1 +9.87 1 +9.88 1 +9.93 1 +9.94 1 +9.96 1 +9.99 1 +10.15 1 +10.21 1 +10.22 1 +10.23 1 +10.24 1 +10.36 1 +10.38 1 +10.38 1 +10.41 1 +10.47 1 +10.49 1 +10.49 1 +10.51 1 +10.52 1 +10.7 1 +10.71 1 +10.85 1 +10.99 1 +11.02 1 +11.12 1 +11.12 1 +11.16 1 +11.2 1 +11.26 1 +11.27 1 +11.35 1 +11.35 1 +11.4 1 +11.43 1 +11.44 1 +11.44 1 +11.46 1 +11.48 1 +11.5 1 +11.54 1 +11.63 1 +11.66 1 +11.69 1 +11.83 1 +11.9 1 +11.91 1 +11.96 1 +12.02 1 +12.13 1 +12.14 1 +12.15 1 +12.3 1 +12.3 1 +12.3 2 +12.34 1 +12.35 1 +12.43 1 +12.43 1 +12.64 1 +12.66 1 +12.7 1 +12.72 1 +12.73 1 +12.74 2 +12.82 1 +12.85 1 +13.02 1 +13.04 1 +13.08 1 +13.14 1 +13.2 1 +13.2 1 +13.22 1 +13.23 1 +13.3 1 +13.3 1 +13.44 1 +13.44 1 +13.44 1 +13.49 1 +13.6 1 +13.66 1 +13.71 1 +13.72 1 +13.8 1 +13.83 1 +13.84 1 +13.88 1 +13.95 1 +14.07 1 +14.16 1 +14.17 1 +14.22 1 +14.24 1 +14.26 1 +14.29 1 +14.33 1 +14.39 1 +14.44 1 +14.51 1 +14.51 1 +14.52 1 +14.62 1 +14.69 1 +14.72 1 +14.75 1 +14.83 1 +14.83 1 +14.84 1 +14.9 1 +14.91 1 +14.92 1 +14.99 1 +15.0 1 +15.01 1 +15.09 1 +15.09 1 +15.09 1 +15.1 1 +15.12 1 +15.13 1 +15.16 1 +15.18 1 +15.22 1 +15.27 1 +15.28 1 +15.32 1 +15.38 1 +15.46 1 +15.46 1 +15.51 1 +15.54 1 +15.87 1 +15.94 1 +15.97 1 +15.98 1 +16.04 2 +16.1 1 +16.12 1 +16.13 1 +16.15 1 +16.29 1 +16.35 1 +16.36 1 +16.38 1 +16.4 1 +16.42 1 +16.47 1 +16.49 1 +16.54 1 +16.61 1 +16.66 1 +16.79 1 +16.79 1 +16.82 1 +16.87 1 +16.87 1 +16.9 1 +16.9 1 +16.91 1 +16.92 1 +17.03 1 +17.03 2 +17.08 1 +17.15 1 +17.19 1 +17.29 1 +17.33 1 +17.44 1 +17.46 1 +17.47 1 +17.51 1 +17.52 2 +17.55 1 +17.59 1 +17.63 1 +17.69 1 +17.76 1 +17.86 1 +17.89 1 +17.99 1 +18.09 1 +18.19 1 +18.2 1 +18.28 1 +18.29 1 +18.31 1 +18.34 1 +18.35 1 +18.36 1 +18.38 1 +18.38 1 +18.41 1 +18.47 1 +18.48 1 +18.79 1 +18.82 1 +18.83 1 +18.86 1 +18.86 1 +19.03 1 +19.12 1 +19.15 1 +19.2 1 +19.31 1 +19.32 1 +19.41 1 +19.47 1 +19.47 2 +19.56 1 +19.59 1 +19.63 1 +19.65 1 +19.72 1 +19.72 1 +19.79 1 +19.79 1 +19.85 1 +19.87 1 +19.9 1 +19.92 1 +19.93 1 +19.98 1 +20.02 1 +20.02 1 +20.17 1 +20.19 1 +20.22 1 +20.3 1 +20.3 1 +20.34 1 +20.39 1 +20.42 1 +20.42 1 +20.44 1 +20.55 1 +20.55 1 +20.56 1 +20.56 1 +20.58 1 +20.58 1 +20.64 1 +20.65 1 +20.75 1 +20.76 1 +20.76 1 +20.8 1 +20.82 1 +20.91 1 +20.93 1 +20.94 1 +20.94 1 +20.97 1 +21.0 1 +21.01 1 +21.01 1 +21.02 1 +21.02 1 +21.11 1 +21.11 1 +21.14 1 +21.16 1 +21.26 1 +21.27 1 +21.3 1 +21.3 1 +21.33 1 +21.33 1 +21.38 1 +21.42 1 +21.52 1 +21.53 1 +21.57 1 +21.66 1 +21.67 1 +21.69 1 +21.77 1 +21.81 1 +21.86 1 +21.91 1 +21.95 1 +22.15 1 +22.19 1 +22.19 1 +22.22 1 +22.27 1 +22.27 1 +22.33 1 +22.48 1 +22.6 1 +22.61 1 +22.64 1 +22.68 1 +22.73 1 +22.75 1 +22.94 1 +22.95 1 +23.03 1 +23.07 1 +23.15 1 +23.15 1 +23.18 1 +23.18 1 +23.25 1 +23.27 1 +23.3 1 +23.31 1 +23.45 1 +23.48 1 +23.53 1 +23.55 1 +23.59 1 +23.61 1 +23.63 1 +23.73 1 +23.77 1 +23.78 2 +23.88 1 +23.91 1 +24.03 1 +24.03 1 +24.13 1 +24.17 1 +24.18 1 +24.25 1 +24.35 1 +24.35 2 +24.42 1 +24.46 1 +24.53 2 +24.59 1 +24.61 1 +24.61 1 +24.84 1 +24.85 1 +24.86 1 +24.95 1 +25.01 1 +25.02 1 +25.03 1 +25.08 1 +25.11 1 +25.17 1 +25.32 2 +25.36 1 +25.36 1 +25.43 1 +25.49 1 +25.51 1 +25.51 1 +25.58 1 +25.59 1 +25.63 1 +25.71 1 +25.75 1 +25.8 1 +25.92 1 +25.92 1 +25.95 1 +25.97 2 +26.0 1 +26.17 2 +26.21 1 +26.22 1 +26.24 1 +26.28 1 +26.44 1 +26.55 1 +26.55 1 +26.65 1 +26.67 1 +26.71 1 +26.71 1 +26.73 1 +26.74 1 +26.79 1 +26.84 1 +26.87 1 +27.02 1 +27.14 1 +27.2 1 +27.27 1 +27.29 1 +27.36 1 +27.39 1 +27.4 1 +27.42 1 +27.46 1 +27.54 1 +27.54 1 +27.61 1 +27.62 1 +27.89 1 +28.02 1 +28.1 1 +28.13 1 +28.14 1 +28.15 1 +28.17 1 +28.19 1 +28.29 1 +28.36 1 +28.4 1 +28.42 2 +28.44 1 +28.52 1 +28.52 1 +28.61 1 +28.64 1 +28.68 1 +28.69 1 +28.69 1 +28.71 1 +28.71 2 +28.77 1 +28.77 1 +28.85 1 +28.86 1 +28.91 1 +28.96 1 +28.96 1 +28.98 1 +29.0 1 +29.11 1 +29.19 1 +29.22 1 +29.24 1 +29.25 1 +29.36 1 +29.41 1 +29.46 1 +29.49 1 +29.52 2 +29.62 1 +29.63 1 +29.66 1 +29.73 1 +29.76 1 +29.78 1 +29.88 1 +29.96 1 +30.04 1 +30.04 1 +30.09 1 +30.12 1 +30.13 1 +30.16 1 +30.17 1 +30.21 1 +30.22 1 +30.28 1 +30.35 1 +30.37 2 +30.41 1 +30.41 1 +30.49 1 +30.55 1 +30.55 1 +30.58 1 +30.61 1 +30.62 1 +30.66 1 +30.67 1 +30.78 1 +30.78 1 +30.87 1 +30.89 1 +30.9 1 +30.92 1 +30.98 2 +30.99 1 +31.01 1 +31.16 1 +31.23 1 +31.3 1 +31.33 1 +31.36 1 +31.45 1 +31.46 1 +31.5 1 +31.61 1 +31.63 1 +31.64 1 +31.66 1 +31.68 1 +31.74 1 +31.75 1 +31.76 1 +31.84 1 +32.04 1 +32.06 1 +32.13 1 +32.17 1 +32.2 1 +32.25 1 +32.25 1 +32.31 1 +32.33 1 +32.39 1 +32.56 1 +32.56 1 +32.61 1 +32.74 1 +32.85 1 +32.89 1 +32.98 1 +33.0 1 +33.01 1 +33.02 1 +33.02 1 +33.02 1 +33.11 1 +33.12 1 +33.18 1 +33.18 1 +33.19 1 +33.24 1 +33.3 1 +33.36 1 +33.36 1 +33.38 1 +33.49 1 +33.52 2 +33.6 2 +33.64 1 +33.64 1 +33.66 1 +33.67 1 +33.72 1 +33.76 1 +33.9 1 +34.05 1 +34.06 1 +34.11 1 +34.14 1 +34.15 1 +34.17 1 +34.2 1 +34.41 2 +34.48 2 +34.49 1 +34.52 1 +34.53 1 +34.54 1 +34.62 1 +34.68 1 +34.68 1 +34.72 1 +34.81 1 +34.83 1 +34.84 1 +34.9 1 +34.95 1 +34.95 1 +34.97 1 +34.97 1 +34.98 1 +35.01 1 +35.02 1 +35.1 1 +35.15 1 +35.16 1 +35.23 1 +35.24 1 +35.36 1 +35.36 1 +35.49 1 +35.62 1 +35.68 1 +35.72 1 +35.84 1 +35.85 1 +35.9 1 +36.05 2 +36.09 1 +36.11 1 +36.12 1 +36.13 1 +36.22 1 +36.56 1 +36.57 1 +36.57 1 +36.62 1 +36.7 1 +36.72 1 +36.73 1 +36.84 1 +36.86 1 +36.93 1 +36.93 1 +37.02 1 +37.08 1 +37.12 1 +37.23 2 +37.32 1 +37.34 1 +37.37 1 +37.76 1 +37.77 1 +37.8 1 +37.93 1 +37.94 1 +37.96 1 +38.0 1 +38.04 1 +38.04 1 +38.04 1 +38.05 1 +38.07 1 +38.14 1 +38.22 1 +38.28 1 +38.37 2 +38.39 1 +38.43 1 +38.53 1 +38.6 1 +38.62 1 +38.66 1 +38.67 1 +38.73 1 +38.74 1 +38.92 1 +38.94 1 +39.01 1 +39.04 1 +39.05 1 +39.11 1 +39.27 1 +39.29 1 +39.41 1 +39.46 1 +39.49 1 +39.55 1 +39.57 1 +39.6 1 +39.63 1 +39.73 1 +39.74 1 +39.78 2 +39.8 1 +39.84 1 +39.84 1 +39.85 1 +39.92 1 +40.01 1 +40.04 1 +40.15 1 +40.21 2 +40.22 1 +40.24 1 +40.26 1 +40.27 1 +40.39 2 +40.43 1 +40.46 1 +40.5 1 +40.59 1 +40.6 1 +40.63 1 +40.76 1 +40.79 1 +40.84 1 +40.89 1 +40.91 1 +40.94 1 +40.96 1 +41.02 1 +41.08 1 +41.2 1 +41.24 1 +41.33 1 +41.34 1 +41.34 2 +41.36 1 +41.36 1 +41.37 1 +41.54 1 +41.56 1 +41.69 1 +41.73 1 +41.75 1 +41.75 1 +41.83 1 +41.86 1 +41.9 1 +42.0 2 +42.02 1 +42.03 1 +42.04 1 +42.22 1 +42.3 1 +42.37 1 +42.41 2 +42.47 1 +42.55 1 +42.56 1 +42.57 1 +42.76 1 +42.89 1 +42.96 1 +42.96 1 +43.0 1 +43.04 1 +43.04 3 +43.18 1 +43.23 1 +43.3 1 +43.34 1 +43.34 2 +43.4 1 +43.42 1 +43.58 1 +43.64 1 +43.67 1 +43.76 2 +43.84 1 +43.85 1 +43.92 1 +43.95 1 +43.95 1 +43.96 1 +43.96 1 +44.04 1 +44.11 1 +44.12 1 +44.12 1 +44.22 3 +44.24 1 +44.27 1 +44.32 1 +44.36 1 +44.4 1 +44.57 1 +44.6 1 +44.63 1 +44.66 1 +44.73 1 +44.75 1 +44.8 1 +44.83 1 +44.9 1 +44.92 1 +44.93 1 +45.0 1 +45.02 1 +45.06 1 +45.06 1 +45.09 1 +45.1 1 +45.1 1 +45.11 1 +45.12 2 +45.14 1 +45.14 2 +45.28 2 +45.29 1 +45.46 1 +45.49 1 +45.53 1 +45.53 2 +45.54 1 +45.69 1 +45.71 1 +45.78 1 +45.81 2 +45.86 1 +45.9 1 +45.94 1 +46.03 1 +46.03 1 +46.09 1 +46.18 1 +46.19 1 +46.28 1 +46.3 1 +46.31 1 +46.33 1 +46.36 1 +46.39 1 +46.52 1 +46.53 1 +46.54 1 +46.57 1 +46.59 1 +46.67 1 +46.69 1 +46.73 1 +46.73 1 +46.73 2 +46.74 1 +46.81 1 +46.87 1 +46.88 1 +46.9 1 +46.93 1 +46.98 1 +47.0 1 +47.03 1 +47.03 1 +47.06 1 +47.15 1 +47.22 1 +47.3 1 +47.31 2 +47.37 2 +47.4 1 +47.46 1 +47.49 1 +47.55 1 +47.6 1 +47.66 1 +47.68 1 +47.71 1 +47.72 1 +47.82 1 +47.86 2 +47.91 1 +47.91 2 +47.95 1 +47.98 1 +48.0 1 +48.08 1 +48.5 1 +48.52 1 +48.71 1 +48.78 1 +48.8 1 +48.85 1 +48.89 1 +48.96 1 +48.98 1 +49.04 1 +49.05 1 +49.16 1 +49.21 1 +49.32 1 +49.34 1 +49.34 1 +49.38 1 +49.44 2 +49.45 1 +49.45 1 +49.46 1 +49.46 1 +49.52 1 +49.56 1 +49.59 1 +49.63 1 +49.67 1 +49.69 1 +49.71 2 +49.72 1 +49.73 1 +49.79 1 +49.84 1 +49.85 2 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-0-2e0cbc2d7c5f16657edacd9e7209e6e7 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-0-2e0cbc2d7c5f16657edacd9e7209e6e7 new file mode 100644 index 000000000000..612bdf44c0cd --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-0-2e0cbc2d7c5f16657edacd9e7209e6e7 @@ -0,0 +1,1049 @@ + 1 + 1 + 1 +alice allen 1 +alice allen 1 +alice allen 1 +alice brown 1 +alice carson 1 +alice davidson 1 +alice falkner 1 +alice garcia 2 +alice hernandez 1 +alice hernandez 1 +alice johnson 1 +alice king 1 +alice king 1 +alice king 1 +alice laertes 1 +alice laertes 1 +alice miller 1 +alice nixon 1 +alice nixon 1 +alice nixon 1 +alice ovid 1 +alice polk 1 +alice quirinius 1 +alice quirinius 1 +alice robinson 1 +alice robinson 1 +alice steinbeck 1 +alice steinbeck 1 +alice steinbeck 1 +alice underhill 1 +alice van buren 1 +alice xylophone 1 +alice xylophone 1 +alice xylophone 1 +alice zipper 1 +alice zipper 1 +alice zipper 1 +bob brown 1 +bob brown 1 +bob brown 2 +bob carson 1 +bob davidson 1 +bob davidson 1 +bob davidson 1 +bob ellison 1 +bob ellison 1 +bob ellison 1 +bob ellison 2 +bob falkner 1 +bob garcia 1 +bob garcia 1 +bob garcia 1 +bob garcia 1 +bob garcia 2 +bob hernandez 1 +bob ichabod 1 +bob king 1 +bob king 1 +bob king 1 +bob laertes 1 +bob laertes 1 +bob miller 1 +bob ovid 1 +bob ovid 1 +bob ovid 1 +bob ovid 1 +bob polk 1 +bob quirinius 1 +bob steinbeck 1 +bob van buren 1 +bob white 1 +bob white 1 +bob xylophone 1 +bob xylophone 1 +bob young 1 +bob zipper 1 +bob zipper 1 +bob zipper 1 +calvin allen 1 +calvin brown 1 +calvin brown 1 +calvin brown 1 +calvin carson 1 +calvin davidson 1 +calvin davidson 1 +calvin ellison 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin falkner 1 +calvin garcia 1 +calvin hernandez 1 +calvin johnson 2 +calvin laertes 1 +calvin laertes 1 +calvin nixon 1 +calvin nixon 1 +calvin nixon 1 +calvin ovid 1 +calvin ovid 1 +calvin ovid 1 +calvin ovid 1 +calvin polk 1 +calvin quirinius 1 +calvin quirinius 1 +calvin robinson 1 +calvin steinbeck 1 +calvin steinbeck 1 +calvin steinbeck 1 +calvin thompson 1 +calvin thompson 1 +calvin underhill 1 +calvin van buren 1 +calvin van buren 1 +calvin white 1 +calvin white 1 +calvin xylophone 1 +calvin xylophone 1 +calvin xylophone 1 +calvin young 1 +calvin young 1 +calvin zipper 1 +calvin zipper 1 +david allen 1 +david allen 1 +david brown 1 +david brown 1 +david davidson 1 +david davidson 1 +david davidson 1 +david davidson 2 +david ellison 1 +david ellison 1 +david ellison 1 +david hernandez 1 +david ichabod 1 +david ichabod 1 +david laertes 1 +david nixon 1 +david ovid 1 +david ovid 1 +david quirinius 1 +david quirinius 1 +david quirinius 1 +david robinson 1 +david robinson 1 +david thompson 1 +david underhill 1 +david underhill 1 +david underhill 1 +david van buren 1 +david van buren 1 +david white 2 +david xylophone 1 +david xylophone 1 +david xylophone 1 +david young 1 +david young 1 +ethan allen 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan carson 1 +ethan ellison 1 +ethan ellison 1 +ethan falkner 1 +ethan falkner 1 +ethan garcia 1 +ethan hernandez 1 +ethan johnson 1 +ethan king 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan laertes 1 +ethan miller 1 +ethan nixon 1 +ethan ovid 1 +ethan polk 1 +ethan polk 1 +ethan polk 1 +ethan polk 1 +ethan quirinius 1 +ethan quirinius 1 +ethan quirinius 1 +ethan robinson 1 +ethan robinson 1 +ethan underhill 1 +ethan van buren 1 +ethan white 1 +ethan white 1 +ethan xylophone 1 +ethan zipper 1 +ethan zipper 1 +fred davidson 1 +fred davidson 1 +fred davidson 1 +fred ellison 1 +fred ellison 1 +fred ellison 1 +fred falkner 1 +fred falkner 1 +fred falkner 1 +fred hernandez 1 +fred ichabod 1 +fred ichabod 2 +fred johnson 1 +fred king 1 +fred king 1 +fred laertes 1 +fred miller 1 +fred nixon 1 +fred nixon 1 +fred nixon 1 +fred nixon 2 +fred polk 1 +fred polk 1 +fred polk 1 +fred polk 1 +fred quirinius 1 +fred quirinius 1 +fred robinson 1 +fred steinbeck 1 +fred steinbeck 1 +fred steinbeck 1 +fred underhill 1 +fred van buren 1 +fred van buren 1 +fred van buren 1 +fred van buren 1 +fred white 1 +fred young 1 +fred young 1 +fred zipper 1 +gabriella allen 1 +gabriella allen 1 +gabriella brown 1 +gabriella brown 1 +gabriella carson 1 +gabriella davidson 1 +gabriella ellison 1 +gabriella ellison 1 +gabriella falkner 1 +gabriella falkner 1 +gabriella falkner 1 +gabriella garcia 1 +gabriella hernandez 1 +gabriella hernandez 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella ichabod 1 +gabriella king 1 +gabriella king 1 +gabriella laertes 1 +gabriella miller 1 +gabriella ovid 1 +gabriella ovid 1 +gabriella polk 1 +gabriella polk 1 +gabriella steinbeck 1 +gabriella steinbeck 1 +gabriella thompson 1 +gabriella thompson 1 +gabriella thompson 1 +gabriella van buren 1 +gabriella van buren 1 +gabriella white 1 +gabriella young 1 +gabriella young 1 +gabriella zipper 1 +gabriella zipper 1 +holly allen 1 +holly brown 1 +holly brown 1 +holly falkner 1 +holly hernandez 1 +holly hernandez 1 +holly hernandez 1 +holly hernandez 2 +holly ichabod 1 +holly ichabod 1 +holly ichabod 1 +holly johnson 1 +holly johnson 1 +holly johnson 1 +holly king 1 +holly king 1 +holly laertes 1 +holly miller 1 +holly nixon 1 +holly nixon 1 +holly polk 1 +holly polk 1 +holly robinson 1 +holly thompson 1 +holly thompson 1 +holly thompson 1 +holly underhill 1 +holly underhill 1 +holly underhill 1 +holly underhill 1 +holly van buren 1 +holly white 1 +holly white 2 +holly xylophone 1 +holly young 1 +holly young 1 +holly zipper 1 +holly zipper 1 +irene allen 1 +irene brown 1 +irene brown 1 +irene brown 1 +irene carson 1 +irene ellison 1 +irene ellison 1 +irene falkner 1 +irene falkner 1 +irene garcia 1 +irene garcia 1 +irene garcia 1 +irene ichabod 1 +irene ichabod 1 +irene johnson 1 +irene laertes 1 +irene laertes 1 +irene laertes 1 +irene miller 1 +irene nixon 1 +irene nixon 1 +irene nixon 1 +irene ovid 1 +irene ovid 1 +irene ovid 1 +irene polk 1 +irene polk 1 +irene polk 1 +irene polk 1 +irene polk 1 +irene quirinius 1 +irene quirinius 1 +irene quirinius 1 +irene robinson 1 +irene steinbeck 1 +irene thompson 1 +irene underhill 1 +irene underhill 1 +irene van buren 1 +irene van buren 1 +irene xylophone 2 +jessica brown 2 +jessica carson 1 +jessica carson 1 +jessica carson 1 +jessica davidson 1 +jessica davidson 1 +jessica davidson 1 +jessica davidson 1 +jessica ellison 1 +jessica ellison 1 +jessica falkner 1 +jessica garcia 1 +jessica garcia 1 +jessica ichabod 1 +jessica johnson 1 +jessica johnson 1 +jessica miller 1 +jessica nixon 1 +jessica nixon 1 +jessica ovid 1 +jessica ovid 2 +jessica polk 1 +jessica quirinius 1 +jessica quirinius 1 +jessica quirinius 1 +jessica quirinius 1 +jessica robinson 1 +jessica thompson 1 +jessica thompson 3 +jessica underhill 1 +jessica underhill 1 +jessica underhill 1 +jessica van buren 1 +jessica white 1 +jessica white 1 +jessica white 1 +jessica white 1 +jessica white 1 +jessica xylophone 1 +jessica young 1 +jessica young 1 +jessica zipper 1 +jessica zipper 1 +jessica zipper 1 +katie allen 1 +katie brown 1 +katie davidson 1 +katie ellison 1 +katie ellison 1 +katie falkner 1 +katie garcia 1 +katie garcia 1 +katie hernandez 1 +katie ichabod 1 +katie ichabod 1 +katie ichabod 1 +katie king 1 +katie king 1 +katie king 1 +katie miller 1 +katie miller 1 +katie nixon 1 +katie ovid 1 +katie polk 1 +katie polk 1 +katie robinson 1 +katie van buren 1 +katie van buren 1 +katie white 1 +katie white 1 +katie xylophone 1 +katie young 1 +katie young 1 +katie young 1 +katie zipper 1 +katie zipper 1 +luke allen 1 +luke allen 1 +luke allen 1 +luke allen 1 +luke allen 2 +luke brown 1 +luke davidson 1 +luke davidson 1 +luke ellison 1 +luke ellison 1 +luke ellison 1 +luke falkner 1 +luke falkner 1 +luke garcia 1 +luke garcia 1 +luke ichabod 1 +luke ichabod 1 +luke johnson 1 +luke johnson 1 +luke johnson 1 +luke laertes 1 +luke laertes 1 +luke laertes 1 +luke laertes 1 +luke laertes 1 +luke miller 1 +luke ovid 1 +luke ovid 1 +luke polk 1 +luke polk 1 +luke quirinius 1 +luke robinson 1 +luke robinson 1 +luke thompson 1 +luke underhill 1 +luke underhill 1 +luke underhill 2 +luke van buren 1 +luke white 1 +luke xylophone 1 +luke zipper 1 +mike allen 1 +mike brown 1 +mike carson 1 +mike carson 1 +mike carson 1 +mike davidson 1 +mike davidson 1 +mike ellison 1 +mike ellison 1 +mike ellison 1 +mike ellison 1 +mike ellison 1 +mike falkner 1 +mike garcia 1 +mike garcia 1 +mike garcia 1 +mike hernandez 1 +mike hernandez 2 +mike ichabod 1 +mike king 1 +mike king 1 +mike king 1 +mike king 1 +mike king 1 +mike king 2 +mike miller 1 +mike nixon 1 +mike nixon 1 +mike polk 1 +mike polk 1 +mike polk 1 +mike quirinius 1 +mike steinbeck 1 +mike steinbeck 1 +mike steinbeck 1 +mike steinbeck 1 +mike van buren 1 +mike van buren 1 +mike white 1 +mike white 1 +mike white 1 +mike white 2 +mike young 1 +mike young 1 +mike young 1 +mike zipper 1 +mike zipper 1 +mike zipper 1 +nick allen 1 +nick allen 1 +nick brown 1 +nick davidson 1 +nick ellison 1 +nick ellison 1 +nick falkner 1 +nick falkner 1 +nick garcia 1 +nick garcia 1 +nick garcia 1 +nick ichabod 1 +nick ichabod 1 +nick ichabod 1 +nick johnson 1 +nick johnson 1 +nick laertes 1 +nick miller 1 +nick nixon 1 +nick ovid 1 +nick polk 1 +nick quirinius 1 +nick quirinius 1 +nick robinson 1 +nick robinson 1 +nick steinbeck 1 +nick thompson 1 +nick underhill 1 +nick van buren 1 +nick xylophone 1 +nick young 1 +nick young 1 +nick zipper 1 +nick zipper 1 +oscar allen 2 +oscar brown 1 +oscar carson 1 +oscar carson 1 +oscar carson 1 +oscar carson 1 +oscar carson 1 +oscar davidson 1 +oscar ellison 1 +oscar ellison 1 +oscar falkner 1 +oscar garcia 1 +oscar hernandez 1 +oscar hernandez 1 +oscar ichabod 1 +oscar ichabod 1 +oscar ichabod 1 +oscar ichabod 1 +oscar johnson 1 +oscar johnson 1 +oscar king 1 +oscar king 1 +oscar king 2 +oscar laertes 1 +oscar laertes 1 +oscar laertes 1 +oscar laertes 1 +oscar nixon 1 +oscar ovid 1 +oscar ovid 1 +oscar ovid 2 +oscar polk 1 +oscar polk 1 +oscar quirinius 1 +oscar quirinius 1 +oscar quirinius 1 +oscar quirinius 1 +oscar robinson 1 +oscar robinson 1 +oscar robinson 1 +oscar robinson 1 +oscar steinbeck 1 +oscar thompson 1 +oscar thompson 1 +oscar thompson 1 +oscar thompson 2 +oscar underhill 1 +oscar van buren 1 +oscar van buren 1 +oscar van buren 1 +oscar white 1 +oscar white 1 +oscar white 1 +oscar white 1 +oscar xylophone 1 +oscar xylophone 1 +oscar xylophone 1 +oscar zipper 1 +oscar zipper 1 +oscar zipper 1 +priscilla brown 1 +priscilla brown 1 +priscilla brown 1 +priscilla carson 1 +priscilla carson 1 +priscilla carson 1 +priscilla ichabod 1 +priscilla ichabod 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla johnson 1 +priscilla king 1 +priscilla nixon 1 +priscilla nixon 2 +priscilla ovid 1 +priscilla ovid 1 +priscilla polk 1 +priscilla quirinius 1 +priscilla thompson 1 +priscilla underhill 1 +priscilla underhill 1 +priscilla van buren 1 +priscilla van buren 1 +priscilla van buren 1 +priscilla white 1 +priscilla xylophone 1 +priscilla xylophone 1 +priscilla xylophone 1 +priscilla young 1 +priscilla young 1 +priscilla zipper 1 +priscilla zipper 1 +quinn allen 1 +quinn allen 1 +quinn brown 1 +quinn brown 1 +quinn brown 1 +quinn davidson 1 +quinn davidson 1 +quinn davidson 1 +quinn davidson 1 +quinn ellison 1 +quinn ellison 1 +quinn garcia 1 +quinn garcia 1 +quinn garcia 1 +quinn garcia 1 +quinn ichabod 1 +quinn king 1 +quinn king 1 +quinn laertes 1 +quinn laertes 1 +quinn laertes 1 +quinn nixon 1 +quinn ovid 1 +quinn quirinius 1 +quinn robinson 1 +quinn steinbeck 1 +quinn steinbeck 4 +quinn thompson 1 +quinn thompson 2 +quinn underhill 1 +quinn underhill 1 +quinn underhill 2 +quinn van buren 1 +quinn young 1 +quinn zipper 1 +quinn zipper 1 +rachel allen 1 +rachel allen 1 +rachel brown 1 +rachel brown 1 +rachel brown 1 +rachel brown 1 +rachel brown 2 +rachel carson 1 +rachel carson 1 +rachel davidson 1 +rachel ellison 1 +rachel falkner 1 +rachel falkner 1 +rachel falkner 1 +rachel falkner 1 +rachel johnson 1 +rachel king 1 +rachel king 1 +rachel laertes 1 +rachel laertes 1 +rachel ovid 1 +rachel ovid 1 +rachel polk 1 +rachel quirinius 1 +rachel robinson 1 +rachel robinson 1 +rachel robinson 1 +rachel thompson 1 +rachel thompson 1 +rachel thompson 1 +rachel underhill 1 +rachel white 1 +rachel white 1 +rachel young 1 +rachel zipper 1 +rachel zipper 1 +sarah carson 1 +sarah carson 1 +sarah carson 1 +sarah ellison 2 +sarah falkner 1 +sarah falkner 1 +sarah garcia 1 +sarah garcia 1 +sarah garcia 2 +sarah ichabod 1 +sarah ichabod 1 +sarah johnson 1 +sarah johnson 1 +sarah johnson 1 +sarah johnson 1 +sarah king 1 +sarah king 1 +sarah miller 1 +sarah ovid 1 +sarah robinson 1 +sarah robinson 1 +sarah steinbeck 1 +sarah white 1 +sarah white 1 +sarah xylophone 1 +sarah young 1 +sarah zipper 1 +tom brown 1 +tom brown 1 +tom carson 1 +tom carson 1 +tom carson 1 +tom davidson 1 +tom ellison 1 +tom ellison 1 +tom ellison 1 +tom falkner 1 +tom falkner 1 +tom hernandez 1 +tom hernandez 1 +tom ichabod 1 +tom johnson 1 +tom johnson 2 +tom king 1 +tom laertes 1 +tom laertes 1 +tom miller 1 +tom miller 1 +tom miller 2 +tom nixon 1 +tom ovid 1 +tom polk 1 +tom polk 1 +tom quirinius 1 +tom quirinius 1 +tom robinson 1 +tom robinson 1 +tom robinson 2 +tom robinson 2 +tom steinbeck 2 +tom van buren 1 +tom van buren 1 +tom van buren 1 +tom white 1 +tom young 1 +tom young 2 +tom zipper 1 +ulysses brown 1 +ulysses carson 1 +ulysses carson 1 +ulysses carson 1 +ulysses carson 1 +ulysses davidson 1 +ulysses ellison 1 +ulysses garcia 1 +ulysses hernandez 1 +ulysses hernandez 1 +ulysses hernandez 2 +ulysses ichabod 1 +ulysses ichabod 1 +ulysses johnson 1 +ulysses king 1 +ulysses laertes 1 +ulysses laertes 1 +ulysses laertes 1 +ulysses miller 1 +ulysses miller 1 +ulysses nixon 1 +ulysses ovid 1 +ulysses polk 1 +ulysses polk 1 +ulysses polk 1 +ulysses polk 2 +ulysses quirinius 1 +ulysses robinson 1 +ulysses steinbeck 1 +ulysses steinbeck 1 +ulysses thompson 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses underhill 1 +ulysses van buren 1 +ulysses white 1 +ulysses white 1 +ulysses xylophone 1 +ulysses xylophone 1 +ulysses xylophone 1 +ulysses young 1 +ulysses young 1 +ulysses young 1 +victor allen 1 +victor allen 1 +victor brown 1 +victor brown 1 +victor brown 1 +victor brown 1 +victor davidson 1 +victor davidson 1 +victor davidson 2 +victor ellison 1 +victor ellison 1 +victor hernandez 1 +victor hernandez 1 +victor hernandez 1 +victor hernandez 1 +victor hernandez 1 +victor johnson 1 +victor johnson 1 +victor johnson 1 +victor king 1 +victor king 1 +victor laertes 1 +victor laertes 1 +victor miller 2 +victor nixon 1 +victor nixon 1 +victor ovid 1 +victor polk 1 +victor quirinius 1 +victor quirinius 1 +victor robinson 1 +victor robinson 1 +victor steinbeck 1 +victor steinbeck 1 +victor steinbeck 1 +victor thompson 1 +victor van buren 1 +victor van buren 1 +victor white 1 +victor white 1 +victor xylophone 1 +victor xylophone 1 +victor xylophone 1 +victor xylophone 1 +victor xylophone 2 +victor young 1 +victor zipper 1 +wendy allen 1 +wendy allen 1 +wendy allen 1 +wendy brown 1 +wendy brown 1 +wendy ellison 1 +wendy ellison 1 +wendy falkner 1 +wendy falkner 1 +wendy falkner 1 +wendy garcia 1 +wendy garcia 1 +wendy garcia 1 +wendy garcia 1 +wendy hernandez 1 +wendy ichabod 1 +wendy king 1 +wendy king 1 +wendy king 1 +wendy laertes 1 +wendy laertes 1 +wendy laertes 1 +wendy miller 1 +wendy miller 1 +wendy nixon 1 +wendy nixon 1 +wendy ovid 1 +wendy ovid 1 +wendy polk 1 +wendy polk 1 +wendy quirinius 1 +wendy quirinius 1 +wendy robinson 1 +wendy robinson 1 +wendy robinson 1 +wendy steinbeck 1 +wendy thompson 1 +wendy thompson 1 +wendy underhill 1 +wendy underhill 1 +wendy underhill 1 +wendy van buren 1 +wendy van buren 1 +wendy white 1 +wendy xylophone 1 +wendy xylophone 1 +wendy young 1 +wendy young 3 +xavier allen 1 +xavier allen 1 +xavier allen 1 +xavier brown 1 +xavier brown 1 +xavier brown 1 +xavier carson 1 +xavier carson 1 +xavier davidson 1 +xavier davidson 1 +xavier davidson 1 +xavier ellison 1 +xavier ellison 1 +xavier garcia 1 +xavier hernandez 1 +xavier hernandez 1 +xavier hernandez 1 +xavier ichabod 1 +xavier ichabod 1 +xavier johnson 1 +xavier johnson 1 +xavier king 1 +xavier king 1 +xavier laertes 1 +xavier ovid 1 +xavier polk 1 +xavier polk 1 +xavier polk 1 +xavier polk 1 +xavier quirinius 1 +xavier quirinius 1 +xavier quirinius 1 +xavier quirinius 1 +xavier thompson 1 +xavier underhill 1 +xavier white 1 +xavier white 1 +xavier xylophone 1 +xavier zipper 2 +yuri allen 1 +yuri allen 1 +yuri brown 1 +yuri brown 1 +yuri carson 1 +yuri carson 1 +yuri ellison 1 +yuri ellison 1 +yuri falkner 1 +yuri falkner 1 +yuri garcia 1 +yuri hernandez 1 +yuri johnson 1 +yuri johnson 1 +yuri johnson 1 +yuri king 1 +yuri laertes 1 +yuri laertes 1 +yuri nixon 1 +yuri nixon 1 +yuri polk 1 +yuri polk 1 +yuri polk 1 +yuri quirinius 1 +yuri quirinius 1 +yuri quirinius 1 +yuri steinbeck 1 +yuri steinbeck 1 +yuri thompson 1 +yuri underhill 1 +yuri underhill 1 +yuri white 1 +yuri xylophone 1 +zach allen 2 +zach brown 1 +zach brown 1 +zach brown 1 +zach brown 1 +zach brown 1 +zach carson 1 +zach ellison 1 +zach falkner 1 +zach falkner 1 +zach garcia 1 +zach garcia 1 +zach garcia 1 +zach garcia 1 +zach ichabod 1 +zach ichabod 1 +zach king 1 +zach king 1 +zach king 2 +zach miller 1 +zach miller 1 +zach miller 1 +zach ovid 1 +zach ovid 1 +zach ovid 1 +zach ovid 1 +zach quirinius 1 +zach robinson 2 +zach steinbeck 2 +zach steinbeck 2 +zach thompson 1 +zach thompson 1 +zach underhill 1 +zach white 1 +zach xylophone 1 +zach xylophone 1 +zach young 1 +zach zipper 1 +zach zipper 1 +zach zipper 1 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-1-5c5f373e325115d710a7a23fcb1626f1 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-1-5c5f373e325115d710a7a23fcb1626f1 new file mode 100644 index 000000000000..22a6f27253dc --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-1-5c5f373e325115d710a7a23fcb1626f1 @@ -0,0 +1,1049 @@ +zach zipper 4 +zach zipper 3 +zach zipper 1 +zach young 4 +zach xylophone 4 +zach xylophone 1 +zach white 1 +zach underhill 1 +zach thompson 2 +zach thompson 2 +zach steinbeck 5 +zach steinbeck 1 +zach robinson 1 +zach quirinius 3 +zach ovid 5 +zach ovid 4 +zach ovid 3 +zach ovid 1 +zach miller 5 +zach miller 3 +zach miller 1 +zach king 6 +zach king 4 +zach king 1 +zach ichabod 3 +zach ichabod 2 +zach garcia 6 +zach garcia 3 +zach garcia 1 +zach garcia 1 +zach falkner 2 +zach falkner 1 +zach ellison 4 +zach carson 3 +zach brown 5 +zach brown 4 +zach brown 3 +zach brown 2 +zach brown 1 +zach allen 4 +yuri xylophone 3 +yuri white 2 +yuri underhill 6 +yuri underhill 4 +yuri thompson 4 +yuri steinbeck 6 +yuri steinbeck 2 +yuri quirinius 4 +yuri quirinius 3 +yuri quirinius 1 +yuri polk 4 +yuri polk 3 +yuri polk 2 +yuri nixon 3 +yuri nixon 2 +yuri laertes 3 +yuri laertes 1 +yuri king 5 +yuri johnson 4 +yuri johnson 3 +yuri johnson 1 +yuri hernandez 4 +yuri garcia 3 +yuri falkner 7 +yuri falkner 3 +yuri ellison 1 +yuri ellison 1 +yuri carson 7 +yuri carson 4 +yuri brown 3 +yuri brown 1 +yuri allen 3 +yuri allen 2 +xavier zipper 1 +xavier xylophone 1 +xavier white 3 +xavier white 3 +xavier underhill 2 +xavier thompson 3 +xavier quirinius 6 +xavier quirinius 5 +xavier quirinius 2 +xavier quirinius 1 +xavier polk 5 +xavier polk 3 +xavier polk 3 +xavier polk 3 +xavier ovid 5 +xavier laertes 4 +xavier king 3 +xavier king 1 +xavier johnson 3 +xavier johnson 1 +xavier ichabod 2 +xavier ichabod 2 +xavier hernandez 3 +xavier hernandez 1 +xavier hernandez 1 +xavier garcia 4 +xavier ellison 1 +xavier ellison 1 +xavier davidson 5 +xavier davidson 4 +xavier davidson 1 +xavier carson 5 +xavier carson 3 +xavier brown 4 +xavier brown 2 +xavier brown 2 +xavier allen 6 +xavier allen 3 +xavier allen 1 +wendy young 8 +wendy young 2 +wendy xylophone 6 +wendy xylophone 4 +wendy white 5 +wendy van buren 2 +wendy van buren 2 +wendy underhill 6 +wendy underhill 5 +wendy underhill 4 +wendy thompson 5 +wendy thompson 2 +wendy steinbeck 1 +wendy robinson 5 +wendy robinson 3 +wendy robinson 2 +wendy quirinius 6 +wendy quirinius 4 +wendy polk 2 +wendy polk 2 +wendy ovid 4 +wendy ovid 1 +wendy nixon 3 +wendy nixon 1 +wendy miller 2 +wendy miller 1 +wendy laertes 3 +wendy laertes 3 +wendy laertes 1 +wendy king 5 +wendy king 4 +wendy king 1 +wendy ichabod 3 +wendy hernandez 1 +wendy garcia 7 +wendy garcia 5 +wendy garcia 4 +wendy garcia 1 +wendy falkner 3 +wendy falkner 1 +wendy falkner 1 +wendy ellison 2 +wendy ellison 1 +wendy brown 5 +wendy brown 2 +wendy allen 6 +wendy allen 2 +wendy allen 2 +victor zipper 3 +victor young 1 +victor xylophone 6 +victor xylophone 6 +victor xylophone 2 +victor xylophone 1 +victor xylophone 1 +victor white 2 +victor white 1 +victor van buren 4 +victor van buren 4 +victor thompson 2 +victor steinbeck 5 +victor steinbeck 2 +victor steinbeck 1 +victor robinson 4 +victor robinson 2 +victor quirinius 3 +victor quirinius 1 +victor polk 3 +victor ovid 2 +victor nixon 6 +victor nixon 4 +victor miller 1 +victor laertes 4 +victor laertes 3 +victor king 6 +victor king 1 +victor johnson 2 +victor johnson 2 +victor johnson 1 +victor hernandez 6 +victor hernandez 4 +victor hernandez 3 +victor hernandez 1 +victor hernandez 1 +victor ellison 7 +victor ellison 4 +victor davidson 6 +victor davidson 2 +victor davidson 2 +victor brown 4 +victor brown 3 +victor brown 2 +victor brown 1 +victor allen 4 +victor allen 2 +ulysses young 7 +ulysses young 6 +ulysses young 3 +ulysses xylophone 6 +ulysses xylophone 3 +ulysses xylophone 2 +ulysses white 6 +ulysses white 2 +ulysses van buren 3 +ulysses underhill 8 +ulysses underhill 4 +ulysses underhill 3 +ulysses underhill 2 +ulysses underhill 2 +ulysses underhill 1 +ulysses underhill 1 +ulysses thompson 5 +ulysses steinbeck 3 +ulysses steinbeck 1 +ulysses robinson 5 +ulysses quirinius 8 +ulysses polk 6 +ulysses polk 4 +ulysses polk 1 +ulysses polk 1 +ulysses ovid 3 +ulysses nixon 1 +ulysses miller 3 +ulysses miller 2 +ulysses laertes 5 +ulysses laertes 4 +ulysses laertes 2 +ulysses king 2 +ulysses johnson 5 +ulysses ichabod 1 +ulysses ichabod 1 +ulysses hernandez 6 +ulysses hernandez 3 +ulysses hernandez 2 +ulysses garcia 2 +ulysses ellison 2 +ulysses davidson 8 +ulysses carson 4 +ulysses carson 3 +ulysses carson 2 +ulysses carson 1 +ulysses brown 3 +tom zipper 5 +tom young 2 +tom young 1 +tom white 1 +tom van buren 5 +tom van buren 2 +tom van buren 1 +tom steinbeck 4 +tom robinson 8 +tom robinson 4 +tom robinson 3 +tom robinson 2 +tom quirinius 5 +tom quirinius 1 +tom polk 3 +tom polk 2 +tom ovid 2 +tom nixon 5 +tom miller 1 +tom miller 1 +tom miller 1 +tom laertes 4 +tom laertes 2 +tom king 1 +tom johnson 8 +tom johnson 1 +tom ichabod 1 +tom hernandez 3 +tom hernandez 2 +tom falkner 3 +tom falkner 2 +tom ellison 5 +tom ellison 3 +tom ellison 1 +tom davidson 7 +tom carson 3 +tom carson 3 +tom carson 1 +tom brown 4 +tom brown 2 +sarah zipper 1 +sarah young 1 +sarah xylophone 2 +sarah white 4 +sarah white 3 +sarah steinbeck 6 +sarah robinson 3 +sarah robinson 2 +sarah ovid 1 +sarah miller 1 +sarah king 3 +sarah king 2 +sarah johnson 7 +sarah johnson 6 +sarah johnson 4 +sarah johnson 2 +sarah ichabod 4 +sarah ichabod 3 +sarah garcia 2 +sarah garcia 2 +sarah garcia 2 +sarah falkner 7 +sarah falkner 1 +sarah ellison 1 +sarah carson 6 +sarah carson 4 +sarah carson 4 +rachel zipper 8 +rachel zipper 5 +rachel young 3 +rachel white 2 +rachel white 2 +rachel underhill 2 +rachel thompson 5 +rachel thompson 4 +rachel thompson 3 +rachel robinson 10 +rachel robinson 3 +rachel robinson 1 +rachel quirinius 5 +rachel polk 4 +rachel ovid 5 +rachel ovid 4 +rachel laertes 1 +rachel laertes 1 +rachel king 3 +rachel king 1 +rachel johnson 1 +rachel falkner 8 +rachel falkner 5 +rachel falkner 5 +rachel falkner 2 +rachel ellison 6 +rachel davidson 6 +rachel carson 7 +rachel carson 2 +rachel brown 5 +rachel brown 4 +rachel brown 3 +rachel brown 3 +rachel brown 1 +rachel allen 5 +rachel allen 1 +quinn zipper 2 +quinn zipper 2 +quinn young 2 +quinn van buren 2 +quinn underhill 7 +quinn underhill 6 +quinn underhill 2 +quinn thompson 5 +quinn thompson 2 +quinn steinbeck 3 +quinn steinbeck 2 +quinn robinson 2 +quinn quirinius 5 +quinn ovid 6 +quinn nixon 3 +quinn laertes 2 +quinn laertes 2 +quinn laertes 1 +quinn king 2 +quinn king 1 +quinn ichabod 1 +quinn garcia 6 +quinn garcia 3 +quinn garcia 2 +quinn garcia 1 +quinn ellison 7 +quinn ellison 5 +quinn davidson 7 +quinn davidson 4 +quinn davidson 3 +quinn davidson 2 +quinn brown 5 +quinn brown 3 +quinn brown 2 +quinn allen 5 +quinn allen 2 +priscilla zipper 5 +priscilla zipper 2 +priscilla young 4 +priscilla young 1 +priscilla xylophone 7 +priscilla xylophone 2 +priscilla xylophone 1 +priscilla white 4 +priscilla van buren 3 +priscilla van buren 3 +priscilla van buren 2 +priscilla underhill 5 +priscilla underhill 4 +priscilla thompson 2 +priscilla quirinius 4 +priscilla polk 5 +priscilla ovid 4 +priscilla ovid 1 +priscilla nixon 2 +priscilla nixon 1 +priscilla king 4 +priscilla johnson 4 +priscilla johnson 2 +priscilla johnson 2 +priscilla johnson 2 +priscilla johnson 1 +priscilla ichabod 3 +priscilla ichabod 2 +priscilla carson 6 +priscilla carson 5 +priscilla carson 4 +priscilla brown 5 +priscilla brown 5 +priscilla brown 3 +oscar zipper 4 +oscar zipper 4 +oscar zipper 2 +oscar xylophone 7 +oscar xylophone 5 +oscar xylophone 3 +oscar white 5 +oscar white 5 +oscar white 3 +oscar white 2 +oscar van buren 5 +oscar van buren 3 +oscar van buren 2 +oscar underhill 1 +oscar thompson 6 +oscar thompson 3 +oscar thompson 3 +oscar thompson 2 +oscar steinbeck 7 +oscar robinson 7 +oscar robinson 3 +oscar robinson 3 +oscar robinson 1 +oscar quirinius 3 +oscar quirinius 3 +oscar quirinius 2 +oscar quirinius 1 +oscar polk 2 +oscar polk 2 +oscar ovid 4 +oscar ovid 2 +oscar ovid 1 +oscar nixon 1 +oscar laertes 6 +oscar laertes 4 +oscar laertes 3 +oscar laertes 2 +oscar king 4 +oscar king 2 +oscar king 1 +oscar johnson 6 +oscar johnson 3 +oscar ichabod 3 +oscar ichabod 3 +oscar ichabod 1 +oscar ichabod 1 +oscar hernandez 6 +oscar hernandez 6 +oscar garcia 4 +oscar falkner 2 +oscar ellison 2 +oscar ellison 1 +oscar davidson 1 +oscar carson 4 +oscar carson 2 +oscar carson 2 +oscar carson 1 +oscar carson 1 +oscar brown 4 +oscar allen 2 +nick zipper 7 +nick zipper 5 +nick young 4 +nick young 2 +nick xylophone 2 +nick van buren 2 +nick underhill 2 +nick thompson 2 +nick steinbeck 4 +nick robinson 3 +nick robinson 1 +nick quirinius 5 +nick quirinius 1 +nick polk 5 +nick ovid 6 +nick nixon 4 +nick miller 2 +nick laertes 3 +nick johnson 4 +nick johnson 4 +nick ichabod 3 +nick ichabod 3 +nick ichabod 1 +nick garcia 5 +nick garcia 4 +nick garcia 4 +nick falkner 3 +nick falkner 1 +nick ellison 3 +nick ellison 2 +nick davidson 4 +nick brown 3 +nick allen 5 +nick allen 4 +mike zipper 4 +mike zipper 4 +mike zipper 1 +mike young 3 +mike young 1 +mike young 1 +mike white 9 +mike white 7 +mike white 5 +mike white 2 +mike van buren 2 +mike van buren 1 +mike steinbeck 4 +mike steinbeck 2 +mike steinbeck 2 +mike steinbeck 1 +mike quirinius 7 +mike polk 4 +mike polk 2 +mike polk 2 +mike nixon 3 +mike nixon 2 +mike miller 1 +mike king 6 +mike king 5 +mike king 4 +mike king 3 +mike king 1 +mike king 1 +mike ichabod 3 +mike hernandez 2 +mike hernandez 1 +mike garcia 3 +mike garcia 2 +mike garcia 1 +mike falkner 2 +mike ellison 6 +mike ellison 5 +mike ellison 3 +mike ellison 1 +mike ellison 1 +mike davidson 5 +mike davidson 5 +mike carson 9 +mike carson 4 +mike carson 3 +mike brown 2 +mike allen 3 +luke zipper 2 +luke xylophone 1 +luke white 1 +luke van buren 2 +luke underhill 2 +luke underhill 2 +luke underhill 1 +luke thompson 3 +luke robinson 6 +luke robinson 1 +luke quirinius 3 +luke polk 3 +luke polk 1 +luke ovid 3 +luke ovid 1 +luke miller 4 +luke laertes 4 +luke laertes 2 +luke laertes 2 +luke laertes 2 +luke laertes 1 +luke johnson 4 +luke johnson 2 +luke johnson 1 +luke ichabod 4 +luke ichabod 1 +luke garcia 5 +luke garcia 2 +luke falkner 4 +luke falkner 2 +luke ellison 3 +luke ellison 2 +luke ellison 1 +luke davidson 2 +luke davidson 2 +luke brown 5 +luke allen 5 +luke allen 2 +luke allen 1 +luke allen 1 +luke allen 1 +katie zipper 1 +katie zipper 1 +katie young 11 +katie young 6 +katie young 1 +katie xylophone 1 +katie white 5 +katie white 3 +katie van buren 6 +katie van buren 4 +katie robinson 2 +katie polk 5 +katie polk 2 +katie ovid 3 +katie nixon 1 +katie miller 1 +katie miller 1 +katie king 7 +katie king 5 +katie king 4 +katie ichabod 6 +katie ichabod 2 +katie ichabod 1 +katie hernandez 1 +katie garcia 4 +katie garcia 3 +katie falkner 4 +katie ellison 5 +katie ellison 4 +katie davidson 1 +katie brown 6 +katie allen 1 +jessica zipper 7 +jessica zipper 6 +jessica zipper 1 +jessica young 4 +jessica young 3 +jessica xylophone 3 +jessica white 8 +jessica white 6 +jessica white 3 +jessica white 1 +jessica white 1 +jessica van buren 1 +jessica underhill 5 +jessica underhill 3 +jessica underhill 2 +jessica thompson 3 +jessica thompson 2 +jessica robinson 2 +jessica quirinius 4 +jessica quirinius 4 +jessica quirinius 3 +jessica quirinius 1 +jessica polk 4 +jessica ovid 2 +jessica ovid 1 +jessica nixon 3 +jessica nixon 2 +jessica miller 5 +jessica johnson 4 +jessica johnson 3 +jessica ichabod 5 +jessica garcia 4 +jessica garcia 3 +jessica falkner 2 +jessica ellison 5 +jessica ellison 2 +jessica davidson 5 +jessica davidson 2 +jessica davidson 2 +jessica davidson 1 +jessica carson 4 +jessica carson 2 +jessica carson 1 +jessica brown 3 +irene xylophone 3 +irene van buren 2 +irene van buren 1 +irene underhill 5 +irene underhill 1 +irene thompson 6 +irene steinbeck 1 +irene robinson 1 +irene quirinius 6 +irene quirinius 5 +irene quirinius 5 +irene polk 3 +irene polk 2 +irene polk 2 +irene polk 1 +irene polk 1 +irene ovid 6 +irene ovid 6 +irene ovid 5 +irene nixon 4 +irene nixon 4 +irene nixon 1 +irene miller 6 +irene laertes 5 +irene laertes 3 +irene laertes 3 +irene johnson 2 +irene ichabod 7 +irene ichabod 1 +irene garcia 4 +irene garcia 2 +irene garcia 2 +irene falkner 5 +irene falkner 2 +irene ellison 4 +irene ellison 3 +irene carson 1 +irene brown 4 +irene brown 4 +irene brown 1 +irene allen 2 +holly zipper 3 +holly zipper 3 +holly young 2 +holly young 2 +holly xylophone 1 +holly white 3 +holly white 1 +holly van buren 4 +holly underhill 6 +holly underhill 3 +holly underhill 3 +holly underhill 2 +holly thompson 2 +holly thompson 1 +holly thompson 1 +holly robinson 2 +holly polk 7 +holly polk 4 +holly nixon 5 +holly nixon 1 +holly miller 4 +holly laertes 5 +holly king 4 +holly king 1 +holly johnson 2 +holly johnson 2 +holly johnson 2 +holly ichabod 4 +holly ichabod 4 +holly ichabod 2 +holly hernandez 9 +holly hernandez 3 +holly hernandez 3 +holly hernandez 2 +holly falkner 6 +holly brown 3 +holly brown 2 +holly allen 1 +gabriella zipper 5 +gabriella zipper 1 +gabriella young 3 +gabriella young 1 +gabriella white 3 +gabriella van buren 3 +gabriella van buren 1 +gabriella thompson 5 +gabriella thompson 5 +gabriella thompson 5 +gabriella steinbeck 4 +gabriella steinbeck 1 +gabriella polk 4 +gabriella polk 4 +gabriella ovid 2 +gabriella ovid 1 +gabriella miller 1 +gabriella laertes 4 +gabriella king 3 +gabriella king 3 +gabriella ichabod 3 +gabriella ichabod 3 +gabriella ichabod 3 +gabriella ichabod 2 +gabriella ichabod 1 +gabriella hernandez 9 +gabriella hernandez 5 +gabriella garcia 2 +gabriella falkner 4 +gabriella falkner 3 +gabriella falkner 2 +gabriella ellison 3 +gabriella ellison 1 +gabriella davidson 2 +gabriella carson 1 +gabriella brown 6 +gabriella brown 3 +gabriella allen 5 +gabriella allen 3 +fred zipper 1 +fred young 2 +fred young 1 +fred white 3 +fred van buren 7 +fred van buren 1 +fred van buren 1 +fred van buren 1 +fred underhill 4 +fred steinbeck 4 +fred steinbeck 2 +fred steinbeck 1 +fred robinson 3 +fred quirinius 7 +fred quirinius 4 +fred polk 7 +fred polk 6 +fred polk 4 +fred polk 2 +fred nixon 7 +fred nixon 5 +fred nixon 1 +fred nixon 1 +fred miller 1 +fred laertes 4 +fred king 6 +fred king 3 +fred johnson 4 +fred ichabod 3 +fred ichabod 2 +fred hernandez 1 +fred falkner 4 +fred falkner 3 +fred falkner 3 +fred ellison 5 +fred ellison 2 +fred ellison 1 +fred davidson 2 +fred davidson 2 +fred davidson 1 +ethan zipper 2 +ethan zipper 1 +ethan xylophone 3 +ethan white 5 +ethan white 2 +ethan van buren 1 +ethan underhill 1 +ethan robinson 3 +ethan robinson 1 +ethan quirinius 6 +ethan quirinius 2 +ethan quirinius 1 +ethan polk 3 +ethan polk 1 +ethan polk 1 +ethan polk 1 +ethan ovid 2 +ethan nixon 7 +ethan miller 5 +ethan laertes 4 +ethan laertes 4 +ethan laertes 3 +ethan laertes 2 +ethan laertes 2 +ethan laertes 2 +ethan laertes 1 +ethan king 1 +ethan johnson 1 +ethan hernandez 3 +ethan garcia 8 +ethan falkner 2 +ethan falkner 1 +ethan ellison 6 +ethan ellison 4 +ethan carson 6 +ethan brown 4 +ethan brown 3 +ethan brown 3 +ethan brown 1 +ethan brown 1 +ethan brown 1 +ethan allen 4 +david young 4 +david young 1 +david xylophone 6 +david xylophone 4 +david xylophone 1 +david white 2 +david van buren 3 +david van buren 2 +david underhill 7 +david underhill 4 +david underhill 1 +david thompson 1 +david robinson 3 +david robinson 2 +david quirinius 4 +david quirinius 4 +david quirinius 2 +david ovid 4 +david ovid 3 +david nixon 1 +david laertes 4 +david ichabod 6 +david ichabod 3 +david hernandez 7 +david ellison 5 +david ellison 3 +david ellison 3 +david davidson 4 +david davidson 3 +david davidson 1 +david davidson 1 +david brown 6 +david brown 2 +david allen 5 +david allen 2 +calvin zipper 9 +calvin zipper 3 +calvin young 3 +calvin young 1 +calvin xylophone 6 +calvin xylophone 3 +calvin xylophone 1 +calvin white 1 +calvin white 1 +calvin van buren 9 +calvin van buren 1 +calvin underhill 4 +calvin thompson 3 +calvin thompson 2 +calvin steinbeck 6 +calvin steinbeck 3 +calvin steinbeck 3 +calvin robinson 2 +calvin quirinius 4 +calvin quirinius 3 +calvin polk 2 +calvin ovid 5 +calvin ovid 4 +calvin ovid 3 +calvin ovid 1 +calvin nixon 7 +calvin nixon 3 +calvin nixon 2 +calvin laertes 3 +calvin laertes 1 +calvin johnson 2 +calvin hernandez 1 +calvin garcia 3 +calvin falkner 8 +calvin falkner 4 +calvin falkner 4 +calvin falkner 3 +calvin falkner 2 +calvin falkner 1 +calvin ellison 3 +calvin davidson 1 +calvin davidson 1 +calvin carson 1 +calvin brown 5 +calvin brown 3 +calvin brown 1 +calvin allen 1 +bob zipper 4 +bob zipper 1 +bob zipper 1 +bob young 1 +bob xylophone 3 +bob xylophone 2 +bob white 3 +bob white 1 +bob van buren 3 +bob steinbeck 2 +bob quirinius 4 +bob polk 2 +bob ovid 7 +bob ovid 2 +bob ovid 2 +bob ovid 1 +bob miller 1 +bob laertes 5 +bob laertes 1 +bob king 3 +bob king 3 +bob king 2 +bob ichabod 1 +bob hernandez 1 +bob garcia 4 +bob garcia 3 +bob garcia 2 +bob garcia 1 +bob garcia 1 +bob falkner 6 +bob ellison 3 +bob ellison 2 +bob ellison 1 +bob ellison 1 +bob davidson 5 +bob davidson 2 +bob davidson 2 +bob carson 3 +bob brown 8 +bob brown 6 +bob brown 2 +alice zipper 2 +alice zipper 1 +alice zipper 1 +alice xylophone 2 +alice xylophone 2 +alice xylophone 1 +alice van buren 2 +alice underhill 2 +alice steinbeck 7 +alice steinbeck 3 +alice steinbeck 1 +alice robinson 4 +alice robinson 1 +alice quirinius 6 +alice quirinius 4 +alice polk 1 +alice ovid 2 +alice nixon 2 +alice nixon 2 +alice nixon 1 +alice miller 2 +alice laertes 3 +alice laertes 2 +alice king 8 +alice king 4 +alice king 2 +alice johnson 5 +alice hernandez 8 +alice hernandez 8 +alice garcia 1 +alice falkner 5 +alice davidson 2 +alice carson 1 +alice brown 5 +alice allen 5 +alice allen 5 +alice allen 4 + 5 + 4 + 3 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-2-ac487cc1b94130bf9ce00e07c7075f65 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-2-ac487cc1b94130bf9ce00e07c7075f65 new file mode 100644 index 000000000000..c38e7bbabc21 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-2-ac487cc1b94130bf9ce00e07c7075f65 @@ -0,0 +1,1049 @@ + 0.25047801147227533 + 0.47992351816443596 + 0.6197718631178707 +alice allen 0.7300380228136882 +alice allen 0.8954372623574145 +alice allen 0.9216061185468452 +alice brown 0.22053231939163498 +alice carson 0.2889733840304182 +alice davidson 0.8593155893536122 +alice falkner 0.08604206500956023 +alice garcia 0.2870722433460076 +alice hernandez 0.011472275334608031 +alice hernandez 0.07604562737642585 +alice johnson 0.5181644359464627 +alice king 0.3652007648183556 +alice king 0.8536121673003803 +alice king 0.9771863117870723 +alice laertes 0.870722433460076 +alice laertes 0.870722433460076 +alice miller 0.12045889101338432 +alice nixon 0.4372623574144487 +alice nixon 0.47036328871892924 +alice nixon 0.768642447418738 +alice ovid 0.4665391969407266 +alice polk 0.279467680608365 +alice quirinius 0.8432122370936902 +alice quirinius 0.9923518164435946 +alice robinson 0.5722433460076045 +alice robinson 0.7984790874524715 +alice steinbeck 0.27151051625239003 +alice steinbeck 0.739961759082218 +alice steinbeck 0.9923954372623575 +alice underhill 0.5513307984790875 +alice van buren 0.4923954372623574 +alice xylophone 0.2376425855513308 +alice xylophone 0.26806083650190116 +alice xylophone 0.8776290630975143 +alice zipper 0.33460076045627374 +alice zipper 0.8814531548757171 +alice zipper 0.9445506692160612 +bob brown 0.5038022813688213 +bob brown 0.5066921606118547 +bob brown 0.5372848948374761 +bob carson 0.43346007604562736 +bob davidson 0.21673003802281368 +bob davidson 0.5285171102661597 +bob davidson 0.8413001912045889 +bob ellison 0.2045889101338432 +bob ellison 0.26577437858508607 +bob ellison 0.5793499043977055 +bob ellison 0.9144486692015209 +bob falkner 0.6940726577437859 +bob garcia 0.08555133079847908 +bob garcia 0.17680608365019013 +bob garcia 0.2887189292543021 +bob garcia 0.5418250950570342 +bob garcia 0.5736137667304015 +bob hernandez 0.7813688212927756 +bob ichabod 0.5200764818355641 +bob king 0.0076481835564053535 +bob king 0.5627376425855514 +bob king 0.9524714828897338 +bob laertes 0.32887189292543023 +bob laertes 0.6825095057034221 +bob miller 0.19771863117870722 +bob ovid 0.40304182509505704 +bob ovid 0.40344168260038243 +bob ovid 0.42065009560229444 +bob ovid 0.8403041825095057 +bob polk 0.15019011406844107 +bob quirinius 0.1844106463878327 +bob steinbeck 0.16920152091254753 +bob van buren 0.5086042065009561 +bob white 0.26045627376425856 +bob white 0.7623574144486692 +bob xylophone 0.4474187380497132 +bob xylophone 0.6539923954372624 +bob young 0.4722753346080306 +bob zipper 0.009505703422053232 +bob zipper 0.24091778202676864 +bob zipper 0.4600760456273764 +calvin allen 0.30975143403441685 +calvin brown 0.4448669201520912 +calvin brown 0.5361216730038023 +calvin brown 0.9196940726577438 +calvin carson 0.9315589353612167 +calvin davidson 0.5869980879541109 +calvin davidson 0.6653992395437263 +calvin ellison 0.6977186311787072 +calvin falkner 0.02091254752851711 +calvin falkner 0.03824091778202677 +calvin falkner 0.21223709369024857 +calvin falkner 0.46577946768060835 +calvin falkner 0.5114068441064639 +calvin falkner 0.5950570342205324 +calvin garcia 0.7896749521988528 +calvin hernandez 0.16730038022813687 +calvin johnson 0.9790874524714829 +calvin laertes 0.5487571701720841 +calvin laertes 0.8145315487571702 +calvin nixon 0.019120458891013385 +calvin nixon 0.4467680608365019 +calvin nixon 0.7395437262357415 +calvin ovid 0.14531548757170173 +calvin ovid 0.17490494296577946 +calvin ovid 0.19961977186311788 +calvin ovid 0.9407265774378585 +calvin polk 0.4619771863117871 +calvin quirinius 0.8802281368821293 +calvin quirinius 0.9254302103250478 +calvin robinson 0.13193116634799235 +calvin steinbeck 0.4818355640535373 +calvin steinbeck 0.7418738049713193 +calvin steinbeck 0.8060836501901141 +calvin thompson 0.2179732313575526 +calvin thompson 0.8422053231939164 +calvin underhill 0.7495219885277247 +calvin van buren 0.022813688212927757 +calvin van buren 0.8508604206500956 +calvin white 0.04182509505703422 +calvin white 0.9674952198852772 +calvin xylophone 0.011406844106463879 +calvin xylophone 0.3193116634799235 +calvin xylophone 0.6634799235181644 +calvin young 0.1988527724665392 +calvin young 0.4391634980988593 +calvin zipper 0.5532319391634981 +calvin zipper 0.8726235741444867 +david allen 0.30019120458891013 +david allen 0.3326959847036329 +david brown 0.1338432122370937 +david brown 0.9694072657743786 +david davidson 0.21414913957934992 +david davidson 0.655893536121673 +david davidson 0.7319391634980988 +david davidson 0.8878326996197718 +david ellison 0.6863117870722434 +david ellison 0.6883365200764818 +david ellison 0.7243346007604563 +david hernandez 0.12237093690248566 +david ichabod 0.35564053537284895 +david ichabod 0.7338403041825095 +david laertes 0.3575525812619503 +david nixon 0.33460803059273425 +david ovid 0.3916349809885932 +david ovid 0.6022944550669216 +david quirinius 0.3155893536121673 +david quirinius 0.6577437858508605 +david quirinius 0.9163498098859315 +david robinson 0.6673003802281369 +david robinson 0.6998087954110899 +david thompson 0.25285171102661597 +david underhill 0.1586998087954111 +david underhill 0.35181644359464626 +david underhill 0.7189292543021033 +david van buren 0.05927342256214149 +david van buren 0.5889101338432122 +david white 0.49429657794676807 +david xylophone 0.4875717017208413 +david xylophone 0.6901140684410646 +david xylophone 0.7571701720841301 +david young 0.0019011406844106464 +david young 0.040152963671128104 +ethan allen 0.20532319391634982 +ethan brown 0.10707456978967496 +ethan brown 0.13307984790874525 +ethan brown 0.4340344168260038 +ethan brown 0.4752851711026616 +ethan brown 0.5219885277246654 +ethan brown 0.745697896749522 +ethan carson 0.20912547528517111 +ethan ellison 0.45124282982791586 +ethan ellison 0.8680688336520076 +ethan falkner 0.0994263862332696 +ethan falkner 0.6845124282982792 +ethan garcia 0.06653992395437262 +ethan hernandez 0.2237093690248566 +ethan johnson 0.2300380228136882 +ethan king 0.47418738049713194 +ethan laertes 0.022944550669216062 +ethan laertes 0.2908745247148289 +ethan laertes 0.42638623326959846 +ethan laertes 0.48098859315589354 +ethan laertes 0.6596558317399618 +ethan laertes 0.7839388145315488 +ethan laertes 0.9201520912547528 +ethan miller 0.23709369024856597 +ethan nixon 0.8164435946462715 +ethan ovid 0.6121673003802282 +ethan polk 0.12167300380228137 +ethan polk 0.3384321223709369 +ethan polk 0.6920152091254753 +ethan polk 0.9619771863117871 +ethan quirinius 0.19391634980988592 +ethan quirinius 0.23135755258126195 +ethan quirinius 0.7908745247148289 +ethan robinson 0.24282982791587 +ethan robinson 0.8003802281368821 +ethan underhill 0.6615969581749049 +ethan van buren 0.8365019011406845 +ethan white 0.48859315589353614 +ethan white 0.5741444866920152 +ethan xylophone 0.9695817490494296 +ethan zipper 0.21102661596958175 +ethan zipper 0.6425855513307985 +fred davidson 0.5239005736137667 +fred davidson 0.7414448669201521 +fred davidson 0.8604206500956023 +fred ellison 0.3977055449330784 +fred ellison 0.5506692160611855 +fred ellison 0.7208413001912046 +fred falkner 0.024714828897338403 +fred falkner 0.19120458891013384 +fred falkner 0.9809885931558935 +fred hernandez 0.2734225621414914 +fred ichabod 0.17110266159695817 +fred ichabod 0.780114722753346 +fred johnson 0.30038022813688214 +fred king 0.2198852772466539 +fred king 0.47718631178707227 +fred laertes 0.2332695984703633 +fred miller 0.7858508604206501 +fred nixon 0.005703422053231939 +fred nixon 0.31749049429657794 +fred nixon 0.7648183556405354 +fred nixon 0.8460076045627376 +fred polk 0.16252390057361377 +fred polk 0.564638783269962 +fred polk 0.6273764258555133 +fred polk 0.8155893536121673 +fred quirinius 0.4866920152091255 +fred quirinius 0.8973384030418251 +fred robinson 0.6387832699619772 +fred steinbeck 0.14722753346080306 +fred steinbeck 0.4627151051625239 +fred steinbeck 0.7265774378585086 +fred underhill 0.35361216730038025 +fred van buren 0.3365200764818356 +fred van buren 0.5057034220532319 +fred van buren 0.6463878326996197 +fred van buren 0.904397705544933 +fred white 0.5171102661596958 +fred young 0.7705544933078394 +fred young 0.7992351816443595 +fred zipper 0.615678776290631 +gabriella allen 0.4435946462715105 +gabriella allen 0.9334600760456274 +gabriella brown 0.4359464627151052 +gabriella brown 0.9636711281070746 +gabriella carson 0.9562737642585551 +gabriella davidson 0.8174904942965779 +gabriella ellison 0.1931166347992352 +gabriella ellison 0.38022813688212925 +gabriella falkner 0.3231939163498099 +gabriella falkner 0.5659655831739961 +gabriella falkner 0.8948374760994264 +gabriella garcia 0.4695817490494297 +gabriella hernandez 0.6444866920152091 +gabriella hernandez 0.7015209125475285 +gabriella ichabod 0.09125475285171103 +gabriella ichabod 0.1520912547528517 +gabriella ichabod 0.1835564053537285 +gabriella ichabod 0.372848948374761 +gabriella ichabod 0.8107074569789675 +gabriella king 0.39961759082217974 +gabriella king 0.5190114068441065 +gabriella laertes 0.4569789674952199 +gabriella miller 0.26996197718631176 +gabriella ovid 0.7091254752851711 +gabriella ovid 0.8897338403041825 +gabriella polk 0.030418250950570342 +gabriella polk 0.44106463878326996 +gabriella steinbeck 0.5755258126195029 +gabriella steinbeck 0.8221797323135756 +gabriella thompson 0.013307984790874524 +gabriella thompson 0.44866920152091255 +gabriella thompson 0.7224334600760456 +gabriella van buren 0.6216730038022814 +gabriella van buren 0.6730038022813688 +gabriella white 0.17208413001912046 +gabriella young 0.5076045627376425 +gabriella young 0.7934990439770554 +gabriella zipper 0.23193916349809887 +gabriella zipper 0.8565965583173997 +holly allen 0.11596958174904944 +holly brown 0.11281070745697896 +holly brown 0.155893536121673 +holly falkner 0.124282982791587 +holly hernandez 0.055449330783938815 +holly hernandez 0.32509505703422054 +holly hernandez 0.97131931166348 +holly hernandez 0.9714828897338403 +holly ichabod 0.12357414448669202 +holly ichabod 0.17300380228136883 +holly ichabod 0.629277566539924 +holly johnson 0.33078393881453155 +holly johnson 0.8612167300380228 +holly johnson 0.9391634980988594 +holly king 0.25475285171102663 +holly king 0.3745247148288973 +holly laertes 0.42775665399239543 +holly miller 0.37476099426386233 +holly nixon 0.10076045627376426 +holly nixon 0.34608030592734224 +holly polk 0.40535372848948376 +holly polk 0.5209125475285171 +holly robinson 0.9273422562141491 +holly thompson 0.1596958174904943 +holly thompson 0.311787072243346 +holly thompson 0.9125475285171103 +holly underhill 0.3479923518164436 +holly underhill 0.5812619502868069 +holly underhill 0.8384030418250951 +holly underhill 0.903041825095057 +holly van buren 0.9464627151051626 +holly white 0.1089866156787763 +holly white 0.4780114722753346 +holly xylophone 0.5304182509505704 +holly young 0.7357414448669202 +holly young 0.8240917782026769 +holly zipper 0.15399239543726237 +holly zipper 0.8546845124282982 +irene allen 0.8738049713193117 +irene brown 0.4588910133843212 +irene brown 0.49619771863117873 +irene brown 0.5678776290630975 +irene carson 0.6844106463878327 +irene ellison 0.32504780114722753 +irene ellison 0.48565965583174 +irene falkner 0.41825095057034223 +irene falkner 0.9866920152091255 +irene garcia 0.11663479923518165 +irene garcia 0.29277566539923955 +irene garcia 0.8126195028680688 +irene ichabod 0.8307984790874525 +irene ichabod 0.9177820267686424 +irene johnson 0.7112810707456979 +irene laertes 0.01338432122370937 +irene laertes 0.1482889733840304 +irene laertes 0.7034220532319392 +irene miller 0.367112810707457 +irene nixon 0.11854684512428298 +irene nixon 0.7927756653992395 +irene nixon 0.9426386233269598 +irene ovid 0.24714828897338403 +irene ovid 0.30210325047801145 +irene ovid 0.779467680608365 +irene polk 0.0038022813688212928 +irene polk 0.45315487571701724 +irene polk 0.6577946768060836 +irene polk 0.8891013384321224 +irene polk 0.9789674952198852 +irene quirinius 0.27533460803059273 +irene quirinius 0.35946462715105165 +irene quirinius 0.384321223709369 +irene robinson 0.18631178707224336 +irene steinbeck 0.9942965779467681 +irene thompson 0.6939163498098859 +irene underhill 0.30401529636711283 +irene underhill 0.3403041825095057 +irene van buren 0.5908221797323135 +irene van buren 0.6634980988593155 +irene xylophone 0.5342205323193916 +jessica brown 0.7680608365019012 +jessica carson 0.3574144486692015 +jessica carson 0.6195028680688337 +jessica carson 0.8269961977186312 +jessica davidson 0.10646387832699619 +jessica davidson 0.34790874524714827 +jessica davidson 0.3593155893536122 +jessica davidson 0.6768060836501901 +jessica ellison 0.0779467680608365 +jessica ellison 0.42015209125475284 +jessica falkner 0.994263862332696 +jessica garcia 0.8279158699808795 +jessica garcia 0.9581749049429658 +jessica ichabod 0.45627376425855515 +jessica johnson 0.30228136882129275 +jessica johnson 0.8049713193116634 +jessica miller 0.8011472275334608 +jessica nixon 0.06500956022944551 +jessica nixon 0.6042065009560229 +jessica ovid 0.15105162523900573 +jessica ovid 0.8992395437262357 +jessica polk 0.4378585086042065 +jessica quirinius 0.058935361216730035 +jessica quirinius 0.4714828897338403 +jessica quirinius 0.5760456273764258 +jessica quirinius 0.8935361216730038 +jessica robinson 0.9638783269961977 +jessica thompson 0.08221797323135756 +jessica thompson 0.5893536121673004 +jessica underhill 0.034220532319391636 +jessica underhill 0.06118546845124283 +jessica underhill 0.9541108986615678 +jessica van buren 0.20650095602294455 +jessica white 0.06273764258555133 +jessica white 0.4149139579349904 +jessica white 0.5798479087452472 +jessica white 0.591254752851711 +jessica white 0.7667304015296367 +jessica xylophone 0.5009560229445507 +jessica young 0.3403441682600382 +jessica young 0.8821292775665399 +jessica zipper 0.14068441064638784 +jessica zipper 0.2984790874524715 +jessica zipper 0.6007604562737643 +katie allen 0.5665399239543726 +katie brown 0.49521988527724664 +katie davidson 0.6730401529636711 +katie ellison 0.3173996175908222 +katie ellison 0.7262357414448669 +katie falkner 0.2676864244741874 +katie garcia 0.049429657794676805 +katie garcia 0.3135755258126195 +katie hernandez 0.6026615969581749 +katie ichabod 0.15296367112810708 +katie ichabod 0.4684512428298279 +katie ichabod 0.7055449330783938 +katie king 0.16159695817490494 +katie king 0.502868068833652 +katie king 0.5927342256214149 +katie miller 0.5228136882129277 +katie miller 0.5296367112810707 +katie nixon 0.7832699619771863 +katie ovid 0.8795411089866156 +katie polk 0.35372848948374763 +katie polk 0.9657794676806084 +katie robinson 0.06844106463878327 +katie van buren 0.06883365200764818 +katie van buren 0.1739961759082218 +katie white 0.045889101338432124 +katie white 0.18546845124282982 +katie xylophone 0.7281368821292775 +katie young 0.16443594646271512 +katie young 0.20152091254752852 +katie young 0.9732313575525813 +katie zipper 0.21863117870722434 +katie zipper 0.4505703422053232 +luke allen 0.03612167300380228 +luke allen 0.21606118546845124 +luke allen 0.8346007604562737 +luke allen 0.8631178707224335 +luke allen 0.9311663479923518 +luke brown 0.7304015296367112 +luke davidson 0.25239005736137665 +luke davidson 0.9961977186311787 +luke ellison 0.1147227533460803 +luke ellison 0.2447418738049713 +luke ellison 0.49809885931558934 +luke falkner 0.24524714828897337 +luke falkner 0.5124282982791587 +luke garcia 0.03441682600382409 +luke garcia 0.32695984703632885 +luke ichabod 0.10266159695817491 +luke ichabod 0.5551330798479087 +luke johnson 0.25430210325047803 +luke johnson 0.6787762906309751 +luke johnson 0.9082217973231358 +luke laertes 0.06309751434034416 +luke laertes 0.3690248565965583 +luke laertes 0.7743785850860421 +luke laertes 0.8079847908745247 +luke laertes 0.811787072243346 +luke miller 0.8068833652007649 +luke ovid 0.435361216730038 +luke ovid 0.7547528517110266 +luke polk 0.13957934990439771 +luke polk 0.9770554493307839 +luke quirinius 0.09315589353612168 +luke robinson 0.015209125475285171 +luke robinson 0.053231939163498096 +luke thompson 0.8840304182509505 +luke underhill 0.08745247148288973 +luke underhill 0.40152963671128106 +luke underhill 0.4608030592734226 +luke van buren 0.4847908745247148 +luke white 0.8098859315589354 +luke xylophone 0.34220532319391633 +luke zipper 0.21292775665399238 +mike allen 0.7036328871892925 +mike brown 0.29063097514340347 +mike carson 0.623574144486692 +mike carson 0.7476099426386233 +mike carson 0.9885931558935361 +mike davidson 0.6520912547528517 +mike davidson 0.8298279158699808 +mike ellison 0.24665391969407266 +mike ellison 0.3821292775665399 +mike ellison 0.8355640535372849 +mike ellison 0.8986615678776291 +mike ellison 0.94106463878327 +mike falkner 0.0248565965583174 +mike garcia 0.39543726235741444 +mike garcia 0.5391969407265774 +mike garcia 0.6482889733840305 +mike hernandez 0.07984790874524715 +mike hernandez 0.7186311787072244 +mike ichabod 0.7642585551330798 +mike king 0.09695817490494296 +mike king 0.188212927756654 +mike king 0.4049429657794677 +mike king 0.5544933078393881 +mike king 0.6045627376425855 +mike king 0.9011406844106464 +mike miller 0.621414913957935 +mike nixon 0.688212927756654 +mike nixon 0.9068441064638784 +mike polk 0.3612167300380228 +mike polk 0.6749521988527725 +mike polk 0.8374760994263862 +mike quirinius 0.5105162523900574 +mike steinbeck 0.05736137667304015 +mike steinbeck 0.747148288973384 +mike steinbeck 0.8745247148288974 +mike steinbeck 0.9330783938814532 +mike van buren 0.8650190114068441 +mike van buren 0.973384030418251 +mike white 0.17782026768642448 +mike white 0.7151051625239006 +mike white 0.7566539923954373 +mike white 0.9808795411089866 +mike young 0.20722433460076045 +mike young 0.3840304182509506 +mike young 0.6405353728489483 +mike zipper 0.12810707456978968 +mike zipper 0.42829827915869984 +mike zipper 0.7946768060836502 +nick allen 0.021032504780114723 +nick allen 0.847036328871893 +nick brown 0.14258555133079848 +nick davidson 0.26003824091778205 +nick ellison 0.028680688336520075 +nick ellison 0.3935361216730038 +nick falkner 0.5684410646387833 +nick falkner 0.7590822179732314 +nick garcia 0.34980988593155893 +nick garcia 0.45817490494296575 +nick garcia 0.892925430210325 +nick ichabod 0.2944550669216061 +nick ichabod 0.37667304015296366 +nick ichabod 0.7074569789674953 +nick johnson 0.3973384030418251 +nick johnson 0.4646271510516252 +nick laertes 0.36311787072243346 +nick miller 0.9961759082217974 +nick nixon 0.7110266159695817 +nick ovid 0.7762906309751434 +nick polk 1.0 +nick quirinius 0.0019120458891013384 +nick quirinius 0.08795411089866156 +nick robinson 0.09505703422053231 +nick robinson 0.45506692160611856 +nick steinbeck 0.2224334600760456 +nick thompson 0.4225621414913958 +nick underhill 0.9101338432122371 +nick van buren 0.03802281368821293 +nick xylophone 0.6806883365200764 +nick young 0.4220532319391635 +nick young 0.8623326959847036 +nick zipper 0.2829827915869981 +nick zipper 0.5468451242829828 +oscar allen 0.785171102661597 +oscar brown 0.13498098859315588 +oscar carson 0.07224334600760456 +oscar carson 0.25665399239543724 +oscar carson 0.3422562141491396 +oscar carson 0.6061185468451242 +oscar carson 0.6826003824091779 +oscar davidson 0.7129277566539924 +oscar ellison 0.036328871892925434 +oscar ellison 0.5831739961759083 +oscar falkner 0.9049429657794676 +oscar garcia 0.02676864244741874 +oscar hernandez 0.20076481835564053 +oscar hernandez 0.7870722433460076 +oscar ichabod 0.12619502868068833 +oscar ichabod 0.14149139579349904 +oscar ichabod 0.4416826003824092 +oscar ichabod 0.8661567877629063 +oscar johnson 0.1806083650190114 +oscar johnson 0.467680608365019 +oscar king 0.6596958174904943 +oscar king 0.6787072243346007 +oscar king 0.9258555133079848 +oscar laertes 0.24904942965779467 +oscar laertes 0.5315487571701721 +oscar laertes 0.6328871892925431 +oscar laertes 0.9980988593155894 +oscar nixon 0.9292543021032504 +oscar ovid 0.43021032504780116 +oscar ovid 0.8288973384030418 +oscar ovid 0.8527724665391969 +oscar polk 0.10836501901140684 +oscar polk 0.37858508604206503 +oscar quirinius 0.3041825095057034 +oscar quirinius 0.46387832699619774 +oscar quirinius 0.6311787072243346 +oscar quirinius 0.8555133079847909 +oscar robinson 0.11216730038022814 +oscar robinson 0.22433460076045628 +oscar robinson 0.2294455066921606 +oscar robinson 0.2390057361376673 +oscar steinbeck 0.9904942965779467 +oscar thompson 0.015296367112810707 +oscar thompson 0.2946768060836502 +oscar thompson 0.3060836501901141 +oscar thompson 0.6140684410646388 +oscar underhill 0.31368821292775667 +oscar van buren 0.722753346080306 +oscar van buren 0.7889733840304183 +oscar van buren 0.8833652007648184 +oscar white 0.055133079847908745 +oscar white 0.22562141491395793 +oscar white 0.4321223709369025 +oscar white 0.6443594646271511 +oscar xylophone 0.10133843212237094 +oscar xylophone 0.4187380497131931 +oscar xylophone 0.4296577946768061 +oscar zipper 0.6233269598470363 +oscar zipper 0.7490494296577946 +oscar zipper 0.8783269961977186 +priscilla brown 0.2925430210325048 +priscilla brown 0.6501901140684411 +priscilla brown 0.9120458891013384 +priscilla carson 0.22753346080305928 +priscilla carson 0.5564053537284895 +priscilla carson 0.7820267686424475 +priscilla ichabod 0.3269961977186312 +priscilla ichabod 0.9828897338403042 +priscilla johnson 0.04206500956022945 +priscilla johnson 0.4011406844106464 +priscilla johnson 0.6368821292775665 +priscilla johnson 0.7131931166347992 +priscilla johnson 0.9429657794676806 +priscilla king 0.3517110266159696 +priscilla nixon 0.38049713193116635 +priscilla nixon 0.6864244741873805 +priscilla ovid 0.8193916349809885 +priscilla ovid 0.9139579349904398 +priscilla polk 0.5697896749521989 +priscilla quirinius 0.22179732313575526 +priscilla thompson 0.7737642585551331 +priscilla underhill 0.1682600382409178 +priscilla underhill 0.8852772466539197 +priscilla van buren 0.10325047801147227 +priscilla van buren 0.7877629063097514 +priscilla van buren 0.9598470363288719 +priscilla white 0.4894837476099426 +priscilla xylophone 0.596958174904943 +priscilla xylophone 0.6159695817490495 +priscilla xylophone 0.8393881453154876 +priscilla young 0.41064638783269963 +priscilla young 0.9182509505703422 +priscilla zipper 0.5247148288973384 +priscilla zipper 0.8574144486692015 +quinn allen 0.1634980988593156 +quinn allen 0.9617590822179732 +quinn brown 0.08986615678776291 +quinn brown 0.17590822179732313 +quinn brown 0.5836501901140685 +quinn davidson 0.11787072243346007 +quinn davidson 0.30592734225621415 +quinn davidson 0.3650190114068441 +quinn davidson 0.751434034416826 +quinn ellison 0.376425855513308 +quinn ellison 0.8517110266159695 +quinn garcia 0.17870722433460076 +quinn garcia 0.7323135755258127 +quinn garcia 0.844106463878327 +quinn garcia 0.9486692015209125 +quinn ichabod 0.42395437262357416 +quinn king 0.6653919694072657 +quinn king 0.9505703422053232 +quinn laertes 0.6080305927342257 +quinn laertes 0.9277566539923955 +quinn laertes 0.9847908745247148 +quinn nixon 0.5133079847908745 +quinn ovid 0.16539923954372623 +quinn quirinius 0.19011406844106463 +quinn robinson 0.27756653992395436 +quinn steinbeck 0.23954372623574144 +quinn steinbeck 0.6367112810707457 +quinn thompson 0.4068441064638783 +quinn thompson 0.7782026768642447 +quinn underhill 0.05353728489483748 +quinn underhill 0.5380228136882129 +quinn underhill 0.9349904397705545 +quinn van buren 0.2623574144486692 +quinn young 0.37832699619771865 +quinn zipper 0.51434034416826 +quinn zipper 0.8859315589353612 +rachel allen 0.1701720841300191 +rachel allen 0.3288973384030418 +rachel brown 0.04780114722753346 +rachel brown 0.057034220532319393 +rachel brown 0.532319391634981 +rachel brown 0.5946462715105163 +rachel brown 0.6064638783269962 +rachel carson 0.09177820267686425 +rachel carson 0.6406844106463878 +rachel davidson 0.37093690248565964 +rachel ellison 0.5162523900573613 +rachel falkner 0.1958174904942966 +rachel falkner 0.6330798479087453 +rachel falkner 0.6768642447418738 +rachel falkner 0.9751434034416826 +rachel johnson 0.9560229445506692 +rachel king 0.12547528517110265 +rachel king 0.6003824091778203 +rachel laertes 0.2638623326959847 +rachel laertes 0.5779467680608364 +rachel ovid 0.23518164435946462 +rachel ovid 0.7053231939163498 +rachel polk 0.14638783269961977 +rachel quirinius 0.0076045627376425855 +rachel robinson 0.14340344168260039 +rachel robinson 0.2084130019120459 +rachel robinson 0.6902485659655831 +rachel thompson 0.2718631178707224 +rachel thompson 0.5334608030592735 +rachel thompson 0.875717017208413 +rachel underhill 0.344106463878327 +rachel white 0.17973231357552583 +rachel white 0.6615678776290631 +rachel young 0.3862332695984704 +rachel zipper 0.33079847908745247 +rachel zipper 0.5717017208413002 +sarah carson 0.08935361216730038 +sarah carson 0.18250950570342206 +sarah carson 0.8041825095057035 +sarah ellison 0.967680608365019 +sarah falkner 0.6252390057361377 +sarah falkner 0.9024856596558317 +sarah garcia 0.3881453154875717 +sarah garcia 0.4072657743785851 +sarah garcia 0.6673040152963671 +sarah ichabod 0.29636711281070743 +sarah ichabod 0.9483747609942639 +sarah johnson 0.06463878326996197 +sarah johnson 0.10456273764258556 +sarah johnson 0.5640535372848948 +sarah johnson 0.7954110898661568 +sarah king 0.8030592734225621 +sarah king 0.9655831739961759 +sarah miller 0.6692160611854685 +sarah ovid 0.20342205323193915 +sarah robinson 0.47338403041825095 +sarah robinson 0.7775665399239544 +sarah steinbeck 0.6520076481835564 +sarah white 0.28517110266159695 +sarah white 0.8479087452471483 +sarah xylophone 0.25621414913957935 +sarah young 0.5570342205323194 +sarah zipper 0.5583173996175909 +tom brown 0.5602294455066922 +tom brown 0.8669201520912547 +tom carson 0.045627376425855515 +tom carson 0.35551330798479086 +tom carson 0.935361216730038 +tom davidson 0.8212927756653993 +tom ellison 0.21032504780114722 +tom ellison 0.26195028680688337 +tom ellison 0.7376425855513308 +tom falkner 0.3441682600382409 +tom falkner 0.6481835564053537 +tom hernandez 0.0038240917782026767 +tom hernandez 0.5399239543726235 +tom ichabod 0.6137667304015296 +tom johnson 0.5525812619502868 +tom johnson 0.7915869980879541 +tom king 0.16061185468451242 +tom laertes 0.0745697896749522 +tom laertes 0.5095057034220533 +tom miller 0.2262357414448669 +tom miller 0.2338403041825095 +tom miller 0.2813688212927757 +tom nixon 0.8451242829827916 +tom ovid 0.864244741873805 +tom polk 0.1491395793499044 +tom polk 0.9521988527724665 +tom quirinius 0.09369024856596558 +tom quirinius 0.8489483747609943 +tom robinson 0.060836501901140684 +tom robinson 0.6254752851711026 +tom robinson 0.6462715105162524 +tom robinson 0.9980879541108987 +tom steinbeck 0.5817490494296578 +tom van buren 0.12737642585551331 +tom van buren 0.3154875717017208 +tom van buren 0.7585551330798479 +tom white 0.47609942638623326 +tom young 0.9369024856596558 +tom young 0.9543726235741445 +tom zipper 0.9063097514340345 +ulysses brown 0.9448669201520913 +ulysses carson 0.07034220532319392 +ulysses carson 0.09885931558935361 +ulysses carson 0.2414448669201521 +ulysses carson 0.7604562737642585 +ulysses davidson 0.7093690248565966 +ulysses ellison 0.55893536121673 +ulysses garcia 0.7246653919694073 +ulysses hernandez 0.4091778202676864 +ulysses hernandez 0.627151051625239 +ulysses hernandez 0.982791586998088 +ulysses ichabod 0.21482889733840305 +ulysses ichabod 0.3193916349809886 +ulysses johnson 0.5621414913957935 +ulysses king 0.9467680608365019 +ulysses laertes 0.390057361376673 +ulysses laertes 0.7973231357552581 +ulysses laertes 0.9866156787762906 +ulysses miller 0.31166347992351817 +ulysses miller 0.5774378585086042 +ulysses nixon 0.0057361376673040155 +ulysses ovid 0.38593155893536124 +ulysses polk 0.04752851711026616 +ulysses polk 0.6083650190114068 +ulysses polk 0.7609942638623327 +ulysses polk 0.8326996197718631 +ulysses quirinius 0.6290630975143403 +ulysses robinson 0.9235181644359465 +ulysses steinbeck 0.039923954372623575 +ulysses steinbeck 0.7724665391969407 +ulysses thompson 0.3824091778202677 +ulysses underhill 0.11406844106463879 +ulysses underhill 0.23574144486692014 +ulysses underhill 0.3365019011406844 +ulysses underhill 0.42585551330798477 +ulysses underhill 0.6102661596958175 +ulysses underhill 0.6959847036328872 +ulysses underhill 0.9752851711026616 +ulysses van buren 0.5437262357414449 +ulysses white 0.5 +ulysses white 0.5931558935361216 +ulysses xylophone 0.5855513307984791 +ulysses xylophone 0.8317399617590823 +ulysses xylophone 0.9005736137667304 +ulysses young 0.18164435946462715 +ulysses young 0.3919694072657744 +ulysses young 0.49049429657794674 +victor allen 0.13575525812619502 +victor allen 0.6309751434034416 +victor brown 0.0497131931166348 +victor brown 0.20267686424474188 +victor brown 0.6178707224334601 +victor brown 0.8910133843212237 +victor davidson 0.026615969581749048 +victor davidson 0.491395793499044 +victor davidson 0.5850860420650096 +victor ellison 0.26425855513307983 +victor ellison 0.6692015209125475 +victor hernandez 0.04397705544933078 +victor hernandez 0.12927756653992395 +victor hernandez 0.1950286806883365 +victor hernandez 0.5411089866156787 +victor hernandez 0.7284894837476099 +victor johnson 0.11977186311787072 +victor johnson 0.4828897338403042 +victor johnson 0.7699619771863118 +victor king 0.41254752851711024 +victor king 0.714828897338403 +victor laertes 0.43155893536121676 +victor laertes 0.6500956022944551 +victor miller 0.4429657794676806 +victor nixon 0.33269961977186313 +victor nixon 0.5258126195028681 +victor ovid 0.22813688212927757 +victor polk 0.13878326996197718 +victor quirinius 0.13766730401529637 +victor quirinius 0.887189292543021 +victor robinson 0.5494296577946768 +victor robinson 0.7509505703422054 +victor steinbeck 0.08365019011406843 +victor steinbeck 0.15487571701720843 +victor steinbeck 0.3669201520912547 +victor thompson 0.10516252390057361 +victor van buren 0.27724665391969405 +victor van buren 0.9579349904397706 +victor white 0.41634980988593157 +victor white 0.6349809885931559 +victor xylophone 0.13688212927756654 +victor xylophone 0.3078393881453155 +victor xylophone 0.4110898661567878 +victor xylophone 0.5449330783938815 +victor xylophone 0.9296577946768061 +victor young 0.18738049713193117 +victor zipper 0.5430210325047801 +wendy allen 0.3231357552581262 +wendy allen 0.734225621414914 +wendy allen 0.869980879541109 +wendy brown 0.18929254302103252 +wendy brown 0.6996197718631179 +wendy ellison 0.7437858508604207 +wendy ellison 0.8498098859315589 +wendy falkner 0.07648183556405354 +wendy falkner 0.5353728489483748 +wendy falkner 0.7756653992395437 +wendy garcia 0.07074569789674952 +wendy garcia 0.0741444866920152 +wendy garcia 0.33840304182509506 +wendy garcia 0.38783269961977185 +wendy hernandez 0.017110266159695818 +wendy ichabod 0.8718929254302104 +wendy king 0.37072243346007605 +wendy king 0.497131931166348 +wendy king 0.5965583173996176 +wendy laertes 0.32122370936902483 +wendy laertes 0.49904397705544934 +wendy laertes 0.876425855513308 +wendy miller 0.7533460803059273 +wendy miller 0.7552581261950286 +wendy nixon 0.44933078393881454 +wendy nixon 0.7661596958174905 +wendy ovid 0.5019011406844106 +wendy ovid 0.6978967495219885 +wendy polk 0.3688212927756654 +wendy polk 0.526615969581749 +wendy quirinius 0.1444866920152091 +wendy quirinius 0.5874524714828897 +wendy robinson 0.030592734225621414 +wendy robinson 0.06692160611854685 +wendy robinson 0.27566539923954375 +wendy steinbeck 0.5703422053231939 +wendy thompson 0.028517110266159697 +wendy thompson 0.11089866156787763 +wendy underhill 0.4837476099426386 +wendy underhill 0.6424474187380497 +wendy underhill 0.9600760456273765 +wendy van buren 0.1920152091254753 +wendy van buren 0.7433460076045627 +wendy white 0.752851711026616 +wendy xylophone 0.6347992351816444 +wendy xylophone 0.7452471482889734 +wendy young 0.07839388145315487 +wendy young 0.3897338403041825 +xavier allen 0.043726235741444866 +xavier allen 0.361376673040153 +xavier allen 0.5456273764258555 +xavier brown 0.6711281070745698 +xavier brown 0.9158699808795411 +xavier brown 0.9847036328871893 +xavier carson 0.0841300191204589 +xavier carson 0.988527724665392 +xavier davidson 0.2585551330798479 +xavier davidson 0.4168260038240918 +xavier davidson 0.609942638623327 +xavier ellison 0.5984703632887189 +xavier ellison 0.7361376673040153 +xavier garcia 0.7017208413001912 +xavier hernandez 0.2509505703422053 +xavier hernandez 0.34990439770554493 +xavier hernandez 0.9220532319391636 +xavier ichabod 0.5475285171102662 +xavier ichabod 0.858508604206501 +xavier johnson 0.3938814531548757 +xavier johnson 0.8231939163498099 +xavier king 0.03231939163498099 +xavier king 0.6539196940726577 +xavier laertes 0.5988593155893536 +xavier ovid 0.4397705544933078 +xavier polk 0.4933078393881453 +xavier polk 0.762906309751434 +xavier polk 0.8136882129277566 +xavier polk 0.8260038240917782 +xavier quirinius 0.07265774378585087 +xavier quirinius 0.27915869980879543 +xavier quirinius 0.34600760456273766 +xavier quirinius 0.8022813688212928 +xavier thompson 0.6118546845124283 +xavier underhill 0.16634799235181644 +xavier white 0.6958174904942965 +xavier white 0.7380497131931166 +xavier xylophone 0.8183556405353728 +xavier zipper 0.9904397705544933 +yuri allen 0.9106463878326996 +yuri allen 1.0 +yuri brown 0.5152091254752852 +yuri brown 0.908745247148289 +yuri carson 0.09560229445506692 +yuri carson 0.9372623574144486 +yuri ellison 0.017208413001912046 +yuri ellison 0.39923954372623577 +yuri falkner 0.28680688336520077 +yuri falkner 0.8967495219885278 +yuri garcia 0.2661596958174905 +yuri hernandez 0.28489483747609945 +yuri johnson 0.5047801147227533 +yuri johnson 0.655831739961759 +yuri johnson 0.720532319391635 +yuri king 0.32129277566539927 +yuri laertes 0.4144486692015209 +yuri laertes 0.8916349809885932 +yuri nixon 0.05162523900573614 +yuri nixon 0.40874524714828897 +yuri polk 0.051330798479087454 +yuri polk 0.39579349904397704 +yuri polk 0.6749049429657795 +yuri quirinius 0.08030592734225621 +yuri quirinius 0.2982791586998088 +yuri quirinius 0.4130019120458891 +yuri steinbeck 0.15779467680608364 +yuri steinbeck 0.9388145315487572 +yuri thompson 0.6175908221797323 +yuri underhill 0.42447418738049714 +yuri underhill 0.8202676864244742 +yuri white 0.19694072657743786 +yuri xylophone 0.4790874524714829 +zach allen 0.8250950570342205 +zach brown 0.0817490494296578 +zach brown 0.09751434034416825 +zach brown 0.248565965583174 +zach brown 0.2965779467680608 +zach brown 0.4524714828897338 +zach carson 0.6921606118546845 +zach ellison 0.6806083650190115 +zach falkner 0.25812619502868067 +zach falkner 0.2695984703632887 +zach garcia 0.30798479087452474 +zach garcia 0.3632887189292543 +zach garcia 0.7072243346007605 +zach garcia 0.7167300380228137 +zach ichabod 0.30988593155893535 +zach ichabod 0.9502868068833652 +zach king 0.5277246653919694 +zach king 0.8336520076481836 +zach king 0.9239543726235742 +zach miller 0.15678776290630975 +zach miller 0.3726235741444867 +zach miller 0.5608365019011406 +zach ovid 0.1311787072243346 +zach ovid 0.2737642585551331 +zach ovid 0.4543726235741445 +zach ovid 0.6711026615969582 +zach quirinius 0.019011406844106463 +zach robinson 0.11026615969581749 +zach steinbeck 0.28107074569789675 +zach steinbeck 0.7170172084130019 +zach thompson 0.13001912045889102 +zach thompson 0.44550669216061184 +zach underhill 0.7718631178707225 +zach white 0.7965779467680608 +zach xylophone 0.032504780114722756 +zach xylophone 0.638623326959847 +zach young 0.009560229445506692 +zach zipper 0.24334600760456274 +zach zipper 0.2832699619771863 +zach zipper 0.8087954110898662 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-3-b82dfa24123047be4b4e3c27c3997d34 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-3-b82dfa24123047be4b4e3c27c3997d34 new file mode 100644 index 000000000000..1e0cf03db63a --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 1-3-b82dfa24123047be4b4e3c27c3997d34 @@ -0,0 +1,1049 @@ +zach zipper 0.0 +zach zipper 0.0 +zach zipper 0.0 +zach young 0.0 +zach xylophone 0.0 +zach xylophone 0.0 +zach white 0.0 +zach underhill 0.0 +zach thompson 0.0 +zach thompson 0.0 +zach steinbeck 0.0 +zach steinbeck 0.0 +zach robinson 0.0 +zach quirinius 0.0 +zach ovid 0.0 +zach ovid 0.0 +zach ovid 0.0 +zach ovid 0.0 +zach miller 0.0 +zach miller 0.0 +zach miller 0.0 +zach king 0.0 +zach king 0.0 +zach king 0.0 +zach ichabod 0.0 +zach ichabod 0.0 +zach garcia 0.0 +zach garcia 0.0 +zach garcia 0.0 +zach garcia 0.0 +zach falkner 0.0 +zach falkner 0.0 +zach ellison 0.0 +zach carson 0.0 +zach brown 0.0 +zach brown 0.0 +zach brown 0.0 +zach brown 0.0 +zach brown 0.0 +zach allen 0.0 +yuri xylophone 0.0 +yuri white 0.0 +yuri underhill 0.0 +yuri underhill 0.0 +yuri thompson 0.0 +yuri steinbeck 0.0 +yuri steinbeck 0.0 +yuri quirinius 0.0 +yuri quirinius 0.0 +yuri quirinius 0.0 +yuri polk 0.0 +yuri polk 0.0 +yuri polk 0.0 +yuri nixon 0.0 +yuri nixon 0.0 +yuri laertes 0.0 +yuri laertes 0.0 +yuri king 0.0 +yuri johnson 0.0 +yuri johnson 0.0 +yuri johnson 0.0 +yuri hernandez 0.0 +yuri garcia 0.0 +yuri falkner 0.0 +yuri falkner 0.0 +yuri ellison 0.0 +yuri ellison 0.0 +yuri carson 0.0 +yuri carson 0.0 +yuri brown 0.0 +yuri brown 0.0 +yuri allen 0.0 +yuri allen 0.0 +xavier zipper 1.0 +xavier xylophone 0.0 +xavier white 0.0 +xavier white 0.0 +xavier underhill 0.0 +xavier thompson 0.0 +xavier quirinius 0.0 +xavier quirinius 0.0 +xavier quirinius 0.0 +xavier quirinius 0.0 +xavier polk 0.0 +xavier polk 0.0 +xavier polk 0.0 +xavier polk 0.0 +xavier ovid 0.0 +xavier laertes 0.0 +xavier king 0.0 +xavier king 0.0 +xavier johnson 0.0 +xavier johnson 0.0 +xavier ichabod 0.0 +xavier ichabod 0.0 +xavier hernandez 0.0 +xavier hernandez 0.0 +xavier hernandez 0.0 +xavier garcia 0.0 +xavier ellison 0.0 +xavier ellison 0.0 +xavier davidson 0.0 +xavier davidson 0.0 +xavier davidson 0.0 +xavier carson 1.0 +xavier carson 0.0 +xavier brown 0.0 +xavier brown 0.0 +xavier brown 0.0 +xavier allen 0.0 +xavier allen 0.0 +xavier allen 0.0 +wendy young 1.0 +wendy young 0.0 +wendy xylophone 0.0 +wendy xylophone 0.0 +wendy white 0.0 +wendy van buren 0.0 +wendy van buren 0.0 +wendy underhill 0.0 +wendy underhill 0.0 +wendy underhill 0.0 +wendy thompson 0.0 +wendy thompson 0.0 +wendy steinbeck 0.0 +wendy robinson 0.0 +wendy robinson 0.0 +wendy robinson 0.0 +wendy quirinius 0.0 +wendy quirinius 0.0 +wendy polk 0.0 +wendy polk 0.0 +wendy ovid 0.0 +wendy ovid 0.0 +wendy nixon 0.0 +wendy nixon 0.0 +wendy miller 0.0 +wendy miller 0.0 +wendy laertes 0.0 +wendy laertes 0.0 +wendy laertes 0.0 +wendy king 0.0 +wendy king 0.0 +wendy king 0.0 +wendy ichabod 0.0 +wendy hernandez 0.0 +wendy garcia 0.0 +wendy garcia 0.0 +wendy garcia 0.0 +wendy garcia 0.0 +wendy falkner 0.0 +wendy falkner 0.0 +wendy falkner 0.0 +wendy ellison 0.0 +wendy ellison 0.0 +wendy brown 0.0 +wendy brown 0.0 +wendy allen 0.0 +wendy allen 0.0 +wendy allen 0.0 +victor zipper 0.0 +victor young 0.0 +victor xylophone 0.0 +victor xylophone 0.0 +victor xylophone 0.0 +victor xylophone 0.0 +victor xylophone 0.0 +victor white 1.0 +victor white 0.0 +victor van buren 0.0 +victor van buren 0.0 +victor thompson 0.0 +victor steinbeck 0.0 +victor steinbeck 0.0 +victor steinbeck 0.0 +victor robinson 0.0 +victor robinson 0.0 +victor quirinius 0.0 +victor quirinius 0.0 +victor polk 0.0 +victor ovid 0.0 +victor nixon 0.0 +victor nixon 0.0 +victor miller 0.0 +victor laertes 0.0 +victor laertes 0.0 +victor king 0.0 +victor king 0.0 +victor johnson 0.0 +victor johnson 0.0 +victor johnson 0.0 +victor hernandez 0.0 +victor hernandez 0.0 +victor hernandez 0.0 +victor hernandez 0.0 +victor hernandez 0.0 +victor ellison 0.0 +victor ellison 0.0 +victor davidson 0.0 +victor davidson 0.0 +victor davidson 0.0 +victor brown 0.0 +victor brown 0.0 +victor brown 0.0 +victor brown 0.0 +victor allen 0.0 +victor allen 0.0 +ulysses young 0.0 +ulysses young 0.0 +ulysses young 0.0 +ulysses xylophone 0.0 +ulysses xylophone 0.0 +ulysses xylophone 0.0 +ulysses white 0.0 +ulysses white 0.0 +ulysses van buren 0.0 +ulysses underhill 0.0 +ulysses underhill 0.0 +ulysses underhill 0.0 +ulysses underhill 0.0 +ulysses underhill 0.0 +ulysses underhill 0.0 +ulysses underhill 0.0 +ulysses thompson 1.0 +ulysses steinbeck 0.0 +ulysses steinbeck 0.0 +ulysses robinson 0.0 +ulysses quirinius 0.0 +ulysses polk 1.0 +ulysses polk 0.0 +ulysses polk 0.0 +ulysses polk 0.0 +ulysses ovid 0.0 +ulysses nixon 0.0 +ulysses miller 0.0 +ulysses miller 0.0 +ulysses laertes 0.0 +ulysses laertes 0.0 +ulysses laertes 0.0 +ulysses king 0.0 +ulysses johnson 0.0 +ulysses ichabod 0.0 +ulysses ichabod 0.0 +ulysses hernandez 1.0 +ulysses hernandez 0.0 +ulysses hernandez 0.0 +ulysses garcia 0.0 +ulysses ellison 1.0 +ulysses davidson 0.0 +ulysses carson 0.0 +ulysses carson 0.0 +ulysses carson 0.0 +ulysses carson 0.0 +ulysses brown 0.0 +tom zipper 0.0 +tom young 0.0 +tom young 0.0 +tom white 0.0 +tom van buren 0.0 +tom van buren 0.0 +tom van buren 0.0 +tom steinbeck 0.0 +tom robinson 0.0 +tom robinson 0.0 +tom robinson 0.0 +tom robinson 0.0 +tom quirinius 0.0 +tom quirinius 0.0 +tom polk 0.0 +tom polk 0.0 +tom ovid 0.0 +tom nixon 0.0 +tom miller 0.0 +tom miller 0.0 +tom miller 0.0 +tom laertes 0.0 +tom laertes 0.0 +tom king 0.0 +tom johnson 0.0 +tom johnson 0.0 +tom ichabod 0.0 +tom hernandez 0.0 +tom hernandez 0.0 +tom falkner 0.0 +tom falkner 0.0 +tom ellison 0.0 +tom ellison 0.0 +tom ellison 0.0 +tom davidson 0.0 +tom carson 0.0 +tom carson 0.0 +tom carson 0.0 +tom brown 0.0 +tom brown 0.0 +sarah zipper 0.0 +sarah young 0.0 +sarah xylophone 0.0 +sarah white 0.0 +sarah white 0.0 +sarah steinbeck 0.0 +sarah robinson 0.0 +sarah robinson 0.0 +sarah ovid 0.0 +sarah miller 0.0 +sarah king 0.0 +sarah king 0.0 +sarah johnson 0.0 +sarah johnson 0.0 +sarah johnson 0.0 +sarah johnson 0.0 +sarah ichabod 0.0 +sarah ichabod 0.0 +sarah garcia 0.0 +sarah garcia 0.0 +sarah garcia 0.0 +sarah falkner 0.0 +sarah falkner 0.0 +sarah ellison 0.0 +sarah carson 0.0 +sarah carson 0.0 +sarah carson 0.0 +rachel zipper 0.0 +rachel zipper 0.0 +rachel young 0.0 +rachel white 0.0 +rachel white 0.0 +rachel underhill 0.0 +rachel thompson 0.0 +rachel thompson 0.0 +rachel thompson 0.0 +rachel robinson 1.0 +rachel robinson 0.0 +rachel robinson 0.0 +rachel quirinius 0.0 +rachel polk 0.0 +rachel ovid 0.0 +rachel ovid 0.0 +rachel laertes 0.0 +rachel laertes 0.0 +rachel king 0.0 +rachel king 0.0 +rachel johnson 0.0 +rachel falkner 0.0 +rachel falkner 0.0 +rachel falkner 0.0 +rachel falkner 0.0 +rachel ellison 0.0 +rachel davidson 0.0 +rachel carson 0.0 +rachel carson 0.0 +rachel brown 1.0 +rachel brown 0.0 +rachel brown 0.0 +rachel brown 0.0 +rachel brown 0.0 +rachel allen 0.0 +rachel allen 0.0 +quinn zipper 0.0 +quinn zipper 0.0 +quinn young 0.0 +quinn van buren 0.0 +quinn underhill 0.0 +quinn underhill 0.0 +quinn underhill 0.0 +quinn thompson 0.0 +quinn thompson 0.0 +quinn steinbeck 0.0 +quinn steinbeck 0.0 +quinn robinson 0.0 +quinn quirinius 0.0 +quinn ovid 0.0 +quinn nixon 0.0 +quinn laertes 1.0 +quinn laertes 0.0 +quinn laertes 0.0 +quinn king 1.0 +quinn king 0.0 +quinn ichabod 0.0 +quinn garcia 1.0 +quinn garcia 0.0 +quinn garcia 0.0 +quinn garcia 0.0 +quinn ellison 0.0 +quinn ellison 0.0 +quinn davidson 1.0 +quinn davidson 0.0 +quinn davidson 0.0 +quinn davidson 0.0 +quinn brown 0.0 +quinn brown 0.0 +quinn brown 0.0 +quinn allen 1.0 +quinn allen 0.0 +priscilla zipper 0.0 +priscilla zipper 0.0 +priscilla young 0.0 +priscilla young 0.0 +priscilla xylophone 0.0 +priscilla xylophone 0.0 +priscilla xylophone 0.0 +priscilla white 1.0 +priscilla van buren 0.0 +priscilla van buren 0.0 +priscilla van buren 0.0 +priscilla underhill 0.0 +priscilla underhill 0.0 +priscilla thompson 0.0 +priscilla quirinius 0.0 +priscilla polk 0.0 +priscilla ovid 0.0 +priscilla ovid 0.0 +priscilla nixon 0.0 +priscilla nixon 0.0 +priscilla king 0.0 +priscilla johnson 0.0 +priscilla johnson 0.0 +priscilla johnson 0.0 +priscilla johnson 0.0 +priscilla johnson 0.0 +priscilla ichabod 0.0 +priscilla ichabod 0.0 +priscilla carson 0.0 +priscilla carson 0.0 +priscilla carson 0.0 +priscilla brown 0.0 +priscilla brown 0.0 +priscilla brown 0.0 +oscar zipper 0.0 +oscar zipper 0.0 +oscar zipper 0.0 +oscar xylophone 0.0 +oscar xylophone 0.0 +oscar xylophone 0.0 +oscar white 0.0 +oscar white 0.0 +oscar white 0.0 +oscar white 0.0 +oscar van buren 1.0 +oscar van buren 0.0 +oscar van buren 0.0 +oscar underhill 0.0 +oscar thompson 0.0 +oscar thompson 0.0 +oscar thompson 0.0 +oscar thompson 0.0 +oscar steinbeck 0.0 +oscar robinson 0.0 +oscar robinson 0.0 +oscar robinson 0.0 +oscar robinson 0.0 +oscar quirinius 0.0 +oscar quirinius 0.0 +oscar quirinius 0.0 +oscar quirinius 0.0 +oscar polk 1.0 +oscar polk 0.0 +oscar ovid 0.0 +oscar ovid 0.0 +oscar ovid 0.0 +oscar nixon 0.0 +oscar laertes 0.0 +oscar laertes 0.0 +oscar laertes 0.0 +oscar laertes 0.0 +oscar king 0.0 +oscar king 0.0 +oscar king 0.0 +oscar johnson 0.0 +oscar johnson 0.0 +oscar ichabod 0.0 +oscar ichabod 0.0 +oscar ichabod 0.0 +oscar ichabod 0.0 +oscar hernandez 0.0 +oscar hernandez 0.0 +oscar garcia 0.0 +oscar falkner 1.0 +oscar ellison 0.0 +oscar ellison 0.0 +oscar davidson 0.0 +oscar carson 0.0 +oscar carson 0.0 +oscar carson 0.0 +oscar carson 0.0 +oscar carson 0.0 +oscar brown 0.0 +oscar allen 0.0 +nick zipper 0.0 +nick zipper 0.0 +nick young 1.0 +nick young 0.0 +nick xylophone 0.0 +nick van buren 0.0 +nick underhill 0.0 +nick thompson 0.0 +nick steinbeck 0.0 +nick robinson 0.0 +nick robinson 0.0 +nick quirinius 0.0 +nick quirinius 0.0 +nick polk 0.0 +nick ovid 0.0 +nick nixon 0.0 +nick miller 0.0 +nick laertes 0.0 +nick johnson 0.0 +nick johnson 0.0 +nick ichabod 0.0 +nick ichabod 0.0 +nick ichabod 0.0 +nick garcia 0.0 +nick garcia 0.0 +nick garcia 0.0 +nick falkner 0.0 +nick falkner 0.0 +nick ellison 0.0 +nick ellison 0.0 +nick davidson 0.0 +nick brown 0.0 +nick allen 0.0 +nick allen 0.0 +mike zipper 0.0 +mike zipper 0.0 +mike zipper 0.0 +mike young 0.0 +mike young 0.0 +mike young 0.0 +mike white 0.0 +mike white 0.0 +mike white 0.0 +mike white 0.0 +mike van buren 0.0 +mike van buren 0.0 +mike steinbeck 0.0 +mike steinbeck 0.0 +mike steinbeck 0.0 +mike steinbeck 0.0 +mike quirinius 0.0 +mike polk 0.0 +mike polk 0.0 +mike polk 0.0 +mike nixon 0.0 +mike nixon 0.0 +mike miller 0.0 +mike king 0.0 +mike king 0.0 +mike king 0.0 +mike king 0.0 +mike king 0.0 +mike king 0.0 +mike ichabod 0.0 +mike hernandez 0.0 +mike hernandez 0.0 +mike garcia 0.0 +mike garcia 0.0 +mike garcia 0.0 +mike falkner 0.0 +mike ellison 0.0 +mike ellison 0.0 +mike ellison 0.0 +mike ellison 0.0 +mike ellison 0.0 +mike davidson 0.0 +mike davidson 0.0 +mike carson 0.0 +mike carson 0.0 +mike carson 0.0 +mike brown 0.0 +mike allen 0.0 +luke zipper 0.0 +luke xylophone 0.0 +luke white 0.0 +luke van buren 0.0 +luke underhill 1.0 +luke underhill 0.0 +luke underhill 0.0 +luke thompson 0.0 +luke robinson 0.0 +luke robinson 0.0 +luke quirinius 0.0 +luke polk 0.0 +luke polk 0.0 +luke ovid 0.0 +luke ovid 0.0 +luke miller 0.0 +luke laertes 0.0 +luke laertes 0.0 +luke laertes 0.0 +luke laertes 0.0 +luke laertes 0.0 +luke johnson 0.0 +luke johnson 0.0 +luke johnson 0.0 +luke ichabod 0.0 +luke ichabod 0.0 +luke garcia 0.0 +luke garcia 0.0 +luke falkner 0.0 +luke falkner 0.0 +luke ellison 0.0 +luke ellison 0.0 +luke ellison 0.0 +luke davidson 0.0 +luke davidson 0.0 +luke brown 0.0 +luke allen 1.0 +luke allen 0.0 +luke allen 0.0 +luke allen 0.0 +luke allen 0.0 +katie zipper 1.0 +katie zipper 0.0 +katie young 1.0 +katie young 0.0 +katie young 0.0 +katie xylophone 0.0 +katie white 0.0 +katie white 0.0 +katie van buren 0.0 +katie van buren 0.0 +katie robinson 0.0 +katie polk 0.0 +katie polk 0.0 +katie ovid 0.0 +katie nixon 1.0 +katie miller 0.0 +katie miller 0.0 +katie king 0.0 +katie king 0.0 +katie king 0.0 +katie ichabod 0.0 +katie ichabod 0.0 +katie ichabod 0.0 +katie hernandez 0.0 +katie garcia 0.0 +katie garcia 0.0 +katie falkner 0.0 +katie ellison 0.0 +katie ellison 0.0 +katie davidson 0.0 +katie brown 0.0 +katie allen 0.0 +jessica zipper 0.0 +jessica zipper 0.0 +jessica zipper 0.0 +jessica young 0.0 +jessica young 0.0 +jessica xylophone 0.0 +jessica white 0.0 +jessica white 0.0 +jessica white 0.0 +jessica white 0.0 +jessica white 0.0 +jessica van buren 0.0 +jessica underhill 1.0 +jessica underhill 0.0 +jessica underhill 0.0 +jessica thompson 0.0 +jessica thompson 0.0 +jessica robinson 0.0 +jessica quirinius 1.0 +jessica quirinius 0.0 +jessica quirinius 0.0 +jessica quirinius 0.0 +jessica polk 0.0 +jessica ovid 0.0 +jessica ovid 0.0 +jessica nixon 0.0 +jessica nixon 0.0 +jessica miller 0.0 +jessica johnson 0.0 +jessica johnson 0.0 +jessica ichabod 1.0 +jessica garcia 0.0 +jessica garcia 0.0 +jessica falkner 0.0 +jessica ellison 0.0 +jessica ellison 0.0 +jessica davidson 0.0 +jessica davidson 0.0 +jessica davidson 0.0 +jessica davidson 0.0 +jessica carson 0.0 +jessica carson 0.0 +jessica carson 0.0 +jessica brown 0.0 +irene xylophone 0.0 +irene van buren 0.0 +irene van buren 0.0 +irene underhill 0.0 +irene underhill 0.0 +irene thompson 0.0 +irene steinbeck 0.0 +irene robinson 0.0 +irene quirinius 0.5 +irene quirinius 0.0 +irene quirinius 0.0 +irene polk 0.0 +irene polk 0.0 +irene polk 0.0 +irene polk 0.0 +irene polk 0.0 +irene ovid 0.0 +irene ovid 0.0 +irene ovid 0.0 +irene nixon 0.0 +irene nixon 0.0 +irene nixon 0.0 +irene miller 0.0 +irene laertes 0.0 +irene laertes 0.0 +irene laertes 0.0 +irene johnson 0.0 +irene ichabod 0.0 +irene ichabod 0.0 +irene garcia 0.0 +irene garcia 0.0 +irene garcia 0.0 +irene falkner 0.0 +irene falkner 0.0 +irene ellison 0.0 +irene ellison 0.0 +irene carson 0.0 +irene brown 0.0 +irene brown 0.0 +irene brown 0.0 +irene allen 0.0 +holly zipper 1.0 +holly zipper 0.0 +holly young 0.0 +holly young 0.0 +holly xylophone 0.0 +holly white 0.0 +holly white 0.0 +holly van buren 0.0 +holly underhill 1.0 +holly underhill 0.0 +holly underhill 0.0 +holly underhill 0.0 +holly thompson 1.0 +holly thompson 0.0 +holly thompson 0.0 +holly robinson 0.0 +holly polk 0.0 +holly polk 0.0 +holly nixon 0.0 +holly nixon 0.0 +holly miller 1.0 +holly laertes 0.0 +holly king 0.0 +holly king 0.0 +holly johnson 0.0 +holly johnson 0.0 +holly johnson 0.0 +holly ichabod 0.0 +holly ichabod 0.0 +holly ichabod 0.0 +holly hernandez 0.0 +holly hernandez 0.0 +holly hernandez 0.0 +holly hernandez 0.0 +holly falkner 0.0 +holly brown 0.0 +holly brown 0.0 +holly allen 0.0 +gabriella zipper 0.0 +gabriella zipper 0.0 +gabriella young 0.0 +gabriella young 0.0 +gabriella white 0.0 +gabriella van buren 0.0 +gabriella van buren 0.0 +gabriella thompson 0.0 +gabriella thompson 0.0 +gabriella thompson 0.0 +gabriella steinbeck 0.0 +gabriella steinbeck 0.0 +gabriella polk 0.0 +gabriella polk 0.0 +gabriella ovid 0.0 +gabriella ovid 0.0 +gabriella miller 0.0 +gabriella laertes 0.0 +gabriella king 0.0 +gabriella king 0.0 +gabriella ichabod 1.0 +gabriella ichabod 0.0 +gabriella ichabod 0.0 +gabriella ichabod 0.0 +gabriella ichabod 0.0 +gabriella hernandez 1.0 +gabriella hernandez 0.0 +gabriella garcia 0.0 +gabriella falkner 0.0 +gabriella falkner 0.0 +gabriella falkner 0.0 +gabriella ellison 0.0 +gabriella ellison 0.0 +gabriella davidson 0.0 +gabriella carson 0.0 +gabriella brown 0.0 +gabriella brown 0.0 +gabriella allen 0.0 +gabriella allen 0.0 +fred zipper 0.0 +fred young 0.0 +fred young 0.0 +fred white 0.0 +fred van buren 0.0 +fred van buren 0.0 +fred van buren 0.0 +fred van buren 0.0 +fred underhill 0.0 +fred steinbeck 0.0 +fred steinbeck 0.0 +fred steinbeck 0.0 +fred robinson 1.0 +fred quirinius 0.0 +fred quirinius 0.0 +fred polk 0.0 +fred polk 0.0 +fred polk 0.0 +fred polk 0.0 +fred nixon 0.0 +fred nixon 0.0 +fred nixon 0.0 +fred nixon 0.0 +fred miller 0.0 +fred laertes 0.0 +fred king 0.0 +fred king 0.0 +fred johnson 1.0 +fred ichabod 0.0 +fred ichabod 0.0 +fred hernandez 0.0 +fred falkner 1.0 +fred falkner 0.0 +fred falkner 0.0 +fred ellison 0.0 +fred ellison 0.0 +fred ellison 0.0 +fred davidson 0.0 +fred davidson 0.0 +fred davidson 0.0 +ethan zipper 0.0 +ethan zipper 0.0 +ethan xylophone 0.0 +ethan white 0.0 +ethan white 0.0 +ethan van buren 0.0 +ethan underhill 0.5 +ethan robinson 0.0 +ethan robinson 0.0 +ethan quirinius 0.0 +ethan quirinius 0.0 +ethan quirinius 0.0 +ethan polk 1.0 +ethan polk 0.0 +ethan polk 0.0 +ethan polk 0.0 +ethan ovid 0.0 +ethan nixon 0.0 +ethan miller 0.0 +ethan laertes 0.0 +ethan laertes 0.0 +ethan laertes 0.0 +ethan laertes 0.0 +ethan laertes 0.0 +ethan laertes 0.0 +ethan laertes 0.0 +ethan king 0.0 +ethan johnson 0.0 +ethan hernandez 0.0 +ethan garcia 0.0 +ethan falkner 0.0 +ethan falkner 0.0 +ethan ellison 0.0 +ethan ellison 0.0 +ethan carson 0.0 +ethan brown 1.0 +ethan brown 0.0 +ethan brown 0.0 +ethan brown 0.0 +ethan brown 0.0 +ethan brown 0.0 +ethan allen 0.0 +david young 0.0 +david young 0.0 +david xylophone 1.0 +david xylophone 0.0 +david xylophone 0.0 +david white 0.0 +david van buren 0.0 +david van buren 0.0 +david underhill 0.0 +david underhill 0.0 +david underhill 0.0 +david thompson 1.0 +david robinson 0.0 +david robinson 0.0 +david quirinius 0.0 +david quirinius 0.0 +david quirinius 0.0 +david ovid 0.0 +david ovid 0.0 +david nixon 0.0 +david laertes 0.0 +david ichabod 1.0 +david ichabod 0.0 +david hernandez 1.0 +david ellison 0.0 +david ellison 0.0 +david ellison 0.0 +david davidson 0.0 +david davidson 0.0 +david davidson 0.0 +david davidson 0.0 +david brown 0.0 +david brown 0.0 +david allen 0.0 +david allen 0.0 +calvin zipper 0.0 +calvin zipper 0.0 +calvin young 0.0 +calvin young 0.0 +calvin xylophone 0.0 +calvin xylophone 0.0 +calvin xylophone 0.0 +calvin white 0.0 +calvin white 0.0 +calvin van buren 1.0 +calvin van buren 0.0 +calvin underhill 0.0 +calvin thompson 0.0 +calvin thompson 0.0 +calvin steinbeck 0.0 +calvin steinbeck 0.0 +calvin steinbeck 0.0 +calvin robinson 0.0 +calvin quirinius 0.0 +calvin quirinius 0.0 +calvin polk 0.0 +calvin ovid 0.0 +calvin ovid 0.0 +calvin ovid 0.0 +calvin ovid 0.0 +calvin nixon 0.0 +calvin nixon 0.0 +calvin nixon 0.0 +calvin laertes 0.0 +calvin laertes 0.0 +calvin johnson 0.0 +calvin hernandez 0.0 +calvin garcia 0.0 +calvin falkner 0.0 +calvin falkner 0.0 +calvin falkner 0.0 +calvin falkner 0.0 +calvin falkner 0.0 +calvin falkner 0.0 +calvin ellison 0.0 +calvin davidson 0.0 +calvin davidson 0.0 +calvin carson 0.0 +calvin brown 0.0 +calvin brown 0.0 +calvin brown 0.0 +calvin allen 0.0 +bob zipper 0.0 +bob zipper 0.0 +bob zipper 0.0 +bob young 0.0 +bob xylophone 0.0 +bob xylophone 0.0 +bob white 0.0 +bob white 0.0 +bob van buren 0.0 +bob steinbeck 0.0 +bob quirinius 0.0 +bob polk 0.0 +bob ovid 0.0 +bob ovid 0.0 +bob ovid 0.0 +bob ovid 0.0 +bob miller 0.0 +bob laertes 0.0 +bob laertes 0.0 +bob king 1.0 +bob king 0.0 +bob king 0.0 +bob ichabod 0.0 +bob hernandez 1.0 +bob garcia 0.0 +bob garcia 0.0 +bob garcia 0.0 +bob garcia 0.0 +bob garcia 0.0 +bob falkner 0.0 +bob ellison 1.0 +bob ellison 0.0 +bob ellison 0.0 +bob ellison 0.0 +bob davidson 0.0 +bob davidson 0.0 +bob davidson 0.0 +bob carson 0.0 +bob brown 0.0 +bob brown 0.0 +bob brown 0.0 +alice zipper 0.0 +alice zipper 0.0 +alice zipper 0.0 +alice xylophone 0.0 +alice xylophone 0.0 +alice xylophone 0.0 +alice van buren 0.0 +alice underhill 0.0 +alice steinbeck 0.0 +alice steinbeck 0.0 +alice steinbeck 0.0 +alice robinson 0.0 +alice robinson 0.0 +alice quirinius 0.0 +alice quirinius 0.0 +alice polk 1.0 +alice ovid 0.0 +alice nixon 0.0 +alice nixon 0.0 +alice nixon 0.0 +alice miller 0.0 +alice laertes 0.0 +alice laertes 0.0 +alice king 0.0 +alice king 0.0 +alice king 0.0 +alice johnson 0.0 +alice hernandez 0.0 +alice hernandez 0.0 +alice garcia 0.0 +alice falkner 0.0 +alice davidson 0.0 +alice carson 0.0 +alice brown 0.0 +alice allen 0.0 +alice allen 0.0 +alice allen 0.0 + 0.0 + 0.0 + 0.0 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 2-0-81bb7f49a55385878637c8aac4d08e5 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 2-0-81bb7f49a55385878637c8aac4d08e5 new file mode 100644 index 000000000000..9091a9156134 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 2-0-81bb7f49a55385878637c8aac4d08e5 @@ -0,0 +1,1294 @@ +2013-03-01 09:11:58.70307 52.64 1 +2013-03-01 09:11:58.70307 52.64 1 +2013-03-01 09:11:58.70307 52.64 1 +2013-03-01 09:11:58.70307 52.64 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703071 71.07 1 +2013-03-01 09:11:58.703072 2.96 1 +2013-03-01 09:11:58.703072 2.96 1 +2013-03-01 09:11:58.703072 2.96 1 +2013-03-01 09:11:58.703072 2.96 1 +2013-03-01 09:11:58.703072 2.96 1 +2013-03-01 09:11:58.703073 10.07 1 +2013-03-01 09:11:58.703073 10.07 1 +2013-03-01 09:11:58.703073 10.07 1 +2013-03-01 09:11:58.703073 10.07 1 +2013-03-01 09:11:58.703074 37.8 1 +2013-03-01 09:11:58.703074 37.8 1 +2013-03-01 09:11:58.703074 37.8 1 +2013-03-01 09:11:58.703074 37.8 1 +2013-03-01 09:11:58.703074 37.8 1 +2013-03-01 09:11:58.703074 37.8 1 +2013-03-01 09:11:58.703075 5.64 1 +2013-03-01 09:11:58.703075 5.64 1 +2013-03-01 09:11:58.703075 5.64 1 +2013-03-01 09:11:58.703075 5.64 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703077 10.16 1 +2013-03-01 09:11:58.703077 10.16 1 +2013-03-01 09:11:58.703077 10.16 1 +2013-03-01 09:11:58.703077 10.16 1 +2013-03-01 09:11:58.703077 10.16 1 +2013-03-01 09:11:58.703077 10.16 1 +2013-03-01 09:11:58.703078 61.52 1 +2013-03-01 09:11:58.703078 61.52 1 +2013-03-01 09:11:58.703078 61.52 1 +2013-03-01 09:11:58.703078 61.52 1 +2013-03-01 09:11:58.703078 61.52 1 +2013-03-01 09:11:58.703078 61.52 1 +2013-03-01 09:11:58.703079 27.32 1 +2013-03-01 09:11:58.703079 27.32 1 +2013-03-01 09:11:58.703079 27.32 1 +2013-03-01 09:11:58.703079 27.32 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.70308 1.76 1 +2013-03-01 09:11:58.703081 67.9 1 +2013-03-01 09:11:58.703081 67.9 1 +2013-03-01 09:11:58.703081 67.9 1 +2013-03-01 09:11:58.703081 67.9 1 +2013-03-01 09:11:58.703081 67.9 1 +2013-03-01 09:11:58.703082 37.25 1 +2013-03-01 09:11:58.703082 37.25 1 +2013-03-01 09:11:58.703082 37.25 1 +2013-03-01 09:11:58.703082 37.25 1 +2013-03-01 09:11:58.703082 37.25 1 +2013-03-01 09:11:58.703082 37.25 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703083 20.52 1 +2013-03-01 09:11:58.703084 1.76 1 +2013-03-01 09:11:58.703084 1.76 1 +2013-03-01 09:11:58.703084 1.76 1 +2013-03-01 09:11:58.703084 1.76 1 +2013-03-01 09:11:58.703084 1.76 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703085 1.01 1 +2013-03-01 09:11:58.703086 9.96 1 +2013-03-01 09:11:58.703086 9.96 1 +2013-03-01 09:11:58.703086 9.96 1 +2013-03-01 09:11:58.703086 9.96 1 +2013-03-01 09:11:58.703086 9.96 1 +2013-03-01 09:11:58.703087 10.63 1 +2013-03-01 09:11:58.703087 10.63 1 +2013-03-01 09:11:58.703087 10.63 1 +2013-03-01 09:11:58.703087 10.63 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703089 8.76 1 +2013-03-01 09:11:58.703089 8.76 1 +2013-03-01 09:11:58.703089 8.76 1 +2013-03-01 09:11:58.703089 8.76 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.70309 50.99 1 +2013-03-01 09:11:58.703091 15.85 1 +2013-03-01 09:11:58.703091 15.85 1 +2013-03-01 09:11:58.703091 15.85 1 +2013-03-01 09:11:58.703091 15.85 1 +2013-03-01 09:11:58.703092 36.84 1 +2013-03-01 09:11:58.703092 36.84 1 +2013-03-01 09:11:58.703092 36.84 1 +2013-03-01 09:11:58.703092 36.84 1 +2013-03-01 09:11:58.703092 36.84 1 +2013-03-01 09:11:58.703092 36.84 1 +2013-03-01 09:11:58.703093 14.85 1 +2013-03-01 09:11:58.703093 14.85 1 +2013-03-01 09:11:58.703093 14.85 1 +2013-03-01 09:11:58.703094 57.11 1 +2013-03-01 09:11:58.703094 57.11 1 +2013-03-01 09:11:58.703094 57.11 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703095 9.77 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703096 11.64 1 +2013-03-01 09:11:58.703097 0.9 1 +2013-03-01 09:11:58.703097 0.9 1 +2013-03-01 09:11:58.703097 0.9 1 +2013-03-01 09:11:58.703098 1.35 1 +2013-03-01 09:11:58.703098 1.35 1 +2013-03-01 09:11:58.703098 1.35 1 +2013-03-01 09:11:58.703098 1.35 1 +2013-03-01 09:11:58.703098 1.35 1 +2013-03-01 09:11:58.703099 11.69 1 +2013-03-01 09:11:58.703099 11.69 1 +2013-03-01 09:11:58.703099 11.69 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703101 8.72 1 +2013-03-01 09:11:58.703102 63.65 1 +2013-03-01 09:11:58.703102 63.65 1 +2013-03-01 09:11:58.703102 63.65 1 +2013-03-01 09:11:58.703102 63.65 1 +2013-03-01 09:11:58.703102 63.65 1 +2013-03-01 09:11:58.703102 63.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703103 8.65 1 +2013-03-01 09:11:58.703104 2.04 1 +2013-03-01 09:11:58.703104 2.04 1 +2013-03-01 09:11:58.703104 2.04 1 +2013-03-01 09:11:58.703104 2.04 1 +2013-03-01 09:11:58.703104 2.04 1 +2013-03-01 09:11:58.703104 2.04 1 +2013-03-01 09:11:58.703105 28.47 1 +2013-03-01 09:11:58.703105 28.47 1 +2013-03-01 09:11:58.703106 11.81 1 +2013-03-01 09:11:58.703106 11.81 1 +2013-03-01 09:11:58.703106 11.81 1 +2013-03-01 09:11:58.703106 11.81 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703107 16.26 1 +2013-03-01 09:11:58.703108 28.47 1 +2013-03-01 09:11:58.703108 28.47 1 +2013-03-01 09:11:58.703108 28.47 1 +2013-03-01 09:11:58.703108 28.47 1 +2013-03-01 09:11:58.703109 38.98 1 +2013-03-01 09:11:58.703109 38.98 1 +2013-03-01 09:11:58.703109 38.98 1 +2013-03-01 09:11:58.703109 38.98 1 +2013-03-01 09:11:58.70311 8.16 1 +2013-03-01 09:11:58.70311 8.16 1 +2013-03-01 09:11:58.70311 8.16 1 +2013-03-01 09:11:58.70311 8.16 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703111 18.8 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703113 21.8 1 +2013-03-01 09:11:58.703114 73.94 1 +2013-03-01 09:11:58.703114 73.94 1 +2013-03-01 09:11:58.703114 73.94 1 +2013-03-01 09:11:58.703114 73.94 1 +2013-03-01 09:11:58.703114 73.94 1 +2013-03-01 09:11:58.703114 73.94 1 +2013-03-01 09:11:58.703115 27.52 1 +2013-03-01 09:11:58.703115 27.52 1 +2013-03-01 09:11:58.703115 27.52 1 +2013-03-01 09:11:58.703115 27.52 1 +2013-03-01 09:11:58.703115 27.52 1 +2013-03-01 09:11:58.703116 33.45 1 +2013-03-01 09:11:58.703116 33.45 1 +2013-03-01 09:11:58.703116 33.45 1 +2013-03-01 09:11:58.703116 33.45 1 +2013-03-01 09:11:58.703117 21.81 1 +2013-03-01 09:11:58.703117 21.81 1 +2013-03-01 09:11:58.703117 21.81 1 +2013-03-01 09:11:58.703117 21.81 1 +2013-03-01 09:11:58.703117 21.81 1 +2013-03-01 09:11:58.703117 21.81 1 +2013-03-01 09:11:58.703118 8.69 1 +2013-03-01 09:11:58.703118 8.69 1 +2013-03-01 09:11:58.703119 58.02 1 +2013-03-01 09:11:58.703119 58.02 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.703121 96.9 1 +2013-03-01 09:11:58.703121 96.9 1 +2013-03-01 09:11:58.703121 96.9 1 +2013-03-01 09:11:58.703121 96.9 1 +2013-03-01 09:11:58.703122 53.56 1 +2013-03-01 09:11:58.703122 53.56 1 +2013-03-01 09:11:58.703122 53.56 1 +2013-03-01 09:11:58.703122 53.56 1 +2013-03-01 09:11:58.703122 53.56 1 +2013-03-01 09:11:58.703123 94.35 1 +2013-03-01 09:11:58.703123 94.35 1 +2013-03-01 09:11:58.703123 94.35 1 +2013-03-01 09:11:58.703123 94.35 1 +2013-03-01 09:11:58.703124 8.93 1 +2013-03-01 09:11:58.703124 8.93 1 +2013-03-01 09:11:58.703124 8.93 1 +2013-03-01 09:11:58.703125 14.94 1 +2013-03-01 09:11:58.703125 14.94 1 +2013-03-01 09:11:58.703125 14.94 1 +2013-03-01 09:11:58.703126 5.49 1 +2013-03-01 09:11:58.703126 5.49 1 +2013-03-01 09:11:58.703126 5.49 1 +2013-03-01 09:11:58.703126 5.49 1 +2013-03-01 09:11:58.703127 3.98 1 +2013-03-01 09:11:58.703127 3.98 1 +2013-03-01 09:11:58.703127 3.98 1 +2013-03-01 09:11:58.703127 3.98 1 +2013-03-01 09:11:58.703128 11.45 1 +2013-03-01 09:11:58.703128 11.45 1 +2013-03-01 09:11:58.703128 11.45 1 +2013-03-01 09:11:58.703128 11.45 1 +2013-03-01 09:11:58.70313 5.83 1 +2013-03-01 09:11:58.70313 5.83 1 +2013-03-01 09:11:58.70313 5.83 1 +2013-03-01 09:11:58.70313 5.83 1 +2013-03-01 09:11:58.70313 5.83 1 +2013-03-01 09:11:58.70313 5.83 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703131 1.75 1 +2013-03-01 09:11:58.703132 1.86 1 +2013-03-01 09:11:58.703132 1.86 1 +2013-03-01 09:11:58.703133 27.34 1 +2013-03-01 09:11:58.703133 27.34 1 +2013-03-01 09:11:58.703133 27.34 1 +2013-03-01 09:11:58.703133 27.34 1 +2013-03-01 09:11:58.703134 98.9 1 +2013-03-01 09:11:58.703134 98.9 1 +2013-03-01 09:11:58.703134 98.9 1 +2013-03-01 09:11:58.703134 98.9 1 +2013-03-01 09:11:58.703134 98.9 1 +2013-03-01 09:11:58.703135 29.14 1 +2013-03-01 09:11:58.703135 29.14 1 +2013-03-01 09:11:58.703135 29.14 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703136 11.87 1 +2013-03-01 09:11:58.703137 18.11 1 +2013-03-01 09:11:58.703137 18.11 1 +2013-03-01 09:11:58.703137 18.11 1 +2013-03-01 09:11:58.703137 18.11 1 +2013-03-01 09:11:58.703137 18.11 1 +2013-03-01 09:11:58.703138 55.68 1 +2013-03-01 09:11:58.703138 55.68 1 +2013-03-01 09:11:58.703138 55.68 1 +2013-03-01 09:11:58.703138 55.68 1 +2013-03-01 09:11:58.703139 12.67 1 +2013-03-01 09:11:58.703139 12.67 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.70314 2.83 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703141 76.06 1 +2013-03-01 09:11:58.703142 24.25 1 +2013-03-01 09:11:58.703142 24.25 1 +2013-03-01 09:11:58.703142 24.25 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703143 26.86 1 +2013-03-01 09:11:58.703144 3.43 1 +2013-03-01 09:11:58.703144 3.43 1 +2013-03-01 09:11:58.703144 3.43 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703145 8.46 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703146 89.12 1 +2013-03-01 09:11:58.703147 54.94 1 +2013-03-01 09:11:58.703147 54.94 1 +2013-03-01 09:11:58.703147 54.94 1 +2013-03-01 09:11:58.703147 54.94 1 +2013-03-01 09:11:58.703147 54.94 1 +2013-03-01 09:11:58.703148 26.97 1 +2013-03-01 09:11:58.703148 26.97 1 +2013-03-01 09:11:58.703148 26.97 1 +2013-03-01 09:11:58.703148 26.97 1 +2013-03-01 09:11:58.703148 26.97 1 +2013-03-01 09:11:58.703148 26.97 1 +2013-03-01 09:11:58.703149 58.05 1 +2013-03-01 09:11:58.703149 58.05 1 +2013-03-01 09:11:58.703149 58.05 1 +2013-03-01 09:11:58.703149 58.05 1 +2013-03-01 09:11:58.703149 58.05 1 +2013-03-01 09:11:58.703149 58.05 1 +2013-03-01 09:11:58.70315 33.01 1 +2013-03-01 09:11:58.70315 33.01 1 +2013-03-01 09:11:58.70315 33.01 1 +2013-03-01 09:11:58.70315 33.01 1 +2013-03-01 09:11:58.703151 95.69 1 +2013-03-01 09:11:58.703151 95.69 1 +2013-03-01 09:11:58.703151 95.69 1 +2013-03-01 09:11:58.703151 95.69 1 +2013-03-01 09:11:58.703151 95.69 1 +2013-03-01 09:11:58.703152 6.85 1 +2013-03-01 09:11:58.703152 6.85 1 +2013-03-01 09:11:58.703152 6.85 1 +2013-03-01 09:11:58.703152 6.85 1 +2013-03-01 09:11:58.703152 6.85 1 +2013-03-01 09:11:58.703153 4.11 1 +2013-03-01 09:11:58.703153 4.11 1 +2013-03-01 09:11:58.703153 4.11 1 +2013-03-01 09:11:58.703153 4.11 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703155 6.93 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703156 21.79 1 +2013-03-01 09:11:58.703157 1.29 1 +2013-03-01 09:11:58.703157 1.29 1 +2013-03-01 09:11:58.703157 1.29 1 +2013-03-01 09:11:58.703157 1.29 1 +2013-03-01 09:11:58.703157 1.29 1 +2013-03-01 09:11:58.703157 1.29 1 +2013-03-01 09:11:58.703158 71.89 1 +2013-03-01 09:11:58.703158 71.89 1 +2013-03-01 09:11:58.703158 71.89 1 +2013-03-01 09:11:58.703159 64.06 1 +2013-03-01 09:11:58.703159 64.06 1 +2013-03-01 09:11:58.703159 64.06 1 +2013-03-01 09:11:58.703159 64.06 1 +2013-03-01 09:11:58.703159 64.06 1 +2013-03-01 09:11:58.703159 64.06 1 +2013-03-01 09:11:58.70316 11.93 1 +2013-03-01 09:11:58.70316 11.93 1 +2013-03-01 09:11:58.70316 11.93 1 +2013-03-01 09:11:58.70316 11.93 1 +2013-03-01 09:11:58.703161 15.82 1 +2013-03-01 09:11:58.703161 15.82 1 +2013-03-01 09:11:58.703161 15.82 1 +2013-03-01 09:11:58.703161 15.82 1 +2013-03-01 09:11:58.703161 15.82 1 +2013-03-01 09:11:58.703162 3.51 1 +2013-03-01 09:11:58.703162 3.51 1 +2013-03-01 09:11:58.703162 3.51 1 +2013-03-01 09:11:58.703162 3.51 1 +2013-03-01 09:11:58.703162 3.51 1 +2013-03-01 09:11:58.703163 15.7 1 +2013-03-01 09:11:58.703163 15.7 1 +2013-03-01 09:11:58.703163 15.7 1 +2013-03-01 09:11:58.703163 15.7 1 +2013-03-01 09:11:58.703163 15.7 1 +2013-03-01 09:11:58.703163 15.7 1 +2013-03-01 09:11:58.703164 30.27 1 +2013-03-01 09:11:58.703164 30.27 1 +2013-03-01 09:11:58.703164 30.27 1 +2013-03-01 09:11:58.703164 30.27 1 +2013-03-01 09:11:58.703164 30.27 1 +2013-03-01 09:11:58.703164 30.27 1 +2013-03-01 09:11:58.703165 8.38 1 +2013-03-01 09:11:58.703165 8.38 1 +2013-03-01 09:11:58.703165 8.38 1 +2013-03-01 09:11:58.703166 16.6 1 +2013-03-01 09:11:58.703166 16.6 1 +2013-03-01 09:11:58.703166 16.6 1 +2013-03-01 09:11:58.703167 17.66 1 +2013-03-01 09:11:58.703167 17.66 1 +2013-03-01 09:11:58.703167 17.66 1 +2013-03-01 09:11:58.703167 17.66 1 +2013-03-01 09:11:58.703167 17.66 1 +2013-03-01 09:11:58.703167 17.66 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703168 32.03 1 +2013-03-01 09:11:58.703169 39.96 1 +2013-03-01 09:11:58.703169 39.96 1 +2013-03-01 09:11:58.703169 39.96 1 +2013-03-01 09:11:58.703169 39.96 1 +2013-03-01 09:11:58.703169 39.96 1 +2013-03-01 09:11:58.70317 11.44 1 +2013-03-01 09:11:58.70317 11.44 1 +2013-03-01 09:11:58.70317 11.44 1 +2013-03-01 09:11:58.70317 11.44 1 +2013-03-01 09:11:58.70317 11.44 1 +2013-03-01 09:11:58.703171 24.94 1 +2013-03-01 09:11:58.703171 24.94 1 +2013-03-01 09:11:58.703171 24.94 1 +2013-03-01 09:11:58.703171 24.94 1 +2013-03-01 09:11:58.703171 24.94 1 +2013-03-01 09:11:58.703171 24.94 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703172 3.44 1 +2013-03-01 09:11:58.703173 8.77 1 +2013-03-01 09:11:58.703173 8.77 1 +2013-03-01 09:11:58.703173 8.77 1 +2013-03-01 09:11:58.703173 8.77 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703175 33.37 1 +2013-03-01 09:11:58.703175 33.37 1 +2013-03-01 09:11:58.703175 33.37 1 +2013-03-01 09:11:58.703175 33.37 1 +2013-03-01 09:11:58.703175 33.37 1 +2013-03-01 09:11:58.703175 33.37 1 +2013-03-01 09:11:58.703176 28.2 1 +2013-03-01 09:11:58.703176 28.2 1 +2013-03-01 09:11:58.703176 28.2 1 +2013-03-01 09:11:58.703176 28.2 1 +2013-03-01 09:11:58.703176 28.2 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703177 11.43 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703178 9.12 1 +2013-03-01 09:11:58.703179 10.82 1 +2013-03-01 09:11:58.703179 10.82 1 +2013-03-01 09:11:58.703179 10.82 1 +2013-03-01 09:11:58.703179 10.82 1 +2013-03-01 09:11:58.70318 10.28 1 +2013-03-01 09:11:58.70318 10.28 1 +2013-03-01 09:11:58.70318 10.28 1 +2013-03-01 09:11:58.70318 10.28 1 +2013-03-01 09:11:58.70318 10.28 1 +2013-03-01 09:11:58.70318 10.28 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703182 1.23 1 +2013-03-01 09:11:58.703182 1.23 1 +2013-03-01 09:11:58.703182 1.23 1 +2013-03-01 09:11:58.703182 1.23 1 +2013-03-01 09:11:58.703182 1.23 1 +2013-03-01 09:11:58.703183 36.74 1 +2013-03-01 09:11:58.703183 36.74 1 +2013-03-01 09:11:58.703183 36.74 1 +2013-03-01 09:11:58.703183 36.74 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703184 8.95 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703186 13.81 1 +2013-03-01 09:11:58.703186 13.81 1 +2013-03-01 09:11:58.703186 13.81 1 +2013-03-01 09:11:58.703186 13.81 1 +2013-03-01 09:11:58.703187 64.89 1 +2013-03-01 09:11:58.703187 64.89 1 +2013-03-01 09:11:58.703187 64.89 1 +2013-03-01 09:11:58.703187 64.89 1 +2013-03-01 09:11:58.703187 64.89 1 +2013-03-01 09:11:58.703187 64.89 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.703189 36.96 1 +2013-03-01 09:11:58.70319 90.09 1 +2013-03-01 09:11:58.70319 90.09 1 +2013-03-01 09:11:58.70319 90.09 1 +2013-03-01 09:11:58.703192 2.63 1 +2013-03-01 09:11:58.703192 2.63 1 +2013-03-01 09:11:58.703192 2.63 1 +2013-03-01 09:11:58.703192 2.63 1 +2013-03-01 09:11:58.703193 28.42 1 +2013-03-01 09:11:58.703193 28.42 1 +2013-03-01 09:11:58.703193 28.42 1 +2013-03-01 09:11:58.703194 15.07 1 +2013-03-01 09:11:58.703194 15.07 1 +2013-03-01 09:11:58.703194 15.07 1 +2013-03-01 09:11:58.703194 15.07 1 +2013-03-01 09:11:58.703194 15.07 1 +2013-03-01 09:11:58.703194 15.07 1 +2013-03-01 09:11:58.703195 3.81 1 +2013-03-01 09:11:58.703195 3.81 1 +2013-03-01 09:11:58.703195 3.81 1 +2013-03-01 09:11:58.703195 3.81 1 +2013-03-01 09:11:58.703195 3.81 1 +2013-03-01 09:11:58.703195 3.81 1 +2013-03-01 09:11:58.703196 0.08 1 +2013-03-01 09:11:58.703196 0.08 1 +2013-03-01 09:11:58.703196 0.08 1 +2013-03-01 09:11:58.703197 16.01 1 +2013-03-01 09:11:58.703197 16.01 1 +2013-03-01 09:11:58.703197 16.01 1 +2013-03-01 09:11:58.703197 16.01 1 +2013-03-01 09:11:58.703197 16.01 1 +2013-03-01 09:11:58.703197 16.01 1 +2013-03-01 09:11:58.703198 30.6 1 +2013-03-01 09:11:58.703198 30.6 1 +2013-03-01 09:11:58.703198 30.6 1 +2013-03-01 09:11:58.703198 30.6 1 +2013-03-01 09:11:58.703199 45.69 1 +2013-03-01 09:11:58.703199 45.69 1 +2013-03-01 09:11:58.703199 45.69 1 +2013-03-01 09:11:58.703199 45.69 1 +2013-03-01 09:11:58.7032 12.72 1 +2013-03-01 09:11:58.7032 12.72 1 +2013-03-01 09:11:58.703201 35.15 1 +2013-03-01 09:11:58.703201 35.15 1 +2013-03-01 09:11:58.703202 31.41 1 +2013-03-01 09:11:58.703202 31.41 1 +2013-03-01 09:11:58.703202 31.41 1 +2013-03-01 09:11:58.703202 31.41 1 +2013-03-01 09:11:58.703202 31.41 1 +2013-03-01 09:11:58.703203 11.63 1 +2013-03-01 09:11:58.703203 11.63 1 +2013-03-01 09:11:58.703203 11.63 1 +2013-03-01 09:11:58.703203 11.63 1 +2013-03-01 09:11:58.703203 11.63 1 +2013-03-01 09:11:58.703205 35.8 1 +2013-03-01 09:11:58.703205 35.8 1 +2013-03-01 09:11:58.703205 35.8 1 +2013-03-01 09:11:58.703205 35.8 1 +2013-03-01 09:11:58.703205 35.8 1 +2013-03-01 09:11:58.703206 6.61 1 +2013-03-01 09:11:58.703206 6.61 1 +2013-03-01 09:11:58.703206 6.61 1 +2013-03-01 09:11:58.703206 6.61 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703207 21.14 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703208 1.23 1 +2013-03-01 09:11:58.703209 25.92 1 +2013-03-01 09:11:58.703209 25.92 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.703211 5.24 1 +2013-03-01 09:11:58.703211 5.24 1 +2013-03-01 09:11:58.703211 5.24 1 +2013-03-01 09:11:58.703211 5.24 1 +2013-03-01 09:11:58.703211 5.24 1 +2013-03-01 09:11:58.703211 5.24 1 +2013-03-01 09:11:58.703212 10.52 1 +2013-03-01 09:11:58.703212 10.52 1 +2013-03-01 09:11:58.703212 10.52 1 +2013-03-01 09:11:58.703212 10.52 1 +2013-03-01 09:11:58.703212 10.52 1 +2013-03-01 09:11:58.703212 10.52 1 +2013-03-01 09:11:58.703213 38.71 1 +2013-03-01 09:11:58.703213 38.71 1 +2013-03-01 09:11:58.703213 38.71 1 +2013-03-01 09:11:58.703213 38.71 1 +2013-03-01 09:11:58.703214 31.35 1 +2013-03-01 09:11:58.703214 31.35 1 +2013-03-01 09:11:58.703214 31.35 1 +2013-03-01 09:11:58.703215 18.78 1 +2013-03-01 09:11:58.703215 18.78 1 +2013-03-01 09:11:58.703215 18.78 1 +2013-03-01 09:11:58.703215 18.78 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703216 11.23 1 +2013-03-01 09:11:58.703217 23.57 1 +2013-03-01 09:11:58.703217 23.57 1 +2013-03-01 09:11:58.703217 23.57 1 +2013-03-01 09:11:58.703218 9.67 1 +2013-03-01 09:11:58.703218 9.67 1 +2013-03-01 09:11:58.703218 9.67 1 +2013-03-01 09:11:58.703218 9.67 1 +2013-03-01 09:11:58.703218 9.67 1 +2013-03-01 09:11:58.703219 1.42 1 +2013-03-01 09:11:58.703219 1.42 1 +2013-03-01 09:11:58.703219 1.42 1 +2013-03-01 09:11:58.703219 1.42 1 +2013-03-01 09:11:58.703219 1.42 1 +2013-03-01 09:11:58.703219 1.42 1 +2013-03-01 09:11:58.70322 7.37 1 +2013-03-01 09:11:58.70322 7.37 1 +2013-03-01 09:11:58.70322 7.37 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703222 38.23 1 +2013-03-01 09:11:58.703222 38.23 1 +2013-03-01 09:11:58.703222 38.23 1 +2013-03-01 09:11:58.703222 38.23 1 +2013-03-01 09:11:58.703222 38.23 1 +2013-03-01 09:11:58.703222 38.23 1 +2013-03-01 09:11:58.703223 3.43 1 +2013-03-01 09:11:58.703223 3.43 1 +2013-03-01 09:11:58.703223 3.43 1 +2013-03-01 09:11:58.703223 3.43 1 +2013-03-01 09:11:58.703223 3.43 1 +2013-03-01 09:11:58.703224 17.92 1 +2013-03-01 09:11:58.703224 17.92 1 +2013-03-01 09:11:58.703224 17.92 1 +2013-03-01 09:11:58.703224 17.92 1 +2013-03-01 09:11:58.703224 17.92 1 +2013-03-01 09:11:58.703224 17.92 1 +2013-03-01 09:11:58.703225 35.51 1 +2013-03-01 09:11:58.703225 35.51 1 +2013-03-01 09:11:58.703225 35.51 1 +2013-03-01 09:11:58.703225 35.51 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703226 34.31 1 +2013-03-01 09:11:58.703227 17.65 1 +2013-03-01 09:11:58.703227 17.65 1 +2013-03-01 09:11:58.703227 17.65 1 +2013-03-01 09:11:58.703227 17.65 1 +2013-03-01 09:11:58.703227 17.65 1 +2013-03-01 09:11:58.703228 4.19 1 +2013-03-01 09:11:58.703228 4.19 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.703229 88.52 1 +2013-03-01 09:11:58.70323 28.02 1 +2013-03-01 09:11:58.70323 28.02 1 +2013-03-01 09:11:58.70323 28.02 1 +2013-03-01 09:11:58.70323 28.02 1 +2013-03-01 09:11:58.70323 28.02 1 +2013-03-01 09:11:58.703231 11.99 1 +2013-03-01 09:11:58.703231 11.99 1 +2013-03-01 09:11:58.703231 11.99 1 +2013-03-01 09:11:58.703231 11.99 1 +2013-03-01 09:11:58.703231 11.99 1 +2013-03-01 09:11:58.703231 11.99 1 +2013-03-01 09:11:58.703232 61.96 1 +2013-03-01 09:11:58.703232 61.96 1 +2013-03-01 09:11:58.703232 61.96 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703234 44.1 1 +2013-03-01 09:11:58.703234 44.1 1 +2013-03-01 09:11:58.703234 44.1 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703235 6.35 1 +2013-03-01 09:11:58.703236 37.8 1 +2013-03-01 09:11:58.703236 37.8 1 +2013-03-01 09:11:58.703236 37.8 1 +2013-03-01 09:11:58.703236 37.8 1 +2013-03-01 09:11:58.703236 37.8 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703237 0.24 1 +2013-03-01 09:11:58.703238 6 1 +2013-03-01 09:11:58.703238 6 1 +2013-03-01 09:11:58.703238 6 1 +2013-03-01 09:11:58.703238 6 1 +2013-03-01 09:11:58.703239 24.8 1 +2013-03-01 09:11:58.703239 24.8 1 +2013-03-01 09:11:58.703239 24.8 1 +2013-03-01 09:11:58.703239 24.8 1 +2013-03-01 09:11:58.703239 24.8 1 +2013-03-01 09:11:58.70324 5.1 1 +2013-03-01 09:11:58.70324 5.1 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703241 19.33 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703243 6.15 1 +2013-03-01 09:11:58.703243 6.15 1 +2013-03-01 09:11:58.703243 6.15 1 +2013-03-01 09:11:58.703243 6.15 1 +2013-03-01 09:11:58.703244 4.47 1 +2013-03-01 09:11:58.703244 4.47 1 +2013-03-01 09:11:58.703244 4.47 1 +2013-03-01 09:11:58.703245 0.72 1 +2013-03-01 09:11:58.703245 0.72 1 +2013-03-01 09:11:58.703245 0.72 1 +2013-03-01 09:11:58.703245 0.72 1 +2013-03-01 09:11:58.703246 45.94 1 +2013-03-01 09:11:58.703246 45.94 1 +2013-03-01 09:11:58.703247 1.29 1 +2013-03-01 09:11:58.703247 1.29 1 +2013-03-01 09:11:58.703247 1.29 1 +2013-03-01 09:11:58.703247 1.29 1 +2013-03-01 09:11:58.703247 1.29 1 +2013-03-01 09:11:58.703247 1.29 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703248 14.36 1 +2013-03-01 09:11:58.703249 19.42 1 +2013-03-01 09:11:58.703249 19.42 1 +2013-03-01 09:11:58.70325 25.89 1 +2013-03-01 09:11:58.70325 25.89 1 +2013-03-01 09:11:58.70325 25.89 1 +2013-03-01 09:11:58.70325 25.89 1 +2013-03-01 09:11:58.70325 25.89 1 +2013-03-01 09:11:58.70325 25.89 1 +2013-03-01 09:11:58.703251 68.98 1 +2013-03-01 09:11:58.703251 68.98 1 +2013-03-01 09:11:58.703251 68.98 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703252 49.85 1 +2013-03-01 09:11:58.703253 55.75 1 +2013-03-01 09:11:58.703253 55.75 1 +2013-03-01 09:11:58.703253 55.75 1 +2013-03-01 09:11:58.703253 55.75 1 +2013-03-01 09:11:58.703253 55.75 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703255 5.42 1 +2013-03-01 09:11:58.703255 5.42 1 +2013-03-01 09:11:58.703255 5.42 1 +2013-03-01 09:11:58.703255 5.42 1 +2013-03-01 09:11:58.703255 5.42 1 +2013-03-01 09:11:58.703255 5.42 1 +2013-03-01 09:11:58.703256 23.78 1 +2013-03-01 09:11:58.703256 23.78 1 +2013-03-01 09:11:58.703256 23.78 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703257 14.96 1 +2013-03-01 09:11:58.703258 19.65 1 +2013-03-01 09:11:58.703258 19.65 1 +2013-03-01 09:11:58.703258 19.65 1 +2013-03-01 09:11:58.703258 19.65 1 +2013-03-01 09:11:58.703258 19.65 1 +2013-03-01 09:11:58.703258 19.65 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.703259 11.37 1 +2013-03-01 09:11:58.70326 3.81 1 +2013-03-01 09:11:58.70326 3.81 1 +2013-03-01 09:11:58.70326 3.81 1 +2013-03-01 09:11:58.70326 3.81 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703261 8.66 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703262 1.81 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703265 6.88 1 +2013-03-01 09:11:58.703266 47.71 1 +2013-03-01 09:11:58.703266 47.71 1 +2013-03-01 09:11:58.703266 47.71 1 +2013-03-01 09:11:58.703266 47.71 1 +2013-03-01 09:11:58.703267 12.22 1 +2013-03-01 09:11:58.703267 12.22 1 +2013-03-01 09:11:58.703267 12.22 1 +2013-03-01 09:11:58.703267 12.22 1 +2013-03-01 09:11:58.703267 12.22 1 +2013-03-01 09:11:58.703267 12.22 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703268 94.73 1 +2013-03-01 09:11:58.703269 43.84 1 +2013-03-01 09:11:58.703269 43.84 1 +2013-03-01 09:11:58.703269 43.84 1 +2013-03-01 09:11:58.70327 5.01 1 +2013-03-01 09:11:58.70327 5.01 1 +2013-03-01 09:11:58.70327 5.01 1 +2013-03-01 09:11:58.70327 5.01 1 +2013-03-01 09:11:58.703271 61.16 1 +2013-03-01 09:11:58.703271 61.16 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703273 10.94 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703274 10.79 1 +2013-03-01 09:11:58.703275 20.57 1 +2013-03-01 09:11:58.703275 20.57 1 +2013-03-01 09:11:58.703275 20.57 1 +2013-03-01 09:11:58.703276 63.54 1 +2013-03-01 09:11:58.703276 63.54 1 +2013-03-01 09:11:58.703276 63.54 1 +2013-03-01 09:11:58.703276 63.54 1 +2013-03-01 09:11:58.703277 3.37 1 +2013-03-01 09:11:58.703277 3.37 1 +2013-03-01 09:11:58.703277 3.37 1 +2013-03-01 09:11:58.703277 3.37 1 +2013-03-01 09:11:58.703278 9.74 1 +2013-03-01 09:11:58.703278 9.74 1 +2013-03-01 09:11:58.703278 9.74 1 +2013-03-01 09:11:58.703278 9.74 1 +2013-03-01 09:11:58.703278 9.74 1 +2013-03-01 09:11:58.703278 9.74 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.703279 20.85 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.70328 40.68 1 +2013-03-01 09:11:58.703281 19.95 1 +2013-03-01 09:11:58.703281 19.95 1 +2013-03-01 09:11:58.703281 19.95 1 +2013-03-01 09:11:58.703281 19.95 1 +2013-03-01 09:11:58.703282 7.5 1 +2013-03-01 09:11:58.703282 7.5 1 +2013-03-01 09:11:58.703282 7.5 1 +2013-03-01 09:11:58.703282 7.5 1 +2013-03-01 09:11:58.703282 7.5 1 +2013-03-01 09:11:58.703282 7.5 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703284 99.88 1 +2013-03-01 09:11:58.703284 99.88 1 +2013-03-01 09:11:58.703284 99.88 1 +2013-03-01 09:11:58.703284 99.88 1 +2013-03-01 09:11:58.703284 99.88 1 +2013-03-01 09:11:58.703284 99.88 1 +2013-03-01 09:11:58.703285 58.66 1 +2013-03-01 09:11:58.703285 58.66 1 +2013-03-01 09:11:58.703285 58.66 1 +2013-03-01 09:11:58.703285 58.66 1 +2013-03-01 09:11:58.703285 58.66 1 +2013-03-01 09:11:58.703286 9.53 1 +2013-03-01 09:11:58.703286 9.53 1 +2013-03-01 09:11:58.703286 9.53 1 +2013-03-01 09:11:58.703286 9.53 1 +2013-03-01 09:11:58.703287 0.89 1 +2013-03-01 09:11:58.703287 0.89 1 +2013-03-01 09:11:58.703287 0.89 1 +2013-03-01 09:11:58.703288 60.57 1 +2013-03-01 09:11:58.703288 60.57 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.70329 16.89 1 +2013-03-01 09:11:58.70329 16.89 1 +2013-03-01 09:11:58.70329 16.89 1 +2013-03-01 09:11:58.70329 16.89 1 +2013-03-01 09:11:58.70329 16.89 1 +2013-03-01 09:11:58.70329 16.89 1 +2013-03-01 09:11:58.703291 1.15 1 +2013-03-01 09:11:58.703291 1.15 1 +2013-03-01 09:11:58.703291 1.15 1 +2013-03-01 09:11:58.703291 1.15 1 +2013-03-01 09:11:58.703292 4.24 1 +2013-03-01 09:11:58.703292 4.24 1 +2013-03-01 09:11:58.703292 4.24 1 +2013-03-01 09:11:58.703292 4.24 1 +2013-03-01 09:11:58.703293 42.86 1 +2013-03-01 09:11:58.703293 42.86 1 +2013-03-01 09:11:58.703293 42.86 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703295 8.58 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703296 30.43 1 +2013-03-01 09:11:58.703297 25.67 1 +2013-03-01 09:11:58.703297 25.67 1 +2013-03-01 09:11:58.703297 25.67 1 +2013-03-01 09:11:58.703297 25.67 1 +2013-03-01 09:11:58.703297 25.67 1 +2013-03-01 09:11:58.703297 25.67 1 +2013-03-01 09:11:58.703298 8.8 1 +2013-03-01 09:11:58.703298 8.8 1 +2013-03-01 09:11:58.703298 8.8 1 +2013-03-01 09:11:58.703298 8.8 1 +2013-03-01 09:11:58.703299 9 1 +2013-03-01 09:11:58.703299 9 1 +2013-03-01 09:11:58.7033 7.51 1 +2013-03-01 09:11:58.7033 7.51 1 +2013-03-01 09:11:58.7033 7.51 1 +2013-03-01 09:11:58.7033 7.51 1 +2013-03-01 09:11:58.7033 7.51 1 +2013-03-01 09:11:58.703301 39.32 1 +2013-03-01 09:11:58.703301 39.32 1 +2013-03-01 09:11:58.703301 39.32 1 +2013-03-01 09:11:58.703301 39.32 1 +2013-03-01 09:11:58.703301 39.32 1 +2013-03-01 09:11:58.703301 39.32 1 +2013-03-01 09:11:58.703302 23.55 1 +2013-03-01 09:11:58.703302 23.55 1 +2013-03-01 09:11:58.703302 23.55 1 +2013-03-01 09:11:58.703302 23.55 1 +2013-03-01 09:11:58.703302 23.55 1 +2013-03-01 09:11:58.703303 88.64 1 +2013-03-01 09:11:58.703303 88.64 1 +2013-03-01 09:11:58.703303 88.64 1 +2013-03-01 09:11:58.703303 88.64 1 +2013-03-01 09:11:58.703303 88.64 1 +2013-03-01 09:11:58.703304 9.04 1 +2013-03-01 09:11:58.703304 9.04 1 +2013-03-01 09:11:58.703304 9.04 1 +2013-03-01 09:11:58.703304 9.04 1 +2013-03-01 09:11:58.703305 18.68 1 +2013-03-01 09:11:58.703305 18.68 1 +2013-03-01 09:11:58.703305 18.68 1 +2013-03-01 09:11:58.703305 18.68 1 +2013-03-01 09:11:58.703306 3.95 1 +2013-03-01 09:11:58.703306 3.95 1 +2013-03-01 09:11:58.703306 3.95 1 +2013-03-01 09:11:58.703306 3.95 1 +2013-03-01 09:11:58.703307 31.28 1 +2013-03-01 09:11:58.703307 31.28 1 +2013-03-01 09:11:58.703307 31.28 1 +2013-03-01 09:11:58.703308 16.95 1 +2013-03-01 09:11:58.703308 16.95 1 +2013-03-01 09:11:58.703308 16.95 1 +2013-03-01 09:11:58.703308 16.95 1 +2013-03-01 09:11:58.703309 11.16 1 +2013-03-01 09:11:58.703309 11.16 1 +2013-03-01 09:11:58.703309 11.16 1 +2013-03-01 09:11:58.703309 11.16 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.70331 9.24 1 +2013-03-01 09:11:58.703311 7.38 1 +2013-03-01 09:11:58.703311 7.38 1 +2013-03-01 09:11:58.703311 7.38 1 +2013-03-01 09:11:58.703311 7.38 1 +2013-03-01 09:11:58.703311 7.38 1 +2013-03-01 09:11:58.703311 7.38 1 +2013-03-01 09:11:58.703312 18.2 1 +2013-03-01 09:11:58.703312 18.2 1 +2013-03-01 09:11:58.703312 18.2 1 +2013-03-01 09:11:58.703312 18.2 1 +2013-03-01 09:11:58.703312 18.2 1 +2013-03-01 09:11:58.703312 18.2 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703313 9.35 1 +2013-03-01 09:11:58.703314 39.12 1 +2013-03-01 09:11:58.703314 39.12 1 +2013-03-01 09:11:58.703314 39.12 1 +2013-03-01 09:11:58.703314 39.12 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703315 4.93 1 +2013-03-01 09:11:58.703316 16.86 1 +2013-03-01 09:11:58.703316 16.86 1 +2013-03-01 09:11:58.703316 16.86 1 +2013-03-01 09:11:58.703316 16.86 1 +2013-03-01 09:11:58.703316 16.86 1 +2013-03-01 09:11:58.703317 1.45 1 +2013-03-01 09:11:58.703317 1.45 1 +2013-03-01 09:11:58.703317 1.45 1 +2013-03-01 09:11:58.703317 1.45 1 +2013-03-01 09:11:58.703318 21.81 1 +2013-03-01 09:11:58.703318 21.81 1 +2013-03-01 09:11:58.703318 21.81 1 +2013-03-01 09:11:58.703318 21.81 1 +2013-03-01 09:11:58.703319 83.21 1 +2013-03-01 09:11:58.703319 83.21 1 +2013-03-01 09:11:58.703319 83.21 1 +2013-03-01 09:11:58.703319 83.21 1 +2013-03-01 09:11:58.703319 83.21 1 +2013-03-01 09:11:58.70332 77.09 1 +2013-03-01 09:11:58.70332 77.09 1 +2013-03-01 09:11:58.70332 77.09 1 +2013-03-01 09:11:58.70332 77.09 1 +2013-03-01 09:11:58.70332 77.09 1 +2013-03-01 09:11:58.70332 77.09 1 +2013-03-01 09:11:58.703321 3.91 1 +2013-03-01 09:11:58.703321 3.91 1 +2013-03-01 09:11:58.703321 3.91 1 +2013-03-01 09:11:58.703321 3.91 1 +2013-03-01 09:11:58.703322 2.48 1 +2013-03-01 09:11:58.703322 2.48 1 +2013-03-01 09:11:58.703322 2.48 1 +2013-03-01 09:11:58.703322 2.48 1 +2013-03-01 09:11:58.703322 2.48 1 +2013-03-01 09:11:58.703322 2.48 1 +2013-03-01 09:11:58.703323 36.22 1 +2013-03-01 09:11:58.703323 36.22 1 +2013-03-01 09:11:58.703323 36.22 1 +2013-03-01 09:11:58.703324 14.08 1 +2013-03-01 09:11:58.703324 14.08 1 +2013-03-01 09:11:58.703324 14.08 1 +2013-03-01 09:11:58.703324 14.08 1 +2013-03-01 09:11:58.703324 14.08 1 +2013-03-01 09:11:58.703324 14.08 1 +2013-03-01 09:11:58.703325 9.24 1 +2013-03-01 09:11:58.703325 9.24 1 +2013-03-01 09:11:58.703325 9.24 1 +2013-03-01 09:11:58.703325 9.24 1 +2013-03-01 09:11:58.703325 9.24 1 +2013-03-01 09:11:58.703325 9.24 1 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 3-0-58a982694ba2b1e34de82b1de54936a0 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 3-0-58a982694ba2b1e34de82b1de54936a0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 4-0-12cc78f3953c3e6b5411ddc729541bf0 b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 4-0-12cc78f3953c3e6b5411ddc729541bf0 new file mode 100644 index 000000000000..d02ca48857b5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_rank.q (deterministic) 4-0-12cc78f3953c3e6b5411ddc729541bf0 @@ -0,0 +1,474 @@ +2013-03-01 09:11:58.703074 58.47 1 +2013-03-01 09:11:58.703074 58.47 1 +2013-03-01 09:11:58.703074 58.47 1 +2013-03-01 09:11:58.703074 58.47 1 +2013-03-01 09:11:58.703074 58.47 1 +2013-03-01 09:11:58.703074 58.47 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703076 18.51 1 +2013-03-01 09:11:58.703077 66.68 1 +2013-03-01 09:11:58.703077 66.68 1 +2013-03-01 09:11:58.703077 66.68 1 +2013-03-01 09:11:58.703077 66.68 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703087 25.19 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703088 1.97 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703089 41.57 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703091 68.85 1 +2013-03-01 09:11:58.703092 54.02 1 +2013-03-01 09:11:58.703092 54.02 1 +2013-03-01 09:11:58.703092 54.02 1 +2013-03-01 09:11:58.703096 87.84 1 +2013-03-01 09:11:58.703097 0.9 1 +2013-03-01 09:11:58.703097 0.9 1 +2013-03-01 09:11:58.703097 0.9 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703098 21.29 1 +2013-03-01 09:11:58.703104 75.85 1 +2013-03-01 09:11:58.703104 75.85 1 +2013-03-01 09:11:58.703104 75.85 1 +2013-03-01 09:11:58.703104 75.85 1 +2013-03-01 09:11:58.703104 75.85 1 +2013-03-01 09:11:58.70311 65.88 1 +2013-03-01 09:11:58.70311 65.88 1 +2013-03-01 09:11:58.70311 65.88 1 +2013-03-01 09:11:58.70311 65.88 1 +2013-03-01 09:11:58.703111 85.94 1 +2013-03-01 09:11:58.703111 85.94 1 +2013-03-01 09:11:58.703111 85.94 1 +2013-03-01 09:11:58.703111 85.94 1 +2013-03-01 09:11:58.703111 85.94 1 +2013-03-01 09:11:58.703111 85.94 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703112 13.29 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703113 58.65 1 +2013-03-01 09:11:58.703118 8.69 1 +2013-03-01 09:11:58.703118 8.69 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.70312 52.6 1 +2013-03-01 09:11:58.703125 78.52 1 +2013-03-01 09:11:58.703125 78.52 1 +2013-03-01 09:11:58.703125 78.52 1 +2013-03-01 09:11:58.703125 78.52 1 +2013-03-01 09:11:58.703125 78.52 1 +2013-03-01 09:11:58.703125 78.52 1 +2013-03-01 09:11:58.703131 63.81 1 +2013-03-01 09:11:58.703131 63.81 1 +2013-03-01 09:11:58.703131 63.81 1 +2013-03-01 09:11:58.703131 63.81 1 +2013-03-01 09:11:58.703131 63.81 1 +2013-03-01 09:11:58.703131 63.81 1 +2013-03-01 09:11:58.703132 1.86 1 +2013-03-01 09:11:58.703132 1.86 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703135 88.38 1 +2013-03-01 09:11:58.703136 27.89 1 +2013-03-01 09:11:58.703136 27.89 1 +2013-03-01 09:11:58.703136 27.89 1 +2013-03-01 09:11:58.703136 27.89 1 +2013-03-01 09:11:58.703136 27.89 1 +2013-03-01 09:11:58.703138 86.7 1 +2013-03-01 09:11:58.703138 86.7 1 +2013-03-01 09:11:58.703138 86.7 1 +2013-03-01 09:11:58.703138 86.7 1 +2013-03-01 09:11:58.703138 86.7 1 +2013-03-01 09:11:58.703139 43.53 1 +2013-03-01 09:11:58.703139 43.53 1 +2013-03-01 09:11:58.703139 43.53 1 +2013-03-01 09:11:58.703139 43.53 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703144 21.59 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703154 16.93 1 +2013-03-01 09:11:58.703156 62.42 1 +2013-03-01 09:11:58.703156 62.42 1 +2013-03-01 09:11:58.703156 62.42 1 +2013-03-01 09:11:58.703156 62.42 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703157 8.99 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703174 36.33 1 +2013-03-01 09:11:58.703178 93.29 1 +2013-03-01 09:11:58.703178 93.29 1 +2013-03-01 09:11:58.703178 93.29 1 +2013-03-01 09:11:58.703178 93.29 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703179 60.94 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703181 26.6 1 +2013-03-01 09:11:58.703184 73.93 1 +2013-03-01 09:11:58.703184 73.93 1 +2013-03-01 09:11:58.703184 73.93 1 +2013-03-01 09:11:58.703184 73.93 1 +2013-03-01 09:11:58.703184 73.93 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703185 8.91 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703186 91.46 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703188 32.85 1 +2013-03-01 09:11:58.703189 37.74 1 +2013-03-01 09:11:58.703189 37.74 1 +2013-03-01 09:11:58.703189 37.74 1 +2013-03-01 09:11:58.703189 37.74 1 +2013-03-01 09:11:58.703189 37.74 1 +2013-03-01 09:11:58.703189 37.74 1 +2013-03-01 09:11:58.703195 82.5 1 +2013-03-01 09:11:58.703195 82.5 1 +2013-03-01 09:11:58.703195 82.5 1 +2013-03-01 09:11:58.703195 82.5 1 +2013-03-01 09:11:58.703195 82.5 1 +2013-03-01 09:11:58.703195 82.5 1 +2013-03-01 09:11:58.703198 97.18 1 +2013-03-01 09:11:58.703198 97.18 1 +2013-03-01 09:11:58.703198 97.18 1 +2013-03-01 09:11:58.703198 97.18 1 +2013-03-01 09:11:58.703206 80.94 1 +2013-03-01 09:11:58.703206 80.94 1 +2013-03-01 09:11:58.703206 80.94 1 +2013-03-01 09:11:58.703206 80.94 1 +2013-03-01 09:11:58.703206 80.94 1 +2013-03-01 09:11:58.703206 80.94 1 +2013-03-01 09:11:58.703207 55.06 1 +2013-03-01 09:11:58.703207 55.06 1 +2013-03-01 09:11:58.703207 55.06 1 +2013-03-01 09:11:58.703207 55.06 1 +2013-03-01 09:11:58.703207 55.06 1 +2013-03-01 09:11:58.703209 25.92 1 +2013-03-01 09:11:58.703209 25.92 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.70321 37.12 1 +2013-03-01 09:11:58.703213 48.8 1 +2013-03-01 09:11:58.703213 48.8 1 +2013-03-01 09:11:58.703213 48.8 1 +2013-03-01 09:11:58.703213 48.8 1 +2013-03-01 09:11:58.703219 32.73 1 +2013-03-01 09:11:58.703219 32.73 1 +2013-03-01 09:11:58.703219 32.73 1 +2013-03-01 09:11:58.703219 32.73 1 +2013-03-01 09:11:58.703219 32.73 1 +2013-03-01 09:11:58.703219 32.73 1 +2013-03-01 09:11:58.70322 7.37 1 +2013-03-01 09:11:58.70322 7.37 1 +2013-03-01 09:11:58.70322 7.37 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703221 26.64 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703223 57.1 1 +2013-03-01 09:11:58.703224 42.93 1 +2013-03-01 09:11:58.703224 42.93 1 +2013-03-01 09:11:58.703224 42.93 1 +2013-03-01 09:11:58.703224 42.93 1 +2013-03-01 09:11:58.703226 68.3 1 +2013-03-01 09:11:58.703226 68.3 1 +2013-03-01 09:11:58.703226 68.3 1 +2013-03-01 09:11:58.703226 68.3 1 +2013-03-01 09:11:58.703226 68.3 1 +2013-03-01 09:11:58.703226 68.3 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703231 18.7 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703233 40.81 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703242 31.23 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703244 25.67 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703245 32.3 1 +2013-03-01 09:11:58.703246 72.87 1 +2013-03-01 09:11:58.703246 72.87 1 +2013-03-01 09:11:58.703248 81.28 1 +2013-03-01 09:11:58.703248 81.28 1 +2013-03-01 09:11:58.703248 81.28 1 +2013-03-01 09:11:58.703249 93.3 1 +2013-03-01 09:11:58.703249 93.3 1 +2013-03-01 09:11:58.703249 93.3 1 +2013-03-01 09:11:58.703249 93.3 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.70325 93.79 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703254 0.32 1 +2013-03-01 09:11:58.703256 43.8 1 +2013-03-01 09:11:58.703256 43.8 1 +2013-03-01 09:11:58.703256 43.8 1 +2013-03-01 09:11:58.703256 43.8 1 +2013-03-01 09:11:58.703256 43.8 1 +2013-03-01 09:11:58.703256 43.8 1 +2013-03-01 09:11:58.703258 21.21 1 +2013-03-01 09:11:58.703258 21.21 1 +2013-03-01 09:11:58.703258 21.21 1 +2013-03-01 09:11:58.703259 52.28 1 +2013-03-01 09:11:58.703259 52.28 1 +2013-03-01 09:11:58.703259 52.28 1 +2013-03-01 09:11:58.703259 52.28 1 +2013-03-01 09:11:58.703259 52.28 1 +2013-03-01 09:11:58.703259 52.28 1 +2013-03-01 09:11:58.703262 78.56 1 +2013-03-01 09:11:58.703262 78.56 1 +2013-03-01 09:11:58.703262 78.56 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703263 14.4 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703264 52.49 1 +2013-03-01 09:11:58.703265 11.46 1 +2013-03-01 09:11:58.703265 11.46 1 +2013-03-01 09:11:58.703265 11.46 1 +2013-03-01 09:11:58.703265 11.46 1 +2013-03-01 09:11:58.703266 83.67 1 +2013-03-01 09:11:58.703266 83.67 1 +2013-03-01 09:11:58.703266 83.67 1 +2013-03-01 09:11:58.703266 83.67 1 +2013-03-01 09:11:58.703266 83.67 1 +2013-03-01 09:11:58.703269 61.06 1 +2013-03-01 09:11:58.703269 61.06 1 +2013-03-01 09:11:58.703269 61.06 1 +2013-03-01 09:11:58.703269 61.06 1 +2013-03-01 09:11:58.703269 61.06 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703272 7.49 1 +2013-03-01 09:11:58.703273 30.49 1 +2013-03-01 09:11:58.703273 30.49 1 +2013-03-01 09:11:58.703273 30.49 1 +2013-03-01 09:11:58.703275 90.69 1 +2013-03-01 09:11:58.703275 90.69 1 +2013-03-01 09:11:58.703275 90.69 1 +2013-03-01 09:11:58.703275 90.69 1 +2013-03-01 09:11:58.703275 90.69 1 +2013-03-01 09:11:58.703275 90.69 1 +2013-03-01 09:11:58.703276 88.46 1 +2013-03-01 09:11:58.703276 88.46 1 +2013-03-01 09:11:58.703276 88.46 1 +2013-03-01 09:11:58.703276 88.46 1 +2013-03-01 09:11:58.703278 69.42 1 +2013-03-01 09:11:58.703278 69.42 1 +2013-03-01 09:11:58.703278 69.42 1 +2013-03-01 09:11:58.70328 45.81 1 +2013-03-01 09:11:58.70328 45.81 1 +2013-03-01 09:11:58.70328 45.81 1 +2013-03-01 09:11:58.70328 45.81 1 +2013-03-01 09:11:58.70328 45.81 1 +2013-03-01 09:11:58.70328 45.81 1 +2013-03-01 09:11:58.703281 62.11 1 +2013-03-01 09:11:58.703281 62.11 1 +2013-03-01 09:11:58.703281 62.11 1 +2013-03-01 09:11:58.703281 62.11 1 +2013-03-01 09:11:58.703281 62.11 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703283 17.62 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703289 0.56 1 +2013-03-01 09:11:58.703293 42.86 1 +2013-03-01 09:11:58.703293 42.86 1 +2013-03-01 09:11:58.703293 42.86 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703294 29.74 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703296 43.66 1 +2013-03-01 09:11:58.703299 23.19 1 +2013-03-01 09:11:58.703299 23.19 1 +2013-03-01 09:11:58.703299 23.19 1 +2013-03-01 09:11:58.703299 23.19 1 +2013-03-01 09:11:58.703299 23.19 1 +2013-03-01 09:11:58.703309 89.4 1 +2013-03-01 09:11:58.703309 89.4 1 +2013-03-01 09:11:58.703309 89.4 1 +2013-03-01 09:11:58.703309 89.4 1 +2013-03-01 09:11:58.70331 69.26 1 +2013-03-01 09:11:58.70331 69.26 1 +2013-03-01 09:11:58.70331 69.26 1 +2013-03-01 09:11:58.703313 20.69 1 +2013-03-01 09:11:58.703313 20.69 1 +2013-03-01 09:11:58.703313 20.69 1 +2013-03-01 09:11:58.703313 20.69 1 +2013-03-01 09:11:58.703315 53.04 1 +2013-03-01 09:11:58.703315 53.04 1 +2013-03-01 09:11:58.703315 53.04 1 +2013-03-01 09:11:58.703315 53.04 1 +2013-03-01 09:11:58.703318 85.62 1 +2013-03-01 09:11:58.703318 85.62 1 +2013-03-01 09:11:58.703318 85.62 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703323 65.61 1 +2013-03-01 09:11:58.703324 98.36 1 +2013-03-01 09:11:58.703324 98.36 1 +2013-03-01 09:11:58.703324 98.36 1 +2013-03-01 09:11:58.703324 98.36 1 +2013-03-01 09:11:58.703325 65.81 1 +2013-03-01 09:11:58.703325 65.81 1 +2013-03-01 09:11:58.703325 65.81 1 +2013-03-01 09:11:58.703325 65.81 1 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-0-6642a21d87e0401ba1a668ea8b244f0c b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-0-6642a21d87e0401ba1a668ea8b244f0c new file mode 100644 index 000000000000..119dd71df142 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-0-6642a21d87e0401ba1a668ea8b244f0c @@ -0,0 +1,1049 @@ + 65560 + 65560 + 65560 +alice allen 65662 +alice allen 65662 +alice allen 65662 +alice brown 65696 +alice carson 65559 +alice davidson 65547 +alice falkner 65669 +alice garcia 65613 +alice hernandez 65737 +alice hernandez 65737 +alice johnson 65739 +alice king 65660 +alice king 65660 +alice king 65660 +alice laertes 65669 +alice laertes 65669 +alice miller 65590 +alice nixon 65586 +alice nixon 65586 +alice nixon 65586 +alice ovid 65737 +alice polk 65548 +alice quirinius 65636 +alice quirinius 65636 +alice robinson 65606 +alice robinson 65606 +alice steinbeck 65578 +alice steinbeck 65578 +alice steinbeck 65578 +alice underhill 65750 +alice van buren 65562 +alice xylophone 65578 +alice xylophone 65578 +alice xylophone 65578 +alice zipper 65553 +alice zipper 65553 +alice zipper 65553 +bob brown 65584 +bob brown 65584 +bob brown 65584 +bob carson 65713 +bob davidson 65664 +bob davidson 65664 +bob davidson 65664 +bob ellison 65591 +bob ellison 65591 +bob ellison 65591 +bob ellison 65591 +bob falkner 65789 +bob garcia 65585 +bob garcia 65585 +bob garcia 65585 +bob garcia 65585 +bob garcia 65585 +bob hernandez 65557 +bob ichabod 65549 +bob king 65715 +bob king 65715 +bob king 65715 +bob laertes 65602 +bob laertes 65602 +bob miller 65608 +bob ovid 65564 +bob ovid 65564 +bob ovid 65564 +bob ovid 65564 +bob polk 65594 +bob quirinius 65700 +bob steinbeck 65637 +bob van buren 65778 +bob white 65543 +bob white 65543 +bob xylophone 65574 +bob xylophone 65574 +bob young 65556 +bob zipper 65559 +bob zipper 65559 +bob zipper 65559 +calvin allen 65669 +calvin brown 65537 +calvin brown 65537 +calvin brown 65537 +calvin carson 65637 +calvin davidson 65541 +calvin davidson 65541 +calvin ellison 65667 +calvin falkner 65573 +calvin falkner 65573 +calvin falkner 65573 +calvin falkner 65573 +calvin falkner 65573 +calvin falkner 65573 +calvin garcia 65664 +calvin hernandez 65578 +calvin johnson 65731 +calvin laertes 65570 +calvin laertes 65570 +calvin nixon 65654 +calvin nixon 65654 +calvin nixon 65654 +calvin ovid 65554 +calvin ovid 65554 +calvin ovid 65554 +calvin ovid 65554 +calvin polk 65731 +calvin quirinius 65741 +calvin quirinius 65741 +calvin robinson 65581 +calvin steinbeck 65680 +calvin steinbeck 65680 +calvin steinbeck 65680 +calvin thompson 65560 +calvin thompson 65560 +calvin underhill 65732 +calvin van buren 65552 +calvin van buren 65552 +calvin white 65553 +calvin white 65553 +calvin xylophone 65575 +calvin xylophone 65575 +calvin xylophone 65575 +calvin young 65574 +calvin young 65574 +calvin zipper 65669 +calvin zipper 65669 +david allen 65588 +david allen 65588 +david brown 65637 +david brown 65637 +david davidson 65559 +david davidson 65559 +david davidson 65559 +david davidson 65559 +david ellison 65634 +david ellison 65634 +david ellison 65634 +david hernandez 65763 +david ichabod 65699 +david ichabod 65699 +david laertes 65762 +david nixon 65536 +david ovid 65623 +david ovid 65623 +david quirinius 65697 +david quirinius 65697 +david quirinius 65697 +david robinson 65762 +david robinson 65762 +david thompson 65550 +david underhill 65602 +david underhill 65602 +david underhill 65602 +david van buren 65625 +david van buren 65625 +david white 65678 +david xylophone 65537 +david xylophone 65537 +david xylophone 65537 +david young 65551 +david young 65551 +ethan allen 65747 +ethan brown 65539 +ethan brown 65539 +ethan brown 65539 +ethan brown 65539 +ethan brown 65539 +ethan brown 65539 +ethan carson 65742 +ethan ellison 65714 +ethan ellison 65714 +ethan falkner 65577 +ethan falkner 65577 +ethan garcia 65736 +ethan hernandez 65618 +ethan johnson 65536 +ethan king 65614 +ethan laertes 65562 +ethan laertes 65562 +ethan laertes 65562 +ethan laertes 65562 +ethan laertes 65562 +ethan laertes 65562 +ethan laertes 65562 +ethan miller 65712 +ethan nixon 65766 +ethan ovid 65697 +ethan polk 65589 +ethan polk 65589 +ethan polk 65589 +ethan polk 65589 +ethan quirinius 65542 +ethan quirinius 65542 +ethan quirinius 65542 +ethan robinson 65547 +ethan robinson 65547 +ethan underhill 65570 +ethan van buren 65572 +ethan white 65677 +ethan white 65677 +ethan xylophone 65595 +ethan zipper 65593 +ethan zipper 65593 +fred davidson 65552 +fred davidson 65552 +fred davidson 65552 +fred ellison 65548 +fred ellison 65548 +fred ellison 65548 +fred falkner 65637 +fred falkner 65637 +fred falkner 65637 +fred hernandez 65541 +fred ichabod 65572 +fred ichabod 65572 +fred johnson 65758 +fred king 65694 +fred king 65694 +fred laertes 65769 +fred miller 65536 +fred nixon 65560 +fred nixon 65560 +fred nixon 65560 +fred nixon 65560 +fred polk 65603 +fred polk 65603 +fred polk 65603 +fred polk 65603 +fred quirinius 65697 +fred quirinius 65697 +fred robinson 65623 +fred steinbeck 65544 +fred steinbeck 65544 +fred steinbeck 65544 +fred underhill 65629 +fred van buren 65537 +fred van buren 65537 +fred van buren 65537 +fred van buren 65537 +fred white 65589 +fred young 65594 +fred young 65594 +fred zipper 65553 +gabriella allen 65646 +gabriella allen 65646 +gabriella brown 65704 +gabriella brown 65704 +gabriella carson 65586 +gabriella davidson 65565 +gabriella ellison 65706 +gabriella ellison 65706 +gabriella falkner 65623 +gabriella falkner 65623 +gabriella falkner 65623 +gabriella garcia 65571 +gabriella hernandez 65587 +gabriella hernandez 65587 +gabriella ichabod 65559 +gabriella ichabod 65559 +gabriella ichabod 65559 +gabriella ichabod 65559 +gabriella ichabod 65559 +gabriella king 65651 +gabriella king 65651 +gabriella laertes 65781 +gabriella miller 65646 +gabriella ovid 65556 +gabriella ovid 65556 +gabriella polk 65701 +gabriella polk 65701 +gabriella steinbeck 65582 +gabriella steinbeck 65582 +gabriella thompson 65682 +gabriella thompson 65682 +gabriella thompson 65682 +gabriella van buren 65581 +gabriella van buren 65581 +gabriella white 65638 +gabriella young 65699 +gabriella young 65699 +gabriella zipper 65540 +gabriella zipper 65540 +holly allen 65596 +holly brown 65599 +holly brown 65599 +holly falkner 65720 +holly hernandez 65602 +holly hernandez 65602 +holly hernandez 65602 +holly hernandez 65602 +holly ichabod 65711 +holly ichabod 65711 +holly ichabod 65711 +holly johnson 65655 +holly johnson 65655 +holly johnson 65655 +holly king 65549 +holly king 65549 +holly laertes 65664 +holly miller 65653 +holly nixon 65539 +holly nixon 65539 +holly polk 65743 +holly polk 65743 +holly robinson 65564 +holly thompson 65538 +holly thompson 65538 +holly thompson 65538 +holly underhill 65634 +holly underhill 65634 +holly underhill 65634 +holly underhill 65634 +holly van buren 65727 +holly white 65536 +holly white 65536 +holly xylophone 65544 +holly young 65606 +holly young 65606 +holly zipper 65607 +holly zipper 65607 +irene allen 65556 +irene brown 65633 +irene brown 65633 +irene brown 65633 +irene carson 65590 +irene ellison 65659 +irene ellison 65659 +irene falkner 65620 +irene falkner 65620 +irene garcia 65660 +irene garcia 65660 +irene garcia 65660 +irene ichabod 65645 +irene ichabod 65645 +irene johnson 65583 +irene laertes 65664 +irene laertes 65664 +irene laertes 65664 +irene miller 65730 +irene nixon 65631 +irene nixon 65631 +irene nixon 65631 +irene ovid 65691 +irene ovid 65691 +irene ovid 65691 +irene polk 65551 +irene polk 65551 +irene polk 65551 +irene polk 65551 +irene polk 65551 +irene quirinius 65724 +irene quirinius 65724 +irene quirinius 65724 +irene robinson 65554 +irene steinbeck 65683 +irene thompson 65688 +irene underhill 65591 +irene underhill 65591 +irene van buren 65579 +irene van buren 65579 +irene xylophone 65775 +jessica brown 65588 +jessica carson 65553 +jessica carson 65553 +jessica carson 65553 +jessica davidson 65549 +jessica davidson 65549 +jessica davidson 65549 +jessica davidson 65549 +jessica ellison 65567 +jessica ellison 65567 +jessica falkner 65584 +jessica garcia 65676 +jessica garcia 65676 +jessica ichabod 65704 +jessica johnson 65607 +jessica johnson 65607 +jessica miller 65733 +jessica nixon 65590 +jessica nixon 65590 +jessica ovid 65582 +jessica ovid 65582 +jessica polk 65637 +jessica quirinius 65562 +jessica quirinius 65562 +jessica quirinius 65562 +jessica quirinius 65562 +jessica robinson 65576 +jessica thompson 65581 +jessica thompson 65581 +jessica underhill 65656 +jessica underhill 65656 +jessica underhill 65656 +jessica van buren 65615 +jessica white 65544 +jessica white 65544 +jessica white 65544 +jessica white 65544 +jessica white 65544 +jessica xylophone 65562 +jessica young 65623 +jessica young 65623 +jessica zipper 65600 +jessica zipper 65600 +jessica zipper 65600 +katie allen 65542 +katie brown 65590 +katie davidson 65619 +katie ellison 65675 +katie ellison 65675 +katie falkner 65728 +katie garcia 65625 +katie garcia 65625 +katie hernandez 65550 +katie ichabod 65658 +katie ichabod 65658 +katie ichabod 65658 +katie king 65629 +katie king 65629 +katie king 65629 +katie miller 65541 +katie miller 65541 +katie nixon 65669 +katie ovid 65681 +katie polk 65746 +katie polk 65746 +katie robinson 65697 +katie van buren 65643 +katie van buren 65643 +katie white 65620 +katie white 65620 +katie xylophone 65585 +katie young 65644 +katie young 65644 +katie young 65644 +katie zipper 65568 +katie zipper 65568 +luke allen 65547 +luke allen 65547 +luke allen 65547 +luke allen 65547 +luke allen 65547 +luke brown 65719 +luke davidson 65656 +luke davidson 65656 +luke ellison 65582 +luke ellison 65582 +luke ellison 65582 +luke falkner 65589 +luke falkner 65589 +luke garcia 65687 +luke garcia 65687 +luke ichabod 65629 +luke ichabod 65629 +luke johnson 65545 +luke johnson 65545 +luke johnson 65545 +luke laertes 65608 +luke laertes 65608 +luke laertes 65608 +luke laertes 65608 +luke laertes 65608 +luke miller 65752 +luke ovid 65569 +luke ovid 65569 +luke polk 65645 +luke polk 65645 +luke quirinius 65655 +luke robinson 65634 +luke robinson 65634 +luke thompson 65626 +luke underhill 65553 +luke underhill 65553 +luke underhill 65553 +luke van buren 65678 +luke white 65693 +luke xylophone 65597 +luke zipper 65641 +mike allen 65706 +mike brown 65654 +mike carson 65698 +mike carson 65698 +mike carson 65698 +mike davidson 65658 +mike davidson 65658 +mike ellison 65598 +mike ellison 65598 +mike ellison 65598 +mike ellison 65598 +mike ellison 65598 +mike falkner 65609 +mike garcia 65571 +mike garcia 65571 +mike garcia 65571 +mike hernandez 65548 +mike hernandez 65548 +mike ichabod 65621 +mike king 65563 +mike king 65563 +mike king 65563 +mike king 65563 +mike king 65563 +mike king 65563 +mike miller 65549 +mike nixon 65619 +mike nixon 65619 +mike polk 65619 +mike polk 65619 +mike polk 65619 +mike quirinius 65717 +mike steinbeck 65550 +mike steinbeck 65550 +mike steinbeck 65550 +mike steinbeck 65550 +mike van buren 65620 +mike van buren 65620 +mike white 65648 +mike white 65648 +mike white 65648 +mike white 65648 +mike young 65545 +mike young 65545 +mike young 65545 +mike zipper 65552 +mike zipper 65552 +mike zipper 65552 +nick allen 65641 +nick allen 65641 +nick brown 65724 +nick davidson 65601 +nick ellison 65691 +nick ellison 65691 +nick falkner 65583 +nick falkner 65583 +nick garcia 65695 +nick garcia 65695 +nick garcia 65695 +nick ichabod 65572 +nick ichabod 65572 +nick ichabod 65572 +nick johnson 65585 +nick johnson 65585 +nick laertes 65624 +nick miller 65757 +nick nixon 65650 +nick ovid 65719 +nick polk 65716 +nick quirinius 65588 +nick quirinius 65588 +nick robinson 65547 +nick robinson 65547 +nick steinbeck 65689 +nick thompson 65610 +nick underhill 65619 +nick van buren 65603 +nick xylophone 65644 +nick young 65654 +nick young 65654 +nick zipper 65757 +nick zipper 65757 +oscar allen 65644 +oscar brown 65614 +oscar carson 65537 +oscar carson 65537 +oscar carson 65537 +oscar carson 65537 +oscar carson 65537 +oscar davidson 65556 +oscar ellison 65630 +oscar ellison 65630 +oscar falkner 65692 +oscar garcia 65751 +oscar hernandez 65683 +oscar hernandez 65683 +oscar ichabod 65536 +oscar ichabod 65536 +oscar ichabod 65536 +oscar ichabod 65536 +oscar johnson 65645 +oscar johnson 65645 +oscar king 65541 +oscar king 65541 +oscar king 65541 +oscar laertes 65625 +oscar laertes 65625 +oscar laertes 65625 +oscar laertes 65625 +oscar nixon 65596 +oscar ovid 65536 +oscar ovid 65536 +oscar ovid 65536 +oscar polk 65541 +oscar polk 65541 +oscar quirinius 65541 +oscar quirinius 65541 +oscar quirinius 65541 +oscar quirinius 65541 +oscar robinson 65537 +oscar robinson 65537 +oscar robinson 65537 +oscar robinson 65537 +oscar steinbeck 65709 +oscar thompson 65542 +oscar thompson 65542 +oscar thompson 65542 +oscar thompson 65542 +oscar underhill 65626 +oscar van buren 65581 +oscar van buren 65581 +oscar van buren 65581 +oscar white 65552 +oscar white 65552 +oscar white 65552 +oscar white 65552 +oscar xylophone 65773 +oscar xylophone 65773 +oscar xylophone 65773 +oscar zipper 65568 +oscar zipper 65568 +oscar zipper 65568 +priscilla brown 65670 +priscilla brown 65670 +priscilla brown 65670 +priscilla carson 65658 +priscilla carson 65658 +priscilla carson 65658 +priscilla ichabod 65627 +priscilla ichabod 65627 +priscilla johnson 65543 +priscilla johnson 65543 +priscilla johnson 65543 +priscilla johnson 65543 +priscilla johnson 65543 +priscilla king 65646 +priscilla nixon 65564 +priscilla nixon 65564 +priscilla ovid 65541 +priscilla ovid 65541 +priscilla polk 65747 +priscilla quirinius 65672 +priscilla thompson 65654 +priscilla underhill 65715 +priscilla underhill 65715 +priscilla van buren 65607 +priscilla van buren 65607 +priscilla van buren 65607 +priscilla white 65652 +priscilla xylophone 65538 +priscilla xylophone 65538 +priscilla xylophone 65538 +priscilla young 65585 +priscilla young 65585 +priscilla zipper 65622 +priscilla zipper 65622 +quinn allen 65657 +quinn allen 65657 +quinn brown 65691 +quinn brown 65691 +quinn brown 65691 +quinn davidson 65549 +quinn davidson 65549 +quinn davidson 65549 +quinn davidson 65549 +quinn ellison 65705 +quinn ellison 65705 +quinn garcia 65568 +quinn garcia 65568 +quinn garcia 65568 +quinn garcia 65568 +quinn ichabod 65564 +quinn king 65558 +quinn king 65558 +quinn laertes 65542 +quinn laertes 65542 +quinn laertes 65542 +quinn nixon 65659 +quinn ovid 65699 +quinn quirinius 65747 +quinn robinson 65627 +quinn steinbeck 65578 +quinn steinbeck 65578 +quinn thompson 65643 +quinn thompson 65643 +quinn underhill 65549 +quinn underhill 65549 +quinn underhill 65549 +quinn van buren 65725 +quinn young 65647 +quinn zipper 65579 +quinn zipper 65579 +rachel allen 65661 +rachel allen 65661 +rachel brown 65586 +rachel brown 65586 +rachel brown 65586 +rachel brown 65586 +rachel brown 65586 +rachel carson 65677 +rachel carson 65677 +rachel davidson 65755 +rachel ellison 65761 +rachel falkner 65616 +rachel falkner 65616 +rachel falkner 65616 +rachel falkner 65616 +rachel johnson 65658 +rachel king 65604 +rachel king 65604 +rachel laertes 65562 +rachel laertes 65562 +rachel ovid 65721 +rachel ovid 65721 +rachel polk 65686 +rachel quirinius 65787 +rachel robinson 65544 +rachel robinson 65544 +rachel robinson 65544 +rachel thompson 65648 +rachel thompson 65648 +rachel thompson 65648 +rachel underhill 65667 +rachel white 65615 +rachel white 65615 +rachel young 65727 +rachel zipper 65757 +rachel zipper 65757 +sarah carson 65679 +sarah carson 65679 +sarah carson 65679 +sarah ellison 65611 +sarah falkner 65606 +sarah falkner 65606 +sarah garcia 65563 +sarah garcia 65563 +sarah garcia 65563 +sarah ichabod 65667 +sarah ichabod 65667 +sarah johnson 65659 +sarah johnson 65659 +sarah johnson 65659 +sarah johnson 65659 +sarah king 65650 +sarah king 65650 +sarah miller 65557 +sarah ovid 65550 +sarah robinson 65677 +sarah robinson 65677 +sarah steinbeck 65721 +sarah white 65622 +sarah white 65622 +sarah xylophone 65678 +sarah young 65595 +sarah zipper 65550 +tom brown 65593 +tom brown 65593 +tom carson 65539 +tom carson 65539 +tom carson 65539 +tom davidson 65780 +tom ellison 65578 +tom ellison 65578 +tom ellison 65578 +tom falkner 65574 +tom falkner 65574 +tom hernandez 65575 +tom hernandez 65575 +tom ichabod 65588 +tom johnson 65536 +tom johnson 65536 +tom king 65576 +tom laertes 65617 +tom laertes 65617 +tom miller 65594 +tom miller 65594 +tom miller 65594 +tom nixon 65672 +tom ovid 65628 +tom polk 65652 +tom polk 65652 +tom quirinius 65563 +tom quirinius 65563 +tom robinson 65626 +tom robinson 65626 +tom robinson 65626 +tom robinson 65626 +tom steinbeck 65666 +tom van buren 65621 +tom van buren 65621 +tom van buren 65621 +tom white 65548 +tom young 65544 +tom young 65544 +tom zipper 65789 +ulysses brown 65735 +ulysses carson 65602 +ulysses carson 65602 +ulysses carson 65602 +ulysses carson 65602 +ulysses davidson 65750 +ulysses ellison 65575 +ulysses garcia 65666 +ulysses hernandez 65651 +ulysses hernandez 65651 +ulysses hernandez 65651 +ulysses ichabod 65551 +ulysses ichabod 65551 +ulysses johnson 65776 +ulysses king 65649 +ulysses laertes 65691 +ulysses laertes 65691 +ulysses laertes 65691 +ulysses miller 65610 +ulysses miller 65610 +ulysses nixon 65603 +ulysses ovid 65656 +ulysses polk 65563 +ulysses polk 65563 +ulysses polk 65563 +ulysses polk 65563 +ulysses quirinius 65786 +ulysses robinson 65744 +ulysses steinbeck 65611 +ulysses steinbeck 65611 +ulysses thompson 65788 +ulysses underhill 65570 +ulysses underhill 65570 +ulysses underhill 65570 +ulysses underhill 65570 +ulysses underhill 65570 +ulysses underhill 65570 +ulysses underhill 65570 +ulysses van buren 65684 +ulysses white 65654 +ulysses white 65654 +ulysses xylophone 65623 +ulysses xylophone 65623 +ulysses xylophone 65623 +ulysses young 65675 +ulysses young 65675 +ulysses young 65675 +victor allen 65684 +victor allen 65684 +victor brown 65550 +victor brown 65550 +victor brown 65550 +victor brown 65550 +victor davidson 65579 +victor davidson 65579 +victor davidson 65579 +victor ellison 65641 +victor ellison 65641 +victor hernandez 65571 +victor hernandez 65571 +victor hernandez 65571 +victor hernandez 65571 +victor hernandez 65571 +victor johnson 65606 +victor johnson 65606 +victor johnson 65606 +victor king 65721 +victor king 65721 +victor laertes 65638 +victor laertes 65638 +victor miller 65570 +victor nixon 65709 +victor nixon 65709 +victor ovid 65649 +victor polk 65625 +victor quirinius 65620 +victor quirinius 65620 +victor robinson 65596 +victor robinson 65596 +victor steinbeck 65618 +victor steinbeck 65618 +victor steinbeck 65618 +victor thompson 65548 +victor van buren 65664 +victor van buren 65664 +victor white 65548 +victor white 65548 +victor xylophone 65549 +victor xylophone 65549 +victor xylophone 65549 +victor xylophone 65549 +victor xylophone 65549 +victor young 65628 +victor zipper 65743 +wendy allen 65628 +wendy allen 65628 +wendy allen 65628 +wendy brown 65580 +wendy brown 65580 +wendy ellison 65545 +wendy ellison 65545 +wendy falkner 65595 +wendy falkner 65595 +wendy falkner 65595 +wendy garcia 65659 +wendy garcia 65659 +wendy garcia 65659 +wendy garcia 65659 +wendy hernandez 65650 +wendy ichabod 65730 +wendy king 65586 +wendy king 65586 +wendy king 65586 +wendy laertes 65566 +wendy laertes 65566 +wendy laertes 65566 +wendy miller 65582 +wendy miller 65582 +wendy nixon 65611 +wendy nixon 65611 +wendy ovid 65589 +wendy ovid 65589 +wendy polk 65656 +wendy polk 65656 +wendy quirinius 65766 +wendy quirinius 65766 +wendy robinson 65622 +wendy robinson 65622 +wendy robinson 65622 +wendy steinbeck 65612 +wendy thompson 65650 +wendy thompson 65650 +wendy underhill 65662 +wendy underhill 65662 +wendy underhill 65662 +wendy van buren 65680 +wendy van buren 65680 +wendy white 65705 +wendy xylophone 65687 +wendy xylophone 65687 +wendy young 65674 +wendy young 65674 +xavier allen 65611 +xavier allen 65611 +xavier allen 65611 +xavier brown 65600 +xavier brown 65600 +xavier brown 65600 +xavier carson 65731 +xavier carson 65731 +xavier davidson 65644 +xavier davidson 65644 +xavier davidson 65644 +xavier ellison 65541 +xavier ellison 65541 +xavier garcia 65672 +xavier hernandez 65541 +xavier hernandez 65541 +xavier hernandez 65541 +xavier ichabod 65597 +xavier ichabod 65597 +xavier johnson 65654 +xavier johnson 65654 +xavier king 65590 +xavier king 65590 +xavier laertes 65743 +xavier ovid 65788 +xavier polk 65587 +xavier polk 65587 +xavier polk 65587 +xavier polk 65587 +xavier quirinius 65599 +xavier quirinius 65599 +xavier quirinius 65599 +xavier quirinius 65599 +xavier thompson 65608 +xavier underhill 65710 +xavier white 65703 +xavier white 65703 +xavier xylophone 65572 +xavier zipper 65561 +yuri allen 65565 +yuri allen 65565 +yuri brown 65538 +yuri brown 65538 +yuri carson 65670 +yuri carson 65670 +yuri ellison 65570 +yuri ellison 65570 +yuri falkner 65658 +yuri falkner 65658 +yuri garcia 65639 +yuri hernandez 65706 +yuri johnson 65587 +yuri johnson 65587 +yuri johnson 65587 +yuri king 65721 +yuri laertes 65637 +yuri laertes 65637 +yuri nixon 65635 +yuri nixon 65635 +yuri polk 65607 +yuri polk 65607 +yuri polk 65607 +yuri quirinius 65544 +yuri quirinius 65544 +yuri quirinius 65544 +yuri steinbeck 65592 +yuri steinbeck 65592 +yuri thompson 65676 +yuri underhill 65718 +yuri underhill 65718 +yuri white 65659 +yuri xylophone 65714 +zach allen 65667 +zach brown 65559 +zach brown 65559 +zach brown 65559 +zach brown 65559 +zach brown 65559 +zach carson 65572 +zach ellison 65748 +zach falkner 65620 +zach falkner 65620 +zach garcia 65544 +zach garcia 65544 +zach garcia 65544 +zach garcia 65544 +zach ichabod 65599 +zach ichabod 65599 +zach king 65556 +zach king 65556 +zach king 65556 +zach miller 65584 +zach miller 65584 +zach miller 65584 +zach ovid 65578 +zach ovid 65578 +zach ovid 65578 +zach ovid 65578 +zach quirinius 65691 +zach robinson 65599 +zach steinbeck 65602 +zach steinbeck 65602 +zach thompson 65636 +zach thompson 65636 +zach underhill 65573 +zach white 65733 +zach xylophone 65542 +zach xylophone 65542 +zach young 65576 +zach zipper 65579 +zach zipper 65579 +zach zipper 65579 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-1-2bf20f39e6ffef258858f7943a974e7e b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-1-2bf20f39e6ffef258858f7943a974e7e new file mode 100644 index 000000000000..657e81a94f4c --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-1-2bf20f39e6ffef258858f7943a974e7e @@ -0,0 +1,1049 @@ + 9.220000267028809 + 43.72999954223633 + 89.52999877929688 +alice allen 2.7899999618530273 +alice allen 21.450000762939453 +alice allen 73.62999725341797 +alice brown 71.30999755859375 +alice carson 39.029998779296875 +alice davidson 70.3499984741211 +alice falkner 90.25 +alice garcia 48.45000076293945 +alice hernandez 88.16999816894531 +alice hernandez 90.55999755859375 +alice johnson 47.359999895095825 +alice king 19.139999389648438 +alice king 23.170000076293945 +alice king 52.22999954223633 +alice laertes 68.94999694824219 +alice laertes 69.52999877929688 +alice miller 68.95999908447266 +alice nixon 40.0 +alice nixon 48.150001525878906 +alice nixon 79.83000183105469 +alice ovid 9.039999961853027 +alice polk 62.900001525878906 +alice quirinius 37.13999938964844 +alice quirinius 62.29999923706055 +alice robinson 3.934999942779541 +alice robinson 56.099998474121094 +alice steinbeck 38.619998931884766 +alice steinbeck 63.40999794006348 +alice steinbeck 92.37000274658203 +alice underhill 93.17499923706055 +alice van buren 38.939998626708984 +alice xylophone 13.816667238871256 +alice xylophone 43.15999984741211 +alice xylophone 78.20999908447266 +alice zipper 26.43000030517578 +alice zipper 42.47999954223633 +alice zipper 89.93000030517578 +bob brown 8.069999694824219 +bob brown 70.93000030517578 +bob brown 93.08999633789062 +bob carson 50.09000015258789 +bob davidson 1.2899999618530273 +bob davidson 71.93000030517578 +bob davidson 74.72000122070312 +bob ellison 41.34000015258789 +bob ellison 56.584999084472656 +bob ellison 75.02999877929688 +bob ellison 80.30000305175781 +bob falkner 16.989999771118164 +bob garcia 4.460000038146973 +bob garcia 5.400000095367432 +bob garcia 43.46500015258789 +bob garcia 80.30000305175781 +bob garcia 87.56999969482422 +bob hernandez 55.51333363850912 +bob ichabod 82.55999755859375 +bob king 8.789999961853027 +bob king 12.539999961853027 +bob king 39.0099983215332 +bob laertes 0.7900000214576721 +bob laertes 10.670000076293945 +bob miller 61.91999816894531 +bob ovid 46.86000061035156 +bob ovid 62.849998474121094 +bob ovid 88.77999877929688 +bob ovid 97.08999633789062 +bob polk 7.980000019073486 +bob quirinius 48.09499931335449 +bob steinbeck 9.699999809265137 +bob van buren 33.66999816894531 +bob white 45.34000015258789 +bob white 45.349998474121094 +bob xylophone 27.6299991607666 +bob xylophone 29.359999656677246 +bob young 35.16999816894531 +bob zipper 3.819999933242798 +bob zipper 32.07500076293945 +bob zipper 52.63999938964844 +calvin allen 63.119998931884766 +calvin brown 28.110000610351562 +calvin brown 85.9000015258789 +calvin brown 90.19999694824219 +calvin carson 59.42500114440918 +calvin davidson 20.40500020980835 +calvin davidson 86.54000091552734 +calvin ellison 26.489999771118164 +calvin falkner 2.9700000286102295 +calvin falkner 56.040000915527344 +calvin falkner 56.33000183105469 +calvin falkner 80.5999984741211 +calvin falkner 93.61000061035156 +calvin falkner 94.30999755859375 +calvin garcia 41.849998474121094 +calvin hernandez 33.869998931884766 +calvin johnson 66.61000061035156 +calvin laertes 23.1299991607666 +calvin laertes 62.670000076293945 +calvin nixon 9.8100004196167 +calvin nixon 41.20000076293945 +calvin nixon 69.73999786376953 +calvin ovid 69.95999908447266 +calvin ovid 71.26000213623047 +calvin ovid 79.12000274658203 +calvin ovid 84.72000122070312 +calvin polk 65.72000122070312 +calvin quirinius 29.540000915527344 +calvin quirinius 60.98499870300293 +calvin robinson 40.439998626708984 +calvin steinbeck 15.220000267028809 +calvin steinbeck 22.850000381469727 +calvin steinbeck 52.06666819254557 +calvin thompson 8.90999984741211 +calvin thompson 89.22500228881836 +calvin underhill 59.70000076293945 +calvin van buren 34.209999084472656 +calvin van buren 64.0 +calvin white 35.864999771118164 +calvin white 90.69000244140625 +calvin xylophone 21.700000762939453 +calvin xylophone 25.420000076293945 +calvin xylophone 56.810001373291016 +calvin young 39.810001373291016 +calvin young 70.27999941507976 +calvin zipper 9.1899995803833 +calvin zipper 95.37999725341797 +david allen 51.005001068115234 +david allen 51.25 +david brown 32.56499910354614 +david brown 93.63999938964844 +david davidson 1.0800000429153442 +david davidson 54.17499876022339 +david davidson 62.720001220703125 +david davidson 74.1500015258789 +david ellison 47.689998626708984 +david ellison 62.587501525878906 +david ellison 85.2300033569336 +david hernandez 75.0833346048991 +david ichabod 35.763334115346275 +david ichabod 82.55000305175781 +david laertes 76.70999908447266 +david nixon 34.72999954223633 +david ovid 43.915000915527344 +david ovid 58.89999961853027 +david quirinius 23.5649995803833 +david quirinius 29.239999771118164 +david quirinius 79.97000122070312 +david robinson 37.703334172566734 +david robinson 70.55999755859375 +david thompson 43.619998931884766 +david underhill 49.719998359680176 +david underhill 60.36499881744385 +david underhill 60.46666717529297 +david van buren 25.500000476837158 +david van buren 50.27999973297119 +david white 45.189998626708984 +david xylophone 30.465000867843628 +david xylophone 30.465000867843628 +david xylophone 54.34000015258789 +david young 10.25 +david young 19.310000777244568 +ethan allen 23.790000518163044 +ethan brown 15.630000114440918 +ethan brown 22.93666648864746 +ethan brown 27.78000044822693 +ethan brown 52.19333346684774 +ethan brown 73.18000030517578 +ethan brown 82.30000305175781 +ethan carson 57.635000228881836 +ethan ellison 0.2800000011920929 +ethan ellison 81.47000122070312 +ethan falkner 50.02000045776367 +ethan falkner 59.43000030517578 +ethan garcia 26.44499921798706 +ethan hernandez 32.30333264668783 +ethan johnson 90.05000305175781 +ethan king 4.349999904632568 +ethan laertes 54.75 +ethan laertes 54.87999868392944 +ethan laertes 59.209999084472656 +ethan laertes 75.31500053405762 +ethan laertes 76.94499969482422 +ethan laertes 80.4749984741211 +ethan laertes 95.06999969482422 +ethan miller 25.3700008392334 +ethan nixon 50.88999938964844 +ethan ovid 58.4950008392334 +ethan polk 2.3499999046325684 +ethan polk 21.31999969482422 +ethan polk 23.440000534057617 +ethan polk 59.869998931884766 +ethan quirinius 13.650000214576721 +ethan quirinius 70.94500160217285 +ethan quirinius 88.18000030517578 +ethan robinson 67.94000244140625 +ethan robinson 80.58500289916992 +ethan underhill 55.630001068115234 +ethan van buren 37.85499954223633 +ethan white 58.70666631062826 +ethan white 63.41999816894531 +ethan xylophone 57.11000061035156 +ethan zipper 5.914999961853027 +ethan zipper 97.51000213623047 +fred davidson 26.435000896453857 +fred davidson 28.144000816345216 +fred davidson 78.30999755859375 +fred ellison 46.65999984741211 +fred ellison 65.40666516621907 +fred ellison 71.98499870300293 +fred falkner 25.144999980926514 +fred falkner 37.62000068028768 +fred falkner 75.125 +fred hernandez 55.9900016784668 +fred ichabod 37.06999969482422 +fred ichabod 75.83499908447266 +fred johnson 96.08999633789062 +fred king 10.220000267028809 +fred king 68.40500068664551 +fred laertes 42.68750047683716 +fred miller 70.02999877929688 +fred nixon 30.589999198913574 +fred nixon 32.71666653951009 +fred nixon 70.5199966430664 +fred nixon 93.02999877929688 +fred polk 23.959999084472656 +fred polk 47.31999969482422 +fred polk 63.97999954223633 +fred polk 90.12000274658203 +fred quirinius 15.300000190734863 +fred quirinius 27.40999984741211 +fred robinson 64.42000007629395 +fred steinbeck 21.239999771118164 +fred steinbeck 21.30000066757202 +fred steinbeck 65.44333394368489 +fred underhill 85.36999893188477 +fred van buren 22.37499976158142 +fred van buren 44.49000096321106 +fred van buren 45.94999980926514 +fred van buren 56.88999938964844 +fred white 42.329999923706055 +fred young 46.79999923706055 +fred young 72.69999885559082 +fred zipper 50.14000129699707 +gabriella allen 28.040000438690186 +gabriella allen 79.64500045776367 +gabriella brown 26.164999961853027 +gabriella brown 84.83000183105469 +gabriella carson 42.7599983215332 +gabriella davidson 18.946666717529297 +gabriella ellison 48.08000183105469 +gabriella ellison 71.54000091552734 +gabriella falkner 31.609999974568684 +gabriella falkner 51.720001220703125 +gabriella falkner 87.61000061035156 +gabriella garcia 43.0099983215332 +gabriella hernandez 62.22666517893473 +gabriella hernandez 92.9800033569336 +gabriella ichabod 10.729999542236328 +gabriella ichabod 26.639999389648438 +gabriella ichabod 66.36000061035156 +gabriella ichabod 71.12999725341797 +gabriella ichabod 83.76666514078777 +gabriella king 20.670000076293945 +gabriella king 62.89999961853027 +gabriella laertes 62.62499809265137 +gabriella miller 26.043334086736042 +gabriella ovid 44.78000068664551 +gabriella ovid 92.4000015258789 +gabriella polk 35.68000030517578 +gabriella polk 90.22500228881836 +gabriella steinbeck 46.45000076293945 +gabriella steinbeck 66.86999893188477 +gabriella thompson 72.34500122070312 +gabriella thompson 75.73666636149089 +gabriella thompson 94.25 +gabriella van buren 38.935001373291016 +gabriella van buren 48.349998474121094 +gabriella white 55.18000030517578 +gabriella young 32.16333262125651 +gabriella young 59.709999084472656 +gabriella zipper 36.2599983215332 +gabriella zipper 91.62999725341797 +holly allen 63.435001373291016 +holly brown 68.51666641235352 +holly brown 86.08666737874348 +holly falkner 83.82666524251302 +holly hernandez 21.190000534057617 +holly hernandez 24.790000915527344 +holly hernandez 28.649999300638836 +holly hernandez 50.22999954223633 +holly ichabod 48.86333433787028 +holly ichabod 73.77000045776367 +holly ichabod 82.91499710083008 +holly johnson 23.447500228881836 +holly johnson 64.36000061035156 +holly johnson 65.62000274658203 +holly king 35.34499979019165 +holly king 42.310001373291016 +holly laertes 70.59666697184245 +holly miller 35.86000061035156 +holly nixon 43.82499885559082 +holly nixon 78.80749893188477 +holly polk 30.365000247955322 +holly polk 75.96499919891357 +holly robinson 82.70499801635742 +holly thompson 0.07999999821186066 +holly thompson 65.52499842643738 +holly thompson 86.69000244140625 +holly underhill 42.54999923706055 +holly underhill 53.02000045776367 +holly underhill 56.13333384195963 +holly underhill 65.84000015258789 +holly van buren 48.893333435058594 +holly white 26.5633331934611 +holly white 40.5 +holly xylophone 49.55666716893514 +holly young 41.698571750095915 +holly young 75.20999908447266 +holly zipper 79.72999827067058 +holly zipper 81.08666610717773 +irene allen 29.75999927520752 +irene brown 28.596666653951008 +irene brown 47.189998626708984 +irene brown 49.46666622161865 +irene carson 86.64999898274739 +irene ellison 38.255001068115234 +irene ellison 45.71333376566569 +irene falkner 22.079999923706055 +irene falkner 83.44666544596355 +irene garcia 38.93499994277954 +irene garcia 42.84666601816813 +irene garcia 58.43000030517578 +irene ichabod 60.7079984664917 +irene ichabod 64.58000183105469 +irene johnson 26.165000438690186 +irene laertes 21.02999997138977 +irene laertes 40.04499816894531 +irene laertes 47.04333241780599 +irene miller 65.44000244140625 +irene nixon 46.03999996185303 +irene nixon 46.96666653951009 +irene nixon 67.09499931335449 +irene ovid 35.130001068115234 +irene ovid 42.535000801086426 +irene ovid 79.75 +irene polk 0.9800000190734863 +irene polk 35.17500114440918 +irene polk 48.94666576385498 +irene polk 49.6 +irene polk 51.885000228881836 +irene quirinius 33.78000005086263 +irene quirinius 42.610000133514404 +irene quirinius 53.4800001780192 +irene robinson 92.19499969482422 +irene steinbeck 73.28000068664551 +irene thompson 46.27375066280365 +irene underhill 24.75 +irene underhill 57.349998474121094 +irene van buren 50.8799991607666 +irene van buren 74.5625 +irene xylophone 83.53499984741211 +jessica brown 26.185000479221344 +jessica carson 56.22999954223633 +jessica carson 62.20000076293945 +jessica carson 62.2400016784668 +jessica davidson 50.01666768391927 +jessica davidson 63.59499931335449 +jessica davidson 69.26666514078777 +jessica davidson 94.53333282470703 +jessica ellison 11.180000305175781 +jessica ellison 64.2060001373291 +jessica falkner 61.57333437601725 +jessica garcia 38.55250036716461 +jessica garcia 57.00999975204468 +jessica ichabod 32.63250035047531 +jessica johnson 9.5600004196167 +jessica johnson 51.959999084472656 +jessica miller 77.83999633789062 +jessica nixon 40.72249960899353 +jessica nixon 90.06999969482422 +jessica ovid 37.45250064134598 +jessica ovid 59.68000030517578 +jessica polk 49.68000030517578 +jessica quirinius 25.65750002861023 +jessica quirinius 37.64200019836426 +jessica quirinius 54.25500011444092 +jessica quirinius 58.019999186197914 +jessica robinson 42.66333484649658 +jessica thompson 30.40666675567627 +jessica thompson 43.87500023841858 +jessica underhill 43.33333269755045 +jessica underhill 45.639999866485596 +jessica underhill 57.584999084472656 +jessica van buren 67.00000047683716 +jessica white 6.170000106096268 +jessica white 63.32500076293945 +jessica white 65.1450023651123 +jessica white 73.93000030517578 +jessica white 96.62000274658203 +jessica xylophone 69.87500190734863 +jessica young 11.1899995803833 +jessica young 43.369998931884766 +jessica zipper 42.43833335240682 +jessica zipper 46.7450008392334 +jessica zipper 56.97999954223633 +katie allen 55.47666549682617 +katie brown 31.699999809265137 +katie davidson 93.22000122070312 +katie ellison 48.31999933719635 +katie ellison 64.08499892552693 +katie falkner 51.665000915527344 +katie garcia 57.71000099182129 +katie garcia 61.21000051498413 +katie hernandez 41.150000381469724 +katie ichabod 44.243333180745445 +katie ichabod 51.800000508626304 +katie ichabod 69.18799896240235 +katie king 39.83000183105469 +katie king 46.80333296457926 +katie king 51.85000038146973 +katie miller 31.399999618530273 +katie miller 74.77999877929688 +katie nixon 23.190000534057617 +katie ovid 67.94500160217285 +katie polk 26.62750005722046 +katie polk 33.9350004196167 +katie robinson 13.890000343322754 +katie van buren 44.434998512268066 +katie van buren 65.41999816894531 +katie white 37.96500015258789 +katie white 59.223333517710365 +katie xylophone 39.30000019073486 +katie young 36.660000801086426 +katie young 67.78333282470703 +katie young 72.76666577657063 +katie zipper 23.766667087872822 +katie zipper 58.75 +luke allen 50.959999084472656 +luke allen 53.36666742960612 +luke allen 54.63249969482422 +luke allen 57.670000076293945 +luke allen 70.39500045776367 +luke brown 49.595000982284546 +luke davidson 7.050000190734863 +luke davidson 18.87000036239624 +luke ellison 16.25 +luke ellison 32.9519996881485 +luke ellison 71.93500137329102 +luke falkner 21.71999979019165 +luke falkner 31.81250023841858 +luke garcia 18.65499973297119 +luke garcia 41.2300001780192 +luke ichabod 41.25750005245209 +luke ichabod 73.55000114440918 +luke johnson 31.670000076293945 +luke johnson 32.84499979019165 +luke johnson 39.54500102996826 +luke laertes 11.819999694824219 +luke laertes 21.184999227523804 +luke laertes 21.993332862854004 +luke laertes 26.696666717529297 +luke laertes 45.9900016784668 +luke miller 52.350000858306885 +luke ovid 23.804999828338623 +luke ovid 64.30000305175781 +luke polk 41.02499961853027 +luke polk 58.4566650390625 +luke quirinius 40.41999816894531 +luke robinson 48.559998750686646 +luke robinson 56.76499938964844 +luke thompson 78.04333368937175 +luke underhill 34.0166662534078 +luke underhill 47.28999996185303 +luke underhill 59.32000160217285 +luke van buren 59.91999944051107 +luke white 74.19599990844726 +luke xylophone 64.77999925613403 +luke zipper 30.434999465942383 +mike allen 30.539999961853027 +mike brown 69.86833318074544 +mike carson 30.25333309173584 +mike carson 61.33799934387207 +mike carson 89.375 +mike davidson 32.55333391825358 +mike davidson 66.74333318074544 +mike ellison 35.905999755859376 +mike ellison 39.82499885559082 +mike ellison 58.56399993896484 +mike ellison 64.52999877929688 +mike ellison 66.93749856948853 +mike falkner 48.53750002384186 +mike garcia 51.02999973297119 +mike garcia 67.93000030517578 +mike garcia 70.8499984741211 +mike hernandez 37.900001525878906 +mike hernandez 59.45000076293945 +mike ichabod 64.7699966430664 +mike king 36.17800045013428 +mike king 41.69500136375427 +mike king 49.57000017166138 +mike king 59.654998779296875 +mike king 71.57000122070312 +mike king 78.50999927520752 +mike miller 29.570000171661377 +mike nixon 45.029999828338624 +mike nixon 48.429999669392906 +mike polk 30.864000129699708 +mike polk 46.95499873161316 +mike polk 79.55500030517578 +mike quirinius 85.0699971516927 +mike steinbeck 24.267500042915344 +mike steinbeck 43.52500021457672 +mike steinbeck 61.426666577657066 +mike steinbeck 68.46000022888184 +mike van buren 27.639999389648438 +mike van buren 56.16333134969076 +mike white 34.8924994468689 +mike white 43.5566660563151 +mike white 53.689998626708984 +mike white 77.54499864578247 +mike young 34.3319993019104 +mike young 52.8100004196167 +mike young 55.64333359400431 +mike zipper 56.86666742960612 +mike zipper 63.3149995803833 +mike zipper 83.91999816894531 +nick allen 57.086001586914065 +nick allen 60.15400094985962 +nick brown 42.939998626708984 +nick davidson 63.07499885559082 +nick ellison 45.34000015258789 +nick ellison 65.88500022888184 +nick falkner 41.87999868392944 +nick falkner 64.05666732788086 +nick garcia 34.34499979019165 +nick garcia 51.08666737874349 +nick garcia 62.88600044250488 +nick ichabod 20.253333409627277 +nick ichabod 53.635000228881836 +nick ichabod 77.36000061035156 +nick johnson 20.114999175071716 +nick johnson 81.91666666666667 +nick laertes 91.56666819254558 +nick miller 71.5500005086263 +nick nixon 77.04249954223633 +nick ovid 74.62666702270508 +nick polk 39.27500009536743 +nick quirinius 60.79499816894531 +nick quirinius 67.44999694824219 +nick robinson 31.672499418258667 +nick robinson 57.66999816894531 +nick steinbeck 59.15999984741211 +nick thompson 18.88666645685832 +nick underhill 43.009998893737794 +nick van buren 34.720000902811684 +nick xylophone 75.3499984741211 +nick young 0.27000001072883606 +nick young 47.813334147135414 +nick zipper 46.22333272298177 +nick zipper 52.54333209991455 +oscar allen 37.396666844685875 +oscar brown 13.100000381469727 +oscar carson 31.91333230336507 +oscar carson 41.77333331108093 +oscar carson 57.3149995803833 +oscar carson 73.59500122070312 +oscar carson 95.44000244140625 +oscar davidson 75.18000030517578 +oscar ellison 34.04499959945679 +oscar ellison 34.04499959945679 +oscar falkner 61.72000050544739 +oscar garcia 67.4800033569336 +oscar hernandez 41.63333400090536 +oscar hernandez 47.93999986648559 +oscar ichabod 45.839999516805015 +oscar ichabod 68.62000274658203 +oscar ichabod 72.18249797821045 +oscar ichabod 76.69000244140625 +oscar johnson 23.880000114440918 +oscar johnson 65.04000091552734 +oscar king 36.69500017166138 +oscar king 49.7049994468689 +oscar king 67.98399925231934 +oscar laertes 43.616665522257485 +oscar laertes 44.755000591278076 +oscar laertes 45.26666831970215 +oscar laertes 53.710001945495605 +oscar nixon 36.56999937693278 +oscar ovid 45.89999961853027 +oscar ovid 46.93999934196472 +oscar ovid 55.277999591827395 +oscar polk 42.31999969482422 +oscar polk 63.900001525878906 +oscar quirinius 63.81500053405762 +oscar quirinius 66.28428527287075 +oscar quirinius 70.24000295003255 +oscar quirinius 81.26249980926514 +oscar robinson 11.34000015258789 +oscar robinson 47.845001220703125 +oscar robinson 59.74333349863688 +oscar robinson 63.346666971842446 +oscar steinbeck 42.49999976158142 +oscar thompson 38.23500061035156 +oscar thompson 51.469999154408775 +oscar thompson 60.029999542236325 +oscar thompson 63.079999923706055 +oscar underhill 66.97666676839192 +oscar van buren 24.085000872612 +oscar van buren 61.880001068115234 +oscar van buren 72.9533322652181 +oscar white 44.72333272298177 +oscar white 46.60999870300293 +oscar white 54.7599983215332 +oscar white 60.85500144958496 +oscar xylophone 34.946666399637856 +oscar xylophone 39.8299994468689 +oscar xylophone 57.119998931884766 +oscar zipper 28.499999046325684 +oscar zipper 47.46750068664551 +oscar zipper 59.1933339436849 +priscilla brown 47.40400066375732 +priscilla brown 77.1479995727539 +priscilla brown 80.5199966430664 +priscilla carson 7.960000038146973 +priscilla carson 28.480000153183937 +priscilla carson 45.92750144004822 +priscilla ichabod 38.95666758219401 +priscilla ichabod 62.32999928792318 +priscilla johnson 44.04499912261963 +priscilla johnson 50.53750038146973 +priscilla johnson 55.98333485921224 +priscilla johnson 59.64499855041504 +priscilla johnson 89.1500015258789 +priscilla king 50.44666735331217 +priscilla nixon 44.32222270965576 +priscilla nixon 45.267999792099 +priscilla ovid 44.78333361943563 +priscilla ovid 52.72999954223633 +priscilla polk 34.89399948120117 +priscilla quirinius 35.609999895095825 +priscilla thompson 35.16249918937683 +priscilla underhill 68.22000122070312 +priscilla underhill 73.97200012207031 +priscilla van buren 50.47000026702881 +priscilla van buren 51.39500045776367 +priscilla van buren 53.541999435424806 +priscilla white 50.47599992752075 +priscilla xylophone 0.15000000596046448 +priscilla xylophone 41.106666485468544 +priscilla xylophone 63.9574990272522 +priscilla young 0.2900000065565109 +priscilla young 19.866666316986084 +priscilla zipper 32.084999322891235 +priscilla zipper 43.90333366394043 +quinn allen 47.90333382288615 +quinn allen 83.33000183105469 +quinn brown 24.280000686645508 +quinn brown 53.98666508992513 +quinn brown 66.82500171661377 +quinn davidson 40.666666666666664 +quinn davidson 54.095001220703125 +quinn davidson 79.78333282470703 +quinn davidson 92.13000106811523 +quinn ellison 52.714999198913574 +quinn ellison 63.352500438690186 +quinn garcia 20.19000056385994 +quinn garcia 54.60000038146973 +quinn garcia 59.010000824928284 +quinn garcia 68.98999881744385 +quinn ichabod 48.60499930381775 +quinn king 61.27333450317383 +quinn king 81.46000289916992 +quinn laertes 32.08000040054321 +quinn laertes 44.45666694641113 +quinn laertes 49.85499858856201 +quinn nixon 72.2471422467913 +quinn ovid 34.423333168029785 +quinn quirinius 53.165000915527344 +quinn robinson 32.624999046325684 +quinn steinbeck 24.802499771118164 +quinn steinbeck 55.477500915527344 +quinn thompson 50.500000381469725 +quinn thompson 55.68600006103516 +quinn underhill 39.66600060462952 +quinn underhill 41.47666676839193 +quinn underhill 56.580000162124634 +quinn van buren 49.40333207448324 +quinn young 55.59000142415365 +quinn zipper 11.359999974568685 +quinn zipper 48.45000123977661 +rachel allen 45.940001249313354 +rachel allen 85.97999954223633 +rachel brown 33.01999984184901 +rachel brown 34.08250021934509 +rachel brown 37.999999046325684 +rachel brown 41.75000019868215 +rachel brown 53.679999669392906 +rachel carson 43.32400016784668 +rachel carson 66.2233320871989 +rachel davidson 14.220000267028809 +rachel ellison 17.549999833106995 +rachel falkner 56.883334159851074 +rachel falkner 57.5199998219808 +rachel falkner 58.80666637420654 +rachel falkner 70.69428443908691 +rachel johnson 36.22499990463257 +rachel king 50.970001220703125 +rachel king 83.53750133514404 +rachel laertes 42.29857151848929 +rachel laertes 71.65999984741211 +rachel ovid 42.25333329041799 +rachel ovid 47.01749947667122 +rachel polk 64.90333239237468 +rachel quirinius 53.2624990940094 +rachel robinson 40.712857246398926 +rachel robinson 53.092498898506165 +rachel robinson 64.94999694824219 +rachel thompson 24.555000439286232 +rachel thompson 31.460000038146973 +rachel thompson 46.804000282287596 +rachel underhill 47.22333272298177 +rachel white 39.87999979654948 +rachel white 41.83428575311388 +rachel young 75.7966677347819 +rachel zipper 45.794999519983925 +rachel zipper 56.909999179840085 +sarah carson 24.576666196187336 +sarah carson 36.33750060200691 +sarah carson 43.65749907493591 +sarah ellison 37.054999351501465 +sarah falkner 48.58285754067557 +sarah falkner 62.36500072479248 +sarah garcia 33.38000011444092 +sarah garcia 35.513333002726235 +sarah garcia 64.31333287556966 +sarah ichabod 36.10599975585937 +sarah ichabod 45.830000162124634 +sarah johnson 26.464999675750732 +sarah johnson 40.9300012588501 +sarah johnson 43.44000196456909 +sarah johnson 64.24333318074544 +sarah king 49.06999909877777 +sarah king 63.01333363850912 +sarah miller 41.709999084472656 +sarah ovid 63.682499408721924 +sarah robinson 39.196666399637856 +sarah robinson 66.88999938964844 +sarah steinbeck 66.89000034332275 +sarah white 41.42599945068359 +sarah white 52.95249938964844 +sarah xylophone 68.31999969482422 +sarah young 35.92750024795532 +sarah zipper 53.697500705718994 +tom brown 38.37000020345052 +tom brown 44.68000049591065 +tom carson 27.994999766349792 +tom carson 54.25250005722046 +tom carson 62.790000915527344 +tom davidson 38.679999113082886 +tom ellison 33.68600053787232 +tom ellison 46.00666618347168 +tom ellison 67.79666646321614 +tom falkner 55.61800079345703 +tom falkner 58.82500012715658 +tom hernandez 50.52250051498413 +tom hernandez 50.52250051498413 +tom ichabod 24.98399963378906 +tom johnson 34.83750009536743 +tom johnson 73.72399978637695 +tom king 69.98000106811523 +tom laertes 41.97285750934056 +tom laertes 70.40333429972331 +tom miller 43.885000586509705 +tom miller 57.10500144958496 +tom miller 76.20499992370605 +tom nixon 62.43000030517578 +tom ovid 38.096666971842446 +tom polk 51.26750087738037 +tom polk 68.22666676839192 +tom quirinius 37.720001220703125 +tom quirinius 53.20399913787842 +tom robinson 43.44333299001058 +tom robinson 54.637142998831614 +tom robinson 59.34250068664551 +tom robinson 99.1500015258789 +tom steinbeck 51.883334477742515 +tom van buren 28.380000829696655 +tom van buren 35.64999930063883 +tom van buren 54.59000015258789 +tom white 51.970001220703125 +tom young 44.7319995880127 +tom young 53.894999980926514 +tom zipper 55.44000116984049 +ulysses brown 48.72666708628336 +ulysses carson 38.742000579833984 +ulysses carson 45.513333002726235 +ulysses carson 48.75249934196472 +ulysses carson 74.64600067138672 +ulysses davidson 63.20857129778181 +ulysses ellison 68.52666759490967 +ulysses garcia 58.77250051498413 +ulysses hernandez 32.371999168395995 +ulysses hernandez 50.57000102996826 +ulysses hernandez 61.39999961853027 +ulysses ichabod 19.1299991607666 +ulysses ichabod 83.06666692097981 +ulysses johnson 51.485000451405845 +ulysses king 46.98333422342936 +ulysses laertes 29.046666741371155 +ulysses laertes 32.88599967956543 +ulysses laertes 60.12399845123291 +ulysses miller 44.552857535226 +ulysses miller 71.39249873161316 +ulysses nixon 51.300000286102296 +ulysses ovid 29.360000610351562 +ulysses polk 40.74399948120117 +ulysses polk 48.9800017674764 +ulysses polk 57.86249828338623 +ulysses polk 81.21333312988281 +ulysses quirinius 68.41500091552734 +ulysses robinson 69.53999853134155 +ulysses steinbeck 44.61833381652832 +ulysses steinbeck 48.362499713897705 +ulysses thompson 45.063334465026855 +ulysses underhill 30.829999764760334 +ulysses underhill 41.43857192993164 +ulysses underhill 44.08333269755045 +ulysses underhill 55.470001220703125 +ulysses underhill 58.9471435546875 +ulysses underhill 68.1900007724762 +ulysses underhill 78.83333333333333 +ulysses van buren 72.38428633553642 +ulysses white 36.17250043153763 +ulysses white 39.084000778198245 +ulysses xylophone 27.519999504089355 +ulysses xylophone 47.65999937057495 +ulysses xylophone 50.29999923706055 +ulysses young 23.308333079020183 +ulysses young 34.6339994430542 +ulysses young 88.06999969482422 +victor allen 49.43800010681152 +victor allen 56.7299998147147 +victor brown 40.80600037574768 +victor brown 63.5024995803833 +victor brown 71.03500080108643 +victor brown 81.71999931335449 +victor davidson 44.70333290100098 +victor davidson 59.070000076293944 +victor davidson 67.27199935913086 +victor ellison 31.28999964396159 +victor ellison 42.54999923706055 +victor hernandez 44.41333452860514 +victor hernandez 47.20249938964844 +victor hernandez 47.73333215713501 +victor hernandez 51.04999923706055 +victor hernandez 59.2399995803833 +victor johnson 54.868000626564026 +victor johnson 55.22999954223633 +victor johnson 57.41000175476074 +victor king 38.27999997138977 +victor king 49.993333180745445 +victor laertes 40.63500006993612 +victor laertes 80.5999984741211 +victor miller 71.00000190734863 +victor nixon 38.393332640329994 +victor nixon 52.920000076293945 +victor ovid 53.260000228881836 +victor polk 3.0 +victor quirinius 57.81666644414266 +victor quirinius 59.39999923706055 +victor robinson 24.614999771118164 +victor robinson 74.5049991607666 +victor steinbeck 28.862000381946565 +victor steinbeck 36.61000006539481 +victor steinbeck 43.09000015258789 +victor thompson 42.67599925994873 +victor van buren 44.669999877611794 +victor van buren 45.121999740600586 +victor white 53.67999887466431 +victor white 54.45000012715658 +victor xylophone 12.160000324249268 +victor xylophone 26.0 +victor xylophone 31.769999821980793 +victor xylophone 52.31499926249186 +victor xylophone 69.2899996439616 +victor young 64.25833320617676 +victor zipper 65.24999904632568 +wendy allen 34.04999955495199 +wendy allen 36.88199939727783 +wendy allen 44.96000012755394 +wendy brown 45.97833283742269 +wendy brown 52.73857225690569 +wendy ellison 42.91333246231079 +wendy ellison 53.56000073750814 +wendy falkner 47.602500438690186 +wendy falkner 64.9099988937378 +wendy falkner 77.5999984741211 +wendy garcia 37.38571425846645 +wendy garcia 48.76666768391927 +wendy garcia 53.225000858306885 +wendy garcia 63.93999926249186 +wendy hernandez 36.195000648498535 +wendy ichabod 10.56499981880188 +wendy king 37.57500076293945 +wendy king 53.44333457946777 +wendy king 56.319997787475586 +wendy laertes 38.39249920845032 +wendy laertes 60.19999885559082 +wendy laertes 65.30624961853027 +wendy miller 44.273332595825195 +wendy miller 53.5675014257431 +wendy nixon 54.995998764038085 +wendy nixon 64.28250026702881 +wendy ovid 43.80499863624573 +wendy ovid 61.64600105285645 +wendy polk 26.784999758005142 +wendy polk 35.21599998474121 +wendy quirinius 28.75666618347168 +wendy quirinius 60.70000092188517 +wendy robinson 42.5799994468689 +wendy robinson 42.90799944400787 +wendy robinson 43.426000237464905 +wendy steinbeck 43.42333388328552 +wendy thompson 40.352857317243306 +wendy thompson 75.93666712443034 +wendy underhill 24.459999561309814 +wendy underhill 33.440000693003334 +wendy underhill 45.51625019311905 +wendy van buren 43.30333296457926 +wendy van buren 65.58666737874348 +wendy white 39.015000104904175 +wendy xylophone 42.42500034968058 +wendy xylophone 53.981666247049965 +wendy young 27.929999828338623 +wendy young 59.609999656677246 +xavier allen 49.24500061571598 +xavier allen 67.39000034332275 +xavier allen 70.29800033569336 +xavier brown 19.772500306367874 +xavier brown 58.87000111171177 +xavier brown 74.6200008392334 +xavier carson 47.3199987411499 +xavier carson 61.52250027656555 +xavier davidson 48.41999936103821 +xavier davidson 52.70666758219401 +xavier davidson 52.78166747093201 +xavier ellison 41.84999942779541 +xavier ellison 62.80200090408325 +xavier garcia 35.0600004568696 +xavier hernandez 47.45200023651123 +xavier hernandez 49.676666259765625 +xavier hernandez 53.446667432785034 +xavier ichabod 56.70625042915344 +xavier ichabod 60.54799928665161 +xavier johnson 41.16333262125651 +xavier johnson 53.85333331425985 +xavier king 42.72800064086914 +xavier king 66.05333455403645 +xavier laertes 38.47999954223633 +xavier ovid 48.89250057935715 +xavier polk 37.05500018596649 +xavier polk 46.82666703065237 +xavier polk 55.385000705718994 +xavier polk 55.65000057220459 +xavier quirinius 59.62499964237213 +xavier quirinius 60.055998992919925 +xavier quirinius 62.52000045776367 +xavier quirinius 65.2933349609375 +xavier thompson 40.244998931884766 +xavier underhill 26.27800006866455 +xavier white 47.8671429497855 +xavier white 63.38428551810129 +xavier xylophone 49.072500228881836 +xavier zipper 8.204999923706055 +yuri allen 53.61250066757202 +yuri allen 64.86833254496257 +yuri brown 46.57500044504801 +yuri brown 66.75250005722046 +yuri carson 39.40750050544739 +yuri carson 49.01600036621094 +yuri ellison 27.49000017642975 +yuri ellison 70.5933329264323 +yuri falkner 47.23285675048828 +yuri falkner 62.807999801635745 +yuri garcia 43.967499017715454 +yuri hernandez 31.94000039100647 +yuri johnson 21.40666739145915 +yuri johnson 34.02333414554596 +yuri johnson 65.7750015258789 +yuri king 49.47333272298177 +yuri laertes 42.070000648498535 +yuri laertes 60.7549991607666 +yuri nixon 49.87142838750567 +yuri nixon 59.945000330607094 +yuri polk 37.56249952316284 +yuri polk 47.583333333333336 +yuri polk 72.60888735453288 +yuri quirinius 18.62000060081482 +yuri quirinius 51.217501401901245 +yuri quirinius 67.24000072479248 +yuri steinbeck 55.757999420166016 +yuri steinbeck 75.87999725341797 +yuri thompson 36.93499946594238 +yuri underhill 51.533334732055664 +yuri underhill 62.31888887617323 +yuri white 44.34999983651297 +yuri xylophone 25.117499828338623 +zach allen 25.92333350578944 +zach brown 38.3799991607666 +zach brown 47.404998779296875 +zach brown 54.30600090026856 +zach brown 58.970001220703125 +zach brown 65.22499942779541 +zach carson 60.783999633789065 +zach ellison 36.211428437914165 +zach falkner 41.225714683532715 +zach falkner 65.99499940872192 +zach garcia 42.8885714326586 +zach garcia 46.8870005607605 +zach garcia 47.5049991607666 +zach garcia 66.09399967193603 +zach ichabod 40.10166613260905 +zach ichabod 53.16749954223633 +zach king 39.137500405311584 +zach king 48.2825003862381 +zach king 61.18999965985616 +zach miller 44.82800054550171 +zach miller 48.52428477151053 +zach miller 53.593332608540855 +zach ovid 35.19399921447039 +zach ovid 38.35833342870077 +zach ovid 43.87200012207031 +zach ovid 83.01999918619792 +zach quirinius 42.638333320617676 +zach robinson 82.04999923706055 +zach steinbeck 55.86599960327148 +zach steinbeck 67.81428473336356 +zach thompson 29.303333282470703 +zach thompson 46.48999913533529 +zach underhill 48.681429045540945 +zach white 66.60250091552734 +zach xylophone 41.875 +zach xylophone 57.2416664759318 +zach young 73.5999984741211 +zach zipper 58.1480016708374 +zach zipper 60.1825008392334 +zach zipper 62.794999837875366 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-2-16239d2b069789ba99fbac50c4f0724f b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-2-16239d2b069789ba99fbac50c4f0724f new file mode 100644 index 000000000000..6cfa5ad413fa --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-2-16239d2b069789ba99fbac50c4f0724f @@ -0,0 +1,1049 @@ + 65560.0 + 65718.0 + 65740.0 +alice allen 65662.0 +alice allen 65720.0 +alice allen 65758.0 +alice brown 65696.0 +alice carson 65559.0 +alice davidson 65547.0 +alice falkner 65669.0 +alice garcia 65613.0 +alice hernandez 65737.0 +alice hernandez 65784.0 +alice johnson 65739.0 +alice king 65660.0 +alice king 65734.0 +alice king 65738.0 +alice laertes 65669.0 +alice laertes 65671.0 +alice miller 65590.0 +alice nixon 65586.0 +alice nixon 65595.0 +alice nixon 65604.0 +alice ovid 65737.0 +alice polk 65548.0 +alice quirinius 65636.0 +alice quirinius 65728.0 +alice robinson 65606.0 +alice robinson 65789.0 +alice steinbeck 65578.0 +alice steinbeck 65673.0 +alice steinbeck 65786.0 +alice underhill 65750.0 +alice van buren 65562.0 +alice xylophone 65578.0 +alice xylophone 65585.0 +alice xylophone 65599.0 +alice zipper 65553.0 +alice zipper 65662.0 +alice zipper 65766.0 +bob brown 65584.0 +bob brown 65777.0 +bob brown 65783.0 +bob carson 65713.0 +bob davidson 65664.0 +bob davidson 65693.0 +bob davidson 65768.0 +bob ellison 65591.0 +bob ellison 65624.0 +bob ellison 65721.0 +bob ellison 65760.0 +bob falkner 65789.0 +bob garcia 65585.0 +bob garcia 65598.0 +bob garcia 65673.0 +bob garcia 65754.0 +bob garcia 65782.0 +bob hernandez 65557.0 +bob ichabod 65549.0 +bob king 65715.0 +bob king 65757.0 +bob king 65783.0 +bob laertes 65602.0 +bob laertes 65663.0 +bob miller 65608.0 +bob ovid 65564.0 +bob ovid 65619.0 +bob ovid 65686.0 +bob ovid 65726.0 +bob polk 65594.0 +bob quirinius 65700.0 +bob steinbeck 65637.0 +bob van buren 65778.0 +bob white 65543.0 +bob white 65605.0 +bob xylophone 65574.0 +bob xylophone 65666.0 +bob young 65556.0 +bob zipper 65559.0 +bob zipper 65633.0 +bob zipper 65739.0 +calvin allen 65669.0 +calvin brown 65537.0 +calvin brown 65580.0 +calvin brown 65677.0 +calvin carson 65637.0 +calvin davidson 65541.0 +calvin davidson 65564.0 +calvin ellison 65667.0 +calvin falkner 65573.0 +calvin falkner 65596.0 +calvin falkner 65738.0 +calvin falkner 65762.0 +calvin falkner 65778.0 +calvin falkner 65784.0 +calvin garcia 65664.0 +calvin hernandez 65578.0 +calvin johnson 65731.0 +calvin laertes 65570.0 +calvin laertes 65684.0 +calvin nixon 65654.0 +calvin nixon 65724.0 +calvin nixon 65749.0 +calvin ovid 65554.0 +calvin ovid 65643.0 +calvin ovid 65663.0 +calvin ovid 65715.0 +calvin polk 65731.0 +calvin quirinius 65741.0 +calvin quirinius 65769.0 +calvin robinson 65581.0 +calvin steinbeck 65680.0 +calvin steinbeck 65762.0 +calvin steinbeck 65779.0 +calvin thompson 65560.0 +calvin thompson 65640.0 +calvin underhill 65732.0 +calvin van buren 65552.0 +calvin van buren 65771.0 +calvin white 65553.0 +calvin white 65561.0 +calvin xylophone 65575.0 +calvin xylophone 65614.5 +calvin xylophone 65713.0 +calvin young 65574.0 +calvin young 65746.0 +calvin zipper 65669.0 +calvin zipper 65739.0 +david allen 65588.0 +david allen 65617.0 +david brown 65637.0 +david brown 65760.0 +david davidson 65559.0 +david davidson 65756.0 +david davidson 65778.0 +david davidson 65779.0 +david ellison 65634.0 +david ellison 65724.0 +david ellison 65724.0 +david hernandez 65763.0 +david ichabod 65699.0 +david ichabod 65715.0 +david laertes 65762.0 +david nixon 65536.0 +david ovid 65623.0 +david ovid 65628.0 +david quirinius 65697.0 +david quirinius 65759.0 +david quirinius 65779.0 +david robinson 65762.0 +david robinson 65775.0 +david thompson 65550.0 +david underhill 65602.0 +david underhill 65662.0 +david underhill 65751.0 +david van buren 65625.0 +david van buren 65634.0 +david white 65678.0 +david xylophone 65537.0 +david xylophone 65670.0 +david xylophone 65764.0 +david young 65551.0 +david young 65694.0 +ethan allen 65747.0 +ethan brown 65539.0 +ethan brown 65617.0 +ethan brown 65685.0 +ethan brown 65685.0 +ethan brown 65722.0 +ethan brown 65733.0 +ethan carson 65742.0 +ethan ellison 65714.0 +ethan ellison 65732.0 +ethan falkner 65577.0 +ethan falkner 65614.0 +ethan garcia 65736.0 +ethan hernandez 65630.5 +ethan johnson 65536.0 +ethan king 65614.0 +ethan laertes 65562.0 +ethan laertes 65597.0 +ethan laertes 65628.0 +ethan laertes 65643.0 +ethan laertes 65680.0 +ethan laertes 65745.0 +ethan laertes 65760.0 +ethan miller 65712.0 +ethan nixon 65766.0 +ethan ovid 65697.0 +ethan polk 65589.0 +ethan polk 65615.0 +ethan polk 65622.0 +ethan polk 65622.0 +ethan quirinius 65542.0 +ethan quirinius 65591.0 +ethan quirinius 65706.0 +ethan robinson 65547.0 +ethan robinson 65659.0 +ethan underhill 65570.0 +ethan van buren 65572.0 +ethan white 65677.0 +ethan white 65788.0 +ethan xylophone 65595.0 +ethan zipper 65593.0 +ethan zipper 65680.0 +fred davidson 65552.0 +fred davidson 65595.0 +fred davidson 65721.0 +fred ellison 65548.0 +fred ellison 65691.0 +fred ellison 65771.0 +fred falkner 65637.0 +fred falkner 65648.0 +fred falkner 65783.0 +fred hernandez 65541.0 +fred ichabod 65572.0 +fred ichabod 65789.0 +fred johnson 65758.0 +fred king 65694.0 +fred king 65745.0 +fred laertes 65769.0 +fred miller 65536.0 +fred nixon 65560.0 +fred nixon 65612.0 +fred nixon 65703.0 +fred nixon 65705.0 +fred polk 65603.0 +fred polk 65656.0 +fred polk 65701.0 +fred polk 65706.0 +fred quirinius 65697.0 +fred quirinius 65782.0 +fred robinson 65623.0 +fred steinbeck 65544.0 +fred steinbeck 65651.0 +fred steinbeck 65755.0 +fred underhill 65629.0 +fred van buren 65537.0 +fred van buren 65561.0 +fred van buren 65745.0 +fred van buren 65789.0 +fred white 65589.0 +fred young 65594.0 +fred young 65773.0 +fred zipper 65553.0 +gabriella allen 65646.0 +gabriella allen 65677.0 +gabriella brown 65704.0 +gabriella brown 65753.0 +gabriella carson 65586.0 +gabriella davidson 65565.0 +gabriella ellison 65706.0 +gabriella ellison 65716.0 +gabriella falkner 65623.0 +gabriella falkner 65711.0 +gabriella falkner 65767.0 +gabriella garcia 65571.0 +gabriella hernandez 65587.0 +gabriella hernandez 65717.0 +gabriella ichabod 65559.0 +gabriella ichabod 65633.0 +gabriella ichabod 65702.0 +gabriella ichabod 65712.0 +gabriella ichabod 65717.0 +gabriella king 65651.0 +gabriella king 65657.0 +gabriella laertes 65781.0 +gabriella miller 65646.0 +gabriella ovid 65556.0 +gabriella ovid 65583.0 +gabriella polk 65701.0 +gabriella polk 65790.0 +gabriella steinbeck 65582.0 +gabriella steinbeck 65653.0 +gabriella thompson 65682.0 +gabriella thompson 65755.0 +gabriella thompson 65766.0 +gabriella van buren 65581.0 +gabriella van buren 65644.0 +gabriella white 65638.0 +gabriella young 65699.0 +gabriella young 65774.0 +gabriella zipper 65540.0 +gabriella zipper 65754.0 +holly allen 65596.0 +holly brown 65599.0 +holly brown 65619.0 +holly falkner 65720.0 +holly hernandez 65602.0 +holly hernandez 65686.0 +holly hernandez 65750.0 +holly hernandez 65788.0 +holly ichabod 65711.0 +holly ichabod 65749.0 +holly ichabod 65752.0 +holly johnson 65655.0 +holly johnson 65662.0 +holly johnson 65755.0 +holly king 65549.0 +holly king 65648.0 +holly laertes 65664.0 +holly miller 65653.0 +holly nixon 65539.0 +holly nixon 65658.0 +holly polk 65743.0 +holly polk 65751.0 +holly robinson 65564.0 +holly thompson 65538.0 +holly thompson 65578.0 +holly thompson 65713.0 +holly underhill 65634.0 +holly underhill 65654.0 +holly underhill 65721.0 +holly underhill 65759.0 +holly van buren 65727.0 +holly white 65536.0 +holly white 65602.0 +holly xylophone 65544.0 +holly young 65606.0 +holly young 65765.0 +holly zipper 65607.0 +holly zipper 65755.0 +irene allen 65556.0 +irene brown 65633.0 +irene brown 65664.0 +irene brown 65765.0 +irene carson 65590.0 +irene ellison 65659.0 +irene ellison 65696.0 +irene falkner 65620.0 +irene falkner 65661.0 +irene garcia 65660.0 +irene garcia 65711.0 +irene garcia 65787.0 +irene ichabod 65645.0 +irene ichabod 65722.0 +irene johnson 65583.0 +irene laertes 65664.0 +irene laertes 65710.0 +irene laertes 65722.0 +irene miller 65730.0 +irene nixon 65631.0 +irene nixon 65643.0 +irene nixon 65653.0 +irene ovid 65691.0 +irene ovid 65734.0 +irene ovid 65753.0 +irene polk 65551.0 +irene polk 65575.0 +irene polk 65579.0 +irene polk 65595.0 +irene polk 65610.0 +irene quirinius 65724.0 +irene quirinius 65769.0 +irene quirinius 65773.0 +irene robinson 65554.0 +irene steinbeck 65683.0 +irene thompson 65688.0 +irene underhill 65591.0 +irene underhill 65707.5 +irene van buren 65579.0 +irene van buren 65589.0 +irene xylophone 65775.0 +jessica brown 65588.0 +jessica carson 65553.0 +jessica carson 65672.0 +jessica carson 65747.0 +jessica davidson 65549.0 +jessica davidson 65606.0 +jessica davidson 65675.0 +jessica davidson 65727.0 +jessica ellison 65567.0 +jessica ellison 65663.0 +jessica falkner 65584.0 +jessica garcia 65676.0 +jessica garcia 65789.0 +jessica ichabod 65704.0 +jessica johnson 65607.0 +jessica johnson 65720.0 +jessica miller 65733.0 +jessica nixon 65590.0 +jessica nixon 65774.0 +jessica ovid 65582.0 +jessica ovid 65751.0 +jessica polk 65637.0 +jessica quirinius 65562.0 +jessica quirinius 65608.0 +jessica quirinius 65712.0 +jessica quirinius 65716.0 +jessica robinson 65576.0 +jessica thompson 65581.0 +jessica thompson 65675.0 +jessica underhill 65656.0 +jessica underhill 65702.0 +jessica underhill 65783.0 +jessica van buren 65615.0 +jessica white 65544.0 +jessica white 65570.0 +jessica white 65594.0 +jessica white 65673.0 +jessica white 65779.0 +jessica xylophone 65562.0 +jessica young 65623.0 +jessica young 65711.0 +jessica zipper 65600.0 +jessica zipper 65657.0 +jessica zipper 65778.0 +katie allen 65542.0 +katie brown 65590.0 +katie davidson 65619.0 +katie ellison 65675.0 +katie ellison 65699.0 +katie falkner 65728.0 +katie garcia 65625.0 +katie garcia 65747.0 +katie hernandez 65550.0 +katie ichabod 65658.0 +katie ichabod 65726.0 +katie ichabod 65757.0 +katie king 65629.0 +katie king 65647.0 +katie king 65776.0 +katie miller 65541.0 +katie miller 65661.0 +katie nixon 65669.0 +katie ovid 65681.0 +katie polk 65746.0 +katie polk 65784.0 +katie robinson 65697.0 +katie van buren 65643.0 +katie van buren 65730.0 +katie white 65620.0 +katie white 65719.0 +katie xylophone 65585.0 +katie young 65644.0 +katie young 65746.0 +katie young 65764.0 +katie zipper 65568.0 +katie zipper 65733.0 +luke allen 65547.0 +luke allen 65552.0 +luke allen 65576.0 +luke allen 65681.0 +luke allen 65776.0 +luke brown 65719.0 +luke davidson 65656.0 +luke davidson 65791.0 +luke ellison 65582.0 +luke ellison 65664.0 +luke ellison 65779.0 +luke falkner 65589.0 +luke falkner 65618.0 +luke garcia 65687.0 +luke garcia 65778.0 +luke ichabod 65629.0 +luke ichabod 65654.0 +luke johnson 65545.0 +luke johnson 65716.0 +luke johnson 65718.0 +luke laertes 65608.0 +luke laertes 65657.0 +luke laertes 65685.0 +luke laertes 65730.0 +luke laertes 65756.0 +luke miller 65752.0 +luke ovid 65569.0 +luke ovid 65693.0 +luke polk 65645.0 +luke polk 65658.0 +luke quirinius 65655.0 +luke robinson 65634.0 +luke robinson 65772.0 +luke thompson 65626.0 +luke underhill 65553.0 +luke underhill 65571.0 +luke underhill 65651.0 +luke van buren 65678.0 +luke white 65693.0 +luke xylophone 65597.0 +luke zipper 65641.0 +mike allen 65706.0 +mike brown 65654.0 +mike carson 65698.0 +mike carson 65700.0 +mike carson 65751.0 +mike davidson 65658.0 +mike davidson 65759.0 +mike ellison 65598.0 +mike ellison 65606.0 +mike ellison 65718.0 +mike ellison 65738.0 +mike ellison 65760.0 +mike falkner 65609.0 +mike garcia 65571.0 +mike garcia 65600.0 +mike garcia 65770.0 +mike hernandez 65548.0 +mike hernandez 65672.0 +mike ichabod 65621.0 +mike king 65563.0 +mike king 65586.0 +mike king 65591.0 +mike king 65642.0 +mike king 65769.0 +mike king 65776.0 +mike miller 65549.0 +mike nixon 65619.0 +mike nixon 65704.0 +mike polk 65619.0 +mike polk 65658.0 +mike polk 65704.0 +mike quirinius 65717.0 +mike steinbeck 65550.0 +mike steinbeck 65564.0 +mike steinbeck 65573.0 +mike steinbeck 65749.0 +mike van buren 65620.0 +mike van buren 65770.0 +mike white 65648.0 +mike white 65685.0 +mike white 65769.0 +mike white 65778.0 +mike young 65545.0 +mike young 65581.0 +mike young 65736.0 +mike zipper 65552.0 +mike zipper 65695.0 +mike zipper 65779.0 +nick allen 65641.0 +nick allen 65786.0 +nick brown 65724.0 +nick davidson 65601.0 +nick ellison 65691.0 +nick ellison 65745.0 +nick falkner 65583.0 +nick falkner 65676.0 +nick garcia 65712.0 +nick garcia 65720.0 +nick garcia 65723.0 +nick ichabod 65572.0 +nick ichabod 65681.0 +nick ichabod 65737.0 +nick johnson 65585.0 +nick johnson 65784.0 +nick laertes 65624.0 +nick miller 65757.0 +nick nixon 65650.0 +nick ovid 65719.0 +nick polk 65716.0 +nick quirinius 65588.0 +nick quirinius 65723.0 +nick robinson 65547.0 +nick robinson 65675.0 +nick steinbeck 65689.0 +nick thompson 65610.0 +nick underhill 65619.0 +nick van buren 65603.0 +nick xylophone 65644.0 +nick young 65654.0 +nick young 65660.0 +nick zipper 65757.0 +nick zipper 65765.0 +oscar allen 65644.0 +oscar brown 65614.0 +oscar carson 65537.0 +oscar carson 65548.0 +oscar carson 65549.0 +oscar carson 65624.0 +oscar carson 65697.0 +oscar davidson 65556.0 +oscar ellison 65630.0 +oscar ellison 65630.0 +oscar falkner 65692.0 +oscar garcia 65751.0 +oscar hernandez 65683.0 +oscar hernandez 65707.0 +oscar ichabod 65536.0 +oscar ichabod 65562.0 +oscar ichabod 65637.0 +oscar ichabod 65763.0 +oscar johnson 65645.0 +oscar johnson 65778.0 +oscar king 65541.0 +oscar king 65550.0 +oscar king 65787.0 +oscar laertes 65625.0 +oscar laertes 65690.0 +oscar laertes 65756.0 +oscar laertes 65790.0 +oscar nixon 65596.0 +oscar ovid 65536.0 +oscar ovid 65615.0 +oscar ovid 65665.5 +oscar polk 65541.0 +oscar polk 65643.0 +oscar quirinius 65541.0 +oscar quirinius 65560.0 +oscar quirinius 65689.0 +oscar quirinius 65720.0 +oscar robinson 65537.0 +oscar robinson 65658.0 +oscar robinson 65687.0 +oscar robinson 65782.0 +oscar steinbeck 65709.0 +oscar thompson 65542.0 +oscar thompson 65681.0 +oscar thompson 65727.0 +oscar thompson 65738.0 +oscar underhill 65626.0 +oscar van buren 65581.0 +oscar van buren 65635.0 +oscar van buren 65705.0 +oscar white 65552.0 +oscar white 65564.0 +oscar white 65671.0 +oscar white 65735.0 +oscar xylophone 65773.0 +oscar xylophone 65773.0 +oscar xylophone 65775.0 +oscar zipper 65568.0 +oscar zipper 65740.0 +oscar zipper 65777.0 +priscilla brown 65670.0 +priscilla brown 65690.0 +priscilla brown 65749.0 +priscilla carson 65658.0 +priscilla carson 65687.0 +priscilla carson 65755.0 +priscilla ichabod 65627.0 +priscilla ichabod 65759.0 +priscilla johnson 65543.0 +priscilla johnson 65668.0 +priscilla johnson 65674.5 +priscilla johnson 65681.0 +priscilla johnson 65755.0 +priscilla king 65646.0 +priscilla nixon 65564.0 +priscilla nixon 65600.0 +priscilla ovid 65541.0 +priscilla ovid 65790.0 +priscilla polk 65747.0 +priscilla quirinius 65672.0 +priscilla thompson 65654.0 +priscilla underhill 65715.0 +priscilla underhill 65729.0 +priscilla van buren 65607.0 +priscilla van buren 65685.0 +priscilla van buren 65749.0 +priscilla white 65652.0 +priscilla xylophone 65538.0 +priscilla xylophone 65763.0 +priscilla xylophone 65774.0 +priscilla young 65585.0 +priscilla young 65658.0 +priscilla zipper 65622.0 +priscilla zipper 65726.0 +quinn allen 65657.0 +quinn allen 65708.0 +quinn brown 65691.0 +quinn brown 65700.0 +quinn brown 65733.0 +quinn davidson 65549.0 +quinn davidson 65714.0 +quinn davidson 65776.0 +quinn davidson 65779.0 +quinn ellison 65705.0 +quinn ellison 65778.0 +quinn garcia 65568.0 +quinn garcia 65604.0 +quinn garcia 65610.0 +quinn garcia 65773.0 +quinn ichabod 65609.0 +quinn king 65558.0 +quinn king 65649.0 +quinn laertes 65542.0 +quinn laertes 65560.0 +quinn laertes 65627.0 +quinn nixon 65659.0 +quinn ovid 65699.0 +quinn quirinius 65747.0 +quinn robinson 65627.0 +quinn steinbeck 65578.0 +quinn steinbeck 65763.0 +quinn thompson 65643.0 +quinn thompson 65774.0 +quinn underhill 65549.0 +quinn underhill 65694.0 +quinn underhill 65767.0 +quinn van buren 65725.0 +quinn young 65647.0 +quinn zipper 65579.0 +quinn zipper 65693.0 +rachel allen 65661.0 +rachel allen 65709.0 +rachel brown 65586.0 +rachel brown 65587.0 +rachel brown 65587.0 +rachel brown 65610.0 +rachel brown 65693.0 +rachel carson 65677.0 +rachel carson 65682.0 +rachel davidson 65755.0 +rachel ellison 65761.0 +rachel falkner 65616.0 +rachel falkner 65681.0 +rachel falkner 65693.0 +rachel falkner 65764.0 +rachel johnson 65658.0 +rachel king 65604.0 +rachel king 65643.0 +rachel laertes 65562.0 +rachel laertes 65624.0 +rachel ovid 65721.0 +rachel ovid 65736.0 +rachel polk 65686.0 +rachel quirinius 65787.0 +rachel robinson 65544.0 +rachel robinson 65717.0 +rachel robinson 65724.0 +rachel thompson 65648.0 +rachel thompson 65662.0 +rachel thompson 65733.0 +rachel underhill 65667.0 +rachel white 65615.0 +rachel white 65717.0 +rachel young 65727.0 +rachel zipper 65757.0 +rachel zipper 65785.0 +sarah carson 65616.0 +sarah carson 65693.0 +sarah carson 65694.0 +sarah ellison 65611.0 +sarah falkner 65606.0 +sarah falkner 65680.0 +sarah garcia 65563.0 +sarah garcia 65638.0 +sarah garcia 65661.0 +sarah ichabod 65667.0 +sarah ichabod 65671.0 +sarah johnson 65659.0 +sarah johnson 65716.0 +sarah johnson 65731.0 +sarah johnson 65751.0 +sarah king 65650.0 +sarah king 65699.0 +sarah miller 65557.0 +sarah ovid 65550.0 +sarah robinson 65677.0 +sarah robinson 65763.0 +sarah steinbeck 65721.0 +sarah white 65622.0 +sarah white 65747.0 +sarah xylophone 65678.0 +sarah young 65595.0 +sarah zipper 65550.0 +tom brown 65593.0 +tom brown 65675.0 +tom carson 65539.0 +tom carson 65624.0 +tom carson 65780.0 +tom davidson 65780.0 +tom ellison 65578.0 +tom ellison 65670.0 +tom ellison 65756.0 +tom falkner 65574.0 +tom falkner 65625.0 +tom hernandez 65575.0 +tom hernandez 65632.0 +tom ichabod 65588.0 +tom johnson 65536.0 +tom johnson 65789.0 +tom king 65576.0 +tom laertes 65617.0 +tom laertes 65701.0 +tom miller 65594.0 +tom miller 65603.0 +tom miller 65704.0 +tom nixon 65672.0 +tom ovid 65628.0 +tom polk 65652.0 +tom polk 65742.0 +tom quirinius 65563.0 +tom quirinius 65783.0 +tom robinson 65626.0 +tom robinson 65632.0 +tom robinson 65691.0 +tom robinson 65758.0 +tom steinbeck 65666.0 +tom van buren 65621.0 +tom van buren 65652.0 +tom van buren 65669.0 +tom white 65548.0 +tom young 65544.0 +tom young 65546.0 +tom zipper 65789.0 +ulysses brown 65735.0 +ulysses carson 65602.0 +ulysses carson 65643.0 +ulysses carson 65703.0 +ulysses carson 65716.0 +ulysses davidson 65750.0 +ulysses ellison 65575.0 +ulysses garcia 65666.0 +ulysses hernandez 65651.0 +ulysses hernandez 65702.0 +ulysses hernandez 65786.0 +ulysses ichabod 65551.0 +ulysses ichabod 65566.0 +ulysses johnson 65776.0 +ulysses king 65649.0 +ulysses laertes 65691.0 +ulysses laertes 65711.0 +ulysses laertes 65781.0 +ulysses miller 65610.0 +ulysses miller 65637.0 +ulysses nixon 65603.0 +ulysses ovid 65656.0 +ulysses polk 65563.0 +ulysses polk 65580.0 +ulysses polk 65612.0 +ulysses polk 65777.0 +ulysses quirinius 65786.0 +ulysses robinson 65744.0 +ulysses steinbeck 65611.0 +ulysses steinbeck 65680.0 +ulysses thompson 65788.0 +ulysses underhill 65570.0 +ulysses underhill 65616.0 +ulysses underhill 65620.0 +ulysses underhill 65623.0 +ulysses underhill 65641.0 +ulysses underhill 65713.0 +ulysses underhill 65785.0 +ulysses van buren 65684.0 +ulysses white 65654.0 +ulysses white 65675.0 +ulysses xylophone 65623.0 +ulysses xylophone 65636.0 +ulysses xylophone 65781.0 +ulysses young 65675.0 +ulysses young 65736.0 +ulysses young 65748.0 +victor allen 65684.0 +victor allen 65707.0 +victor brown 65550.0 +victor brown 65555.0 +victor brown 65622.0 +victor brown 65673.0 +victor davidson 65579.0 +victor davidson 65628.0 +victor davidson 65783.0 +victor ellison 65641.0 +victor ellison 65782.0 +victor hernandez 65571.0 +victor hernandez 65659.0 +victor hernandez 65708.0 +victor hernandez 65735.0 +victor hernandez 65775.0 +victor johnson 65606.0 +victor johnson 65607.0 +victor johnson 65607.0 +victor king 65721.0 +victor king 65743.0 +victor laertes 65638.0 +victor laertes 65644.0 +victor miller 65570.0 +victor nixon 65709.0 +victor nixon 65791.0 +victor ovid 65649.0 +victor polk 65625.0 +victor quirinius 65620.0 +victor quirinius 65651.0 +victor robinson 65596.0 +victor robinson 65673.0 +victor steinbeck 65618.0 +victor steinbeck 65661.0 +victor steinbeck 65686.0 +victor thompson 65548.0 +victor van buren 65664.0 +victor van buren 65774.0 +victor white 65548.0 +victor white 65601.0 +victor xylophone 65549.0 +victor xylophone 65618.0 +victor xylophone 65644.0 +victor xylophone 65677.0 +victor xylophone 65755.0 +victor young 65628.0 +victor zipper 65743.0 +wendy allen 65628.0 +wendy allen 65711.0 +wendy allen 65782.0 +wendy brown 65580.0 +wendy brown 65657.0 +wendy ellison 65545.0 +wendy ellison 65603.0 +wendy falkner 65595.0 +wendy falkner 65604.0 +wendy falkner 65635.0 +wendy garcia 65659.0 +wendy garcia 65746.0 +wendy garcia 65747.0 +wendy garcia 65777.0 +wendy hernandez 65650.0 +wendy ichabod 65730.0 +wendy king 65586.0 +wendy king 65664.0 +wendy king 65670.0 +wendy laertes 65566.0 +wendy laertes 65683.0 +wendy laertes 65727.0 +wendy miller 65582.0 +wendy miller 65626.0 +wendy nixon 65611.0 +wendy nixon 65746.0 +wendy ovid 65589.0 +wendy ovid 65643.0 +wendy polk 65656.0 +wendy polk 65692.0 +wendy quirinius 65766.0 +wendy quirinius 65767.0 +wendy robinson 65622.0 +wendy robinson 65715.0 +wendy robinson 65774.0 +wendy steinbeck 65612.0 +wendy thompson 65650.0 +wendy thompson 65737.0 +wendy underhill 65662.0 +wendy underhill 65758.0 +wendy underhill 65775.0 +wendy van buren 65680.0 +wendy van buren 65699.0 +wendy white 65705.0 +wendy xylophone 65687.0 +wendy xylophone 65773.0 +wendy young 65674.0 +wendy young 65685.0 +xavier allen 65611.0 +xavier allen 65618.0 +xavier allen 65771.0 +xavier brown 65600.0 +xavier brown 65704.0 +xavier brown 65723.0 +xavier carson 65731.0 +xavier carson 65758.0 +xavier davidson 65644.0 +xavier davidson 65664.0 +xavier davidson 65755.0 +xavier ellison 65541.0 +xavier ellison 65622.0 +xavier garcia 65672.0 +xavier hernandez 65541.0 +xavier hernandez 65544.0 +xavier hernandez 65766.0 +xavier ichabod 65597.0 +xavier ichabod 65663.0 +xavier johnson 65655.0 +xavier johnson 65744.0 +xavier king 65590.0 +xavier king 65601.0 +xavier laertes 65743.0 +xavier ovid 65788.0 +xavier polk 65587.0 +xavier polk 65653.0 +xavier polk 65675.0 +xavier polk 65696.0 +xavier quirinius 65599.0 +xavier quirinius 65650.0 +xavier quirinius 65656.0 +xavier quirinius 65737.0 +xavier thompson 65608.0 +xavier underhill 65710.0 +xavier white 65703.0 +xavier white 65732.0 +xavier xylophone 65572.0 +xavier zipper 65561.0 +yuri allen 65565.0 +yuri allen 65682.0 +yuri brown 65538.0 +yuri brown 65688.0 +yuri carson 65670.0 +yuri carson 65769.0 +yuri ellison 65570.0 +yuri ellison 65581.0 +yuri falkner 65658.0 +yuri falkner 65681.0 +yuri garcia 65639.0 +yuri hernandez 65706.0 +yuri johnson 65587.0 +yuri johnson 65697.0 +yuri johnson 65712.0 +yuri king 65721.0 +yuri laertes 65637.0 +yuri laertes 65773.0 +yuri nixon 65635.0 +yuri nixon 65740.0 +yuri polk 65607.0 +yuri polk 65713.0 +yuri polk 65742.0 +yuri quirinius 65544.0 +yuri quirinius 65617.0 +yuri quirinius 65695.0 +yuri steinbeck 65592.0 +yuri steinbeck 65679.0 +yuri thompson 65676.0 +yuri underhill 65718.0 +yuri underhill 65750.0 +yuri white 65659.0 +yuri xylophone 65714.0 +zach allen 65667.0 +zach brown 65559.0 +zach brown 65588.0 +zach brown 65691.0 +zach brown 65759.0 +zach brown 65762.0 +zach carson 65572.0 +zach ellison 65748.0 +zach falkner 65620.0 +zach falkner 65627.0 +zach garcia 65544.0 +zach garcia 65623.0 +zach garcia 65629.0 +zach garcia 65764.5 +zach ichabod 65599.0 +zach ichabod 65612.0 +zach king 65556.0 +zach king 65702.0 +zach king 65773.0 +zach miller 65583.0 +zach miller 65665.0 +zach miller 65719.0 +zach ovid 65578.0 +zach ovid 65669.0 +zach ovid 65703.0 +zach ovid 65784.0 +zach quirinius 65691.0 +zach robinson 65599.0 +zach steinbeck 65602.0 +zach steinbeck 65695.0 +zach thompson 65636.0 +zach thompson 65696.0 +zach underhill 65573.0 +zach white 65733.0 +zach xylophone 65542.0 +zach xylophone 65780.0 +zach young 65576.0 +zach zipper 65579.0 +zach zipper 65649.0 +zach zipper 65676.0 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-3-d90b27fca067b0b3c48d873b3ef32af7 b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-3-d90b27fca067b0b3c48d873b3ef32af7 new file mode 100644 index 000000000000..072a8a891a83 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-3-d90b27fca067b0b3c48d873b3ef32af7 @@ -0,0 +1,1049 @@ +65536 +65536 +65536 +65536 +65536 +65536 +65537 +65537 +65537 +65537 +65538 +65538 +65538 +65539 +65539 +65539 +65540 +65541 +65541 +65541 +65541 +65541 +65541 +65541 +65541 +65542 +65542 +65542 +65542 +65543 +65543 +65544 +65544 +65544 +65544 +65544 +65544 +65544 +65544 +65545 +65545 +65545 +65547 +65547 +65547 +65547 +65548 +65548 +65548 +65548 +65548 +65548 +65548 +65549 +65549 +65549 +65549 +65549 +65549 +65549 +65549 +65550 +65550 +65550 +65550 +65550 +65550 +65550 +65551 +65551 +65551 +65552 +65552 +65552 +65552 +65552 +65553 +65553 +65553 +65553 +65554 +65554 +65555 +65556 +65556 +65556 +65557 +65557 +65558 +65559 +65559 +65559 +65559 +65560 +65560 +65560 +65560 +65561 +65562 +65562 +65562 +65562 +65562 +65562 +65563 +65563 +65563 +65563 +65563 +65564 +65564 +65564 +65564 +65564 +65564 +65564 +65564 +65565 +65565 +65566 +65566 +65567 +65568 +65568 +65568 +65569 +65570 +65570 +65570 +65570 +65570 +65571 +65571 +65571 +65572 +65572 +65572 +65572 +65572 +65573 +65573 +65573 +65574 +65574 +65574 +65575 +65575 +65575 +65575 +65576 +65576 +65576 +65576 +65577 +65578 +65578 +65578 +65578 +65578 +65578 +65579 +65579 +65579 +65579 +65579 +65580 +65580 +65580 +65581 +65581 +65581 +65581 +65581 +65582 +65582 +65582 +65582 +65582 +65583 +65583 +65583 +65583 +65584 +65584 +65584 +65584 +65585 +65585 +65585 +65585 +65585 +65586 +65586 +65586 +65586 +65586 +65587 +65587 +65587 +65587 +65588 +65588 +65588 +65588 +65588 +65589 +65589 +65589 +65589 +65589 +65590 +65590 +65590 +65590 +65590 +65591 +65591 +65591 +65591 +65592 +65593 +65593 +65594 +65594 +65594 +65594 +65595 +65595 +65595 +65595 +65595 +65595 +65596 +65596 +65596 +65596 +65596 +65597 +65597 +65597 +65598 +65599 +65599 +65599 +65600 +65600 +65600 +65600 +65601 +65601 +65601 +65602 +65602 +65602 +65602 +65603 +65603 +65603 +65603 +65603 +65604 +65604 +65604 +65605 +65606 +65606 +65606 +65606 +65606 +65607 +65607 +65607 +65607 +65607 +65607 +65608 +65608 +65608 +65608 +65609 +65610 +65610 +65610 +65610 +65610 +65610 +65611 +65611 +65611 +65612 +65612 +65612 +65614 +65614 +65614 +65615 +65615 +65615 +65616 +65616 +65617 +65617 +65617 +65617 +65618 +65618 +65618 +65618 +65618 +65619 +65619 +65619 +65619 +65619 +65619 +65620 +65620 +65620 +65620 +65620 +65620 +65621 +65621 +65622 +65622 +65622 +65622 +65622 +65622 +65623 +65623 +65623 +65623 +65623 +65623 +65623 +65624 +65624 +65624 +65624 +65625 +65625 +65625 +65625 +65626 +65626 +65626 +65626 +65627 +65627 +65627 +65627 +65628 +65628 +65628 +65628 +65628 +65628 +65629 +65629 +65629 +65630 +65630 +65631 +65632 +65632 +65632 +65633 +65633 +65633 +65633 +65634 +65634 +65634 +65634 +65635 +65635 +65635 +65636 +65636 +65636 +65636 +65637 +65637 +65637 +65637 +65637 +65637 +65637 +65638 +65638 +65638 +65639 +65640 +65641 +65641 +65641 +65641 +65642 +65643 +65643 +65643 +65643 +65643 +65643 +65643 +65643 +65643 +65643 +65644 +65644 +65644 +65644 +65644 +65645 +65645 +65645 +65646 +65646 +65646 +65647 +65647 +65648 +65648 +65648 +65648 +65649 +65649 +65649 +65650 +65650 +65650 +65650 +65650 +65650 +65650 +65651 +65651 +65651 +65651 +65651 +65651 +65652 +65652 +65652 +65653 +65653 +65653 +65653 +65654 +65654 +65654 +65654 +65654 +65654 +65654 +65654 +65654 +65655 +65655 +65656 +65656 +65656 +65656 +65656 +65656 +65657 +65657 +65657 +65657 +65657 +65658 +65658 +65658 +65658 +65658 +65658 +65658 +65658 +65658 +65658 +65658 +65659 +65659 +65659 +65659 +65659 +65659 +65659 +65659 +65660 +65660 +65660 +65661 +65661 +65661 +65661 +65661 +65662 +65662 +65662 +65662 +65662 +65662 +65663 +65663 +65663 +65663 +65664 +65664 +65664 +65664 +65664 +65664 +65664 +65665 +65666 +65666 +65667 +65667 +65667 +65667 +65667 +65667 +65668 +65669 +65669 +65669 +65669 +65669 +65669 +65670 +65670 +65670 +65670 +65670 +65671 +65671 +65671 +65672 +65672 +65672 +65672 +65672 +65672 +65673 +65673 +65673 +65673 +65673 +65674 +65674 +65674 +65674 +65675 +65675 +65675 +65675 +65675 +65675 +65675 +65675 +65676 +65676 +65676 +65677 +65677 +65677 +65677 +65677 +65677 +65678 +65678 +65678 +65678 +65679 +65679 +65680 +65680 +65680 +65680 +65680 +65680 +65680 +65681 +65681 +65681 +65681 +65681 +65682 +65682 +65682 +65683 +65683 +65683 +65683 +65684 +65684 +65684 +65684 +65685 +65685 +65685 +65685 +65685 +65685 +65686 +65686 +65686 +65687 +65687 +65687 +65687 +65688 +65688 +65689 +65689 +65690 +65690 +65691 +65691 +65691 +65691 +65691 +65691 +65691 +65691 +65692 +65692 +65693 +65693 +65693 +65693 +65693 +65693 +65694 +65694 +65694 +65695 +65695 +65695 +65695 +65695 +65696 +65696 +65696 +65696 +65697 +65697 +65697 +65697 +65697 +65697 +65698 +65698 +65698 +65699 +65699 +65699 +65699 +65699 +65699 +65700 +65700 +65700 +65701 +65701 +65701 +65702 +65702 +65702 +65702 +65702 +65703 +65703 +65703 +65703 +65703 +65704 +65704 +65704 +65704 +65704 +65704 +65705 +65705 +65705 +65705 +65706 +65706 +65706 +65706 +65706 +65706 +65707 +65707 +65708 +65708 +65709 +65709 +65709 +65710 +65711 +65711 +65711 +65711 +65711 +65711 +65712 +65712 +65712 +65712 +65712 +65713 +65713 +65713 +65713 +65713 +65714 +65714 +65714 +65715 +65715 +65715 +65715 +65715 +65716 +65716 +65716 +65716 +65716 +65716 +65717 +65717 +65717 +65717 +65717 +65718 +65718 +65718 +65718 +65719 +65719 +65719 +65719 +65720 +65720 +65720 +65720 +65720 +65720 +65721 +65721 +65721 +65721 +65721 +65721 +65721 +65722 +65722 +65722 +65722 +65723 +65723 +65724 +65724 +65724 +65724 +65724 +65724 +65725 +65726 +65726 +65726 +65726 +65727 +65727 +65727 +65727 +65727 +65728 +65728 +65729 +65730 +65730 +65730 +65730 +65731 +65731 +65731 +65731 +65732 +65732 +65732 +65733 +65733 +65733 +65733 +65733 +65733 +65734 +65734 +65735 +65735 +65735 +65736 +65736 +65736 +65736 +65737 +65737 +65737 +65737 +65737 +65738 +65738 +65738 +65738 +65739 +65739 +65739 +65740 +65740 +65740 +65741 +65742 +65742 +65742 +65743 +65743 +65743 +65743 +65744 +65744 +65745 +65745 +65745 +65745 +65746 +65746 +65746 +65746 +65747 +65747 +65747 +65747 +65747 +65747 +65747 +65748 +65748 +65749 +65749 +65749 +65749 +65749 +65750 +65750 +65750 +65750 +65750 +65751 +65751 +65751 +65751 +65751 +65752 +65752 +65753 +65753 +65754 +65754 +65755 +65755 +65755 +65755 +65755 +65755 +65755 +65755 +65755 +65756 +65756 +65756 +65756 +65756 +65757 +65757 +65757 +65757 +65757 +65758 +65758 +65758 +65758 +65758 +65758 +65759 +65759 +65759 +65759 +65759 +65760 +65760 +65760 +65760 +65760 +65761 +65762 +65762 +65762 +65762 +65762 +65763 +65763 +65763 +65763 +65763 +65764 +65764 +65764 +65765 +65765 +65765 +65766 +65766 +65766 +65766 +65766 +65767 +65767 +65767 +65768 +65769 +65769 +65769 +65769 +65769 +65769 +65769 +65770 +65770 +65771 +65771 +65771 +65772 +65773 +65773 +65773 +65773 +65773 +65773 +65773 +65773 +65774 +65774 +65774 +65774 +65774 +65774 +65775 +65775 +65775 +65775 +65775 +65775 +65776 +65776 +65776 +65776 +65776 +65776 +65776 +65777 +65777 +65777 +65777 +65777 +65777 +65778 +65778 +65778 +65778 +65778 +65778 +65778 +65778 +65778 +65779 +65779 +65779 +65779 +65779 +65779 +65779 +65780 +65780 +65780 +65781 +65781 +65781 +65782 +65782 +65782 +65782 +65782 +65783 +65783 +65783 +65783 +65783 +65783 +65783 +65784 +65784 +65784 +65784 +65784 +65785 +65785 +65786 +65786 +65786 +65786 +65786 +65787 +65787 +65787 +65787 +65787 +65788 +65788 +65788 +65788 +65789 +65789 +65789 +65789 +65789 +65789 +65789 +65789 +65789 +65789 +65790 +65790 +65790 +65791 +65791 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-4-f2e4d659b65a833e9281b6786d3d55c1 b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-4-f2e4d659b65a833e9281b6786d3d55c1 new file mode 100644 index 000000000000..9cc7e7ea6c2b --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_udaf.q (deterministic)-4-f2e4d659b65a833e9281b6786d3d55c1 @@ -0,0 +1,1049 @@ + 24.37875 + 27.900000000000002 + 43.64 +alice allen 16.919999999999998 +alice allen 20.39 +alice allen 23.59 +alice brown 6.91 +alice carson 41.74 +alice davidson 26.346000000000004 +alice falkner 32.166666666666664 +alice garcia 15.412 +alice hernandez 19.958181818181817 +alice hernandez 30.482857142857142 +alice johnson 25.51 +alice king 13.085 +alice king 25.616666666666664 +alice king 42.335 +alice laertes 20.549999999999997 +alice laertes 30.436 +alice miller 29.563333333333333 +alice nixon 19.28666666666667 +alice nixon 24.5625 +alice nixon 32.99 +alice ovid 31.35 +alice polk 17.863999999999997 +alice quirinius 19.032857142857143 +alice quirinius 23.9425 +alice robinson 23.338750000000005 +alice robinson 39.85 +alice steinbeck 22.862000000000002 +alice steinbeck 26.328000000000003 +alice steinbeck 27.08777777777778 +alice underhill 24.032222222222224 +alice van buren 19.642000000000003 +alice xylophone 24.438000000000002 +alice xylophone 28.739999999999995 +alice xylophone 30.0825 +alice zipper 26.3 +alice zipper 28.735000000000003 +alice zipper 31.05545454545455 +bob brown 12.902222222222223 +bob brown 13.945 +bob brown 33.843333333333334 +bob carson 28.627999999999997 +bob davidson 19.8525 +bob davidson 23.482 +bob davidson 24.67 +bob ellison 16.315714285714286 +bob ellison 18.4 +bob ellison 26.913999999999998 +bob ellison 27.59 +bob falkner 9.27 +bob garcia 11.63 +bob garcia 22.221249999999998 +bob garcia 23.59636363636364 +bob garcia 26.88857142857143 +bob garcia 28.715000000000003 +bob hernandez 37.23 +bob ichabod 28.33875 +bob king 8.615 +bob king 19.77 +bob king 26.7325 +bob laertes 21.33 +bob laertes 37.88 +bob miller 25.495 +bob ovid 25.675 +bob ovid 25.83 +bob ovid 28.37875 +bob ovid 32.5025 +bob polk 9.74 +bob quirinius 34.57 +bob steinbeck 9.725 +bob van buren 29.552857142857142 +bob white 17.685 +bob white 29.46285714285715 +bob xylophone 17.03 +bob xylophone 33.24 +bob young 19.824 +bob zipper 24.095 +bob zipper 33.36 +bob zipper 34.99 +calvin allen 21.3 +calvin brown 20.808 +calvin brown 24.16 +calvin brown 24.636666666666667 +calvin carson 22.815 +calvin davidson 22.116666666666664 +calvin davidson 22.364 +calvin ellison 24.92 +calvin falkner 18.343999999999998 +calvin falkner 19.56 +calvin falkner 22.946000000000005 +calvin falkner 23.327777777777776 +calvin falkner 23.974999999999998 +calvin falkner 33.382 +calvin garcia 17.285 +calvin hernandez 12.663333333333334 +calvin johnson 24.898571428571433 +calvin laertes 28.105 +calvin laertes 28.362000000000002 +calvin nixon 26.784285714285716 +calvin nixon 27.36 +calvin nixon 32.282 +calvin ovid 22.063333333333336 +calvin ovid 22.81500000000001 +calvin ovid 25.495714285714286 +calvin ovid 30.926666666666666 +calvin polk 27.820000000000004 +calvin quirinius 16.28 +calvin quirinius 25.552500000000002 +calvin robinson 31.814999999999998 +calvin steinbeck 12.85 +calvin steinbeck 14.939999999999998 +calvin steinbeck 17.535 +calvin thompson 28.592857142857145 +calvin thompson 40.79 +calvin underhill 24.062 +calvin van buren 26.525 +calvin van buren 28.865 +calvin white 28.256249999999998 +calvin white 43.275 +calvin xylophone 24.13111111111111 +calvin xylophone 25.27 +calvin xylophone 36.455 +calvin young 19.06 +calvin young 21.455999999999996 +calvin zipper 10.674999999999999 +calvin zipper 26.012857142857143 +david allen 25.134285714285713 +david allen 41.72333333333333 +david brown 8.52 +david brown 28.968181818181815 +david davidson 17.63 +david davidson 26.563333333333336 +david davidson 30.7325 +david davidson 33.33 +david ellison 23.79909090909091 +david ellison 24.74888888888889 +david ellison 26.198571428571427 +david hernandez 27.766 +david ichabod 16.66 +david ichabod 19.538 +david laertes 24.587500000000002 +david nixon 26.01375 +david ovid 24.131428571428575 +david ovid 32.72 +david quirinius 16.5 +david quirinius 25.08 +david quirinius 29.415 +david robinson 22.2175 +david robinson 30.99 +david thompson 25.38 +david underhill 1.17 +david underhill 21.546666666666667 +david underhill 28.26 +david van buren 26.45833333333334 +david van buren 35.7825 +david white 15.833333333333334 +david xylophone 10.71 +david xylophone 26.341428571428565 +david xylophone 33.224000000000004 +david young 9.64 +david young 21.22 +ethan allen 22.68 +ethan brown 19.37 +ethan brown 21.58666666666667 +ethan brown 21.799999999999997 +ethan brown 29.099999999999998 +ethan brown 32.43666666666667 +ethan brown 39.84 +ethan carson 24.15666666666667 +ethan ellison 27.80777777777778 +ethan ellison 48.71 +ethan falkner 17.993333333333336 +ethan falkner 26.775000000000002 +ethan garcia 19.15 +ethan hernandez 25.081111111111113 +ethan johnson 32.81875 +ethan king 19.51 +ethan laertes 16.463 +ethan laertes 17.625999999999998 +ethan laertes 25.020714285714288 +ethan laertes 26.697142857142858 +ethan laertes 28.14 +ethan laertes 29.668571428571425 +ethan laertes 36.589999999999996 +ethan miller 24.326666666666664 +ethan nixon 34.78666666666667 +ethan ovid 20.642857142857142 +ethan polk 6.98 +ethan polk 12.756666666666666 +ethan polk 30.324 +ethan polk 40.46 +ethan quirinius 23.419999999999998 +ethan quirinius 24.36 +ethan quirinius 29.068 +ethan robinson 24.463750000000005 +ethan robinson 31.630000000000003 +ethan underhill 19.86 +ethan van buren 22.241999999999997 +ethan white 31.3175 +ethan white 32.87 +ethan xylophone 30.996000000000002 +ethan zipper 22.728333333333335 +ethan zipper 29.66 +fred davidson 30.116666666666667 +fred davidson 33.55200000000001 +fred davidson 39.37 +fred ellison 16.72 +fred ellison 17.462 +fred ellison 35.1 +fred falkner 14.51 +fred falkner 27.207000000000004 +fred falkner 27.887500000000003 +fred hernandez 36.045 +fred ichabod 29.017000000000003 +fred ichabod 30.405000000000005 +fred johnson 16.9925 +fred king 20.024 +fred king 32.54666666666667 +fred laertes 25.610000000000003 +fred miller 25.92 +fred nixon 14.915 +fred nixon 21.830000000000002 +fred nixon 24.4125 +fred nixon 31.360000000000003 +fred polk 18.698 +fred polk 19.743000000000002 +fred polk 20.96 +fred polk 31.11 +fred quirinius 20.085 +fred quirinius 33.9 +fred robinson 22.502 +fred steinbeck 21.123749999999998 +fred steinbeck 25.572 +fred steinbeck 30.81 +fred underhill 29.198888888888884 +fred van buren 21.34 +fred van buren 23.285 +fred van buren 26.520000000000003 +fred van buren 33.6 +fred white 21.41 +fred young 16.876250000000002 +fred young 20.996666666666666 +fred zipper 23.627499999999998 +gabriella allen 24.113333333333333 +gabriella allen 28.4725 +gabriella brown 29.963333333333335 +gabriella brown 30.65222222222222 +gabriella carson 16.6325 +gabriella davidson 34.52 +gabriella ellison 20.18 +gabriella ellison 29.62 +gabriella falkner 14.37 +gabriella falkner 17.738333333333333 +gabriella falkner 28.61 +gabriella garcia 39.025 +gabriella hernandez 20.818333333333335 +gabriella hernandez 24.601666666666663 +gabriella ichabod 10.4925 +gabriella ichabod 20.686666666666667 +gabriella ichabod 23.185 +gabriella ichabod 23.43 +gabriella ichabod 27.44636363636364 +gabriella king 13.645 +gabriella king 22.23 +gabriella laertes 23.735 +gabriella miller 17.165 +gabriella ovid 22.884545454545453 +gabriella ovid 25.29 +gabriella polk 20.38714285714286 +gabriella polk 25.832000000000004 +gabriella steinbeck 6.226666666666667 +gabriella steinbeck 29.683333333333337 +gabriella thompson 25.565454545454546 +gabriella thompson 29.031 +gabriella thompson 29.122500000000006 +gabriella van buren 24.353 +gabriella van buren 34.21666666666667 +gabriella white 36.5175 +gabriella young 21.28142857142857 +gabriella young 21.32 +gabriella zipper 21.798461538461545 +gabriella zipper 28.676666666666666 +holly allen 27.18 +holly brown 22.76 +holly brown 30.950000000000003 +holly falkner 29.666666666666668 +holly hernandez 19.875 +holly hernandez 23.7075 +holly hernandez 24.5 +holly hernandez 26.50333333333333 +holly ichabod 23.262857142857143 +holly ichabod 25.85090909090909 +holly ichabod 29.521666666666665 +holly johnson 18.939999999999998 +holly johnson 23.2625 +holly johnson 26.49285714285714 +holly king 20.61333333333333 +holly king 30.95888888888889 +holly laertes 17.509999999999998 +holly miller 40.8975 +holly nixon 27.775714285714287 +holly nixon 30.642500000000002 +holly polk 21.02 +holly polk 24.446666666666665 +holly robinson 26.083750000000006 +holly thompson 18.801428571428573 +holly thompson 23.91 +holly thompson 29.97125 +holly underhill 18.19 +holly underhill 22.22888888888889 +holly underhill 22.813333333333336 +holly underhill 30.613999999999997 +holly van buren 20.113333333333333 +holly white 25.284999999999997 +holly white 41.0125 +holly xylophone 26.88571428571429 +holly young 30.8425 +holly young 33.24333333333334 +holly zipper 27.784000000000002 +holly zipper 28.384285714285713 +irene allen 34.605000000000004 +irene brown 18.740000000000002 +irene brown 28.974999999999998 +irene brown 32.230000000000004 +irene carson 25.665833333333335 +irene ellison 10.225000000000001 +irene ellison 26.119999999999997 +irene falkner 9.94 +irene falkner 19.41 +irene garcia 9.790000000000001 +irene garcia 19.666666666666668 +irene garcia 21.22666666666667 +irene ichabod 20.956666666666667 +irene ichabod 24.488333333333333 +irene johnson 25.34 +irene laertes 15.85 +irene laertes 21.573333333333334 +irene laertes 22.041999999999998 +irene miller 34.994285714285716 +irene nixon 22.52 +irene nixon 32.485 +irene nixon 33.165 +irene ovid 17.73 +irene ovid 22.96 +irene ovid 30.92 +irene polk 5.35 +irene polk 25.535 +irene polk 33.76 +irene polk 35.05 +irene polk 45.14 +irene quirinius 38.36 +irene quirinius 41.864999999999995 +irene quirinius 42.0 +irene robinson 30.86 +irene steinbeck 15.08 +irene thompson 28.419999999999998 +irene underhill 27.977999999999998 +irene underhill 28.438 +irene van buren 26.93625 +irene van buren 27.797999999999995 +irene xylophone 29.10454545454546 +jessica brown 38.325 +jessica carson 16.038 +jessica carson 29.668333333333337 +jessica carson 33.06 +jessica davidson 18.926 +jessica davidson 26.2975 +jessica davidson 27.611428571428572 +jessica davidson 29.86 +jessica ellison 26.873333333333335 +jessica ellison 27.123333333333335 +jessica falkner 21.75142857142858 +jessica garcia 16.939090909090908 +jessica garcia 26.48 +jessica ichabod 28.971666666666664 +jessica johnson 21.601428571428574 +jessica johnson 24.42 +jessica miller 26.90571428571429 +jessica nixon 19.15 +jessica nixon 27.025000000000002 +jessica ovid 30.72285714285714 +jessica ovid 30.895 +jessica polk 27.912857142857145 +jessica quirinius 17.05 +jessica quirinius 21.529999999999998 +jessica quirinius 25.16 +jessica quirinius 26.347999999999995 +jessica robinson 24.322857142857142 +jessica thompson 28.658000000000005 +jessica thompson 30.873636363636365 +jessica underhill 14.6725 +jessica underhill 25.831666666666667 +jessica underhill 31.345000000000002 +jessica van buren 19.575 +jessica white 18.35 +jessica white 19.175 +jessica white 20.812 +jessica white 26.0 +jessica white 29.307142857142857 +jessica xylophone 22.26 +jessica young 27.9525 +jessica young 37.61333333333334 +jessica zipper 7.03 +jessica zipper 15.794999999999998 +jessica zipper 19.95 +katie allen 27.283846153846152 +katie brown 24.156666666666666 +katie davidson 13.498000000000001 +katie ellison 19.2 +katie ellison 24.888571428571428 +katie falkner 28.959999999999997 +katie garcia 28.287142857142857 +katie garcia 36.196666666666665 +katie hernandez 25.14428571428572 +katie ichabod 19.363333333333333 +katie ichabod 20.458571428571428 +katie ichabod 28.924999999999997 +katie king 21.64125 +katie king 21.855 +katie king 22.895 +katie miller 16.263333333333335 +katie miller 30.274285714285718 +katie nixon 25.022499999999997 +katie ovid 24.055000000000003 +katie polk 21.296666666666667 +katie polk 32.03 +katie robinson 36.26 +katie van buren 28.332 +katie van buren 31.408000000000005 +katie white 23.48 +katie white 26.236666666666665 +katie xylophone 32.415 +katie young 18.209999999999997 +katie young 22.88125 +katie young 28.39888888888889 +katie zipper 10.285 +katie zipper 27.495 +luke allen 9.42 +luke allen 21.374615384615385 +luke allen 25.32 +luke allen 27.174999999999997 +luke allen 35.434 +luke brown 25.08 +luke davidson 28.205 +luke davidson 28.790000000000003 +luke ellison 7.8 +luke ellison 16.04 +luke ellison 23.426666666666666 +luke falkner 18.0 +luke falkner 22.19 +luke garcia 29.619999999999997 +luke garcia 32.722 +luke ichabod 21.150000000000002 +luke ichabod 32.78142857142857 +luke johnson 21.58666666666667 +luke johnson 23.03 +luke johnson 23.054 +luke laertes 20.264 +luke laertes 33.72 +luke laertes 39.8 +luke laertes 41.36 +luke laertes 42.254999999999995 +luke miller 20.054444444444446 +luke ovid 19.819999999999997 +luke ovid 30.832857142857147 +luke polk 24.348750000000003 +luke polk 26.57625 +luke quirinius 38.07 +luke robinson 30.119999999999994 +luke robinson 30.31375 +luke thompson 29.026874999999997 +luke underhill 21.735714285714288 +luke underhill 22.175 +luke underhill 26.785714285714285 +luke van buren 17.072222222222223 +luke white 29.063333333333333 +luke xylophone 28.994 +luke zipper 33.995 +mike allen 32.78 +mike brown 27.592222222222222 +mike carson 28.8675 +mike carson 29.88 +mike carson 32.07142857142857 +mike davidson 21.240000000000002 +mike davidson 46.31 +mike ellison 20.5275 +mike ellison 21.99 +mike ellison 24.36 +mike ellison 24.511111111111113 +mike ellison 27.703333333333337 +mike falkner 40.335 +mike garcia 24.3525 +mike garcia 24.582 +mike garcia 35.12 +mike hernandez 8.783333333333333 +mike hernandez 19.40666666666667 +mike ichabod 29.120000000000005 +mike king 14.256666666666668 +mike king 17.889999999999997 +mike king 20.493333333333336 +mike king 23.86 +mike king 26.081 +mike king 30.974 +mike miller 29.275 +mike nixon 17.306 +mike nixon 25.572 +mike polk 18.96 +mike polk 23.75142857142857 +mike polk 33.42 +mike quirinius 19.37375 +mike steinbeck 14.155 +mike steinbeck 19.305833333333332 +mike steinbeck 20.721249999999998 +mike steinbeck 31.75 +mike van buren 15.520000000000001 +mike van buren 25.828333333333333 +mike white 19.13111111111111 +mike white 22.4025 +mike white 24.7725 +mike white 35.235 +mike young 1.5 +mike young 24.679 +mike young 34.02833333333333 +mike zipper 17.97 +mike zipper 26.247333333333337 +mike zipper 44.169999999999995 +nick allen 23.744999999999997 +nick allen 36.93 +nick brown 27.669999999999998 +nick davidson 31.97285714285714 +nick ellison 23.061666666666667 +nick ellison 27.676666666666666 +nick falkner 22.555714285714284 +nick falkner 27.46 +nick garcia 17.465 +nick garcia 18.854 +nick garcia 33.60333333333333 +nick ichabod 19.231428571428573 +nick ichabod 27.645000000000003 +nick ichabod 35.836666666666666 +nick johnson 5.58 +nick johnson 25.274 +nick laertes 26.57857142857143 +nick miller 22.208333333333332 +nick nixon 16.107499999999998 +nick ovid 31.350000000000005 +nick polk 35.70333333333334 +nick quirinius 20.753333333333334 +nick quirinius 30.573333333333334 +nick robinson 21.48 +nick robinson 23.185 +nick steinbeck 19.56555555555556 +nick thompson 31.474999999999998 +nick underhill 38.24 +nick van buren 20.77375 +nick xylophone 30.909999999999997 +nick young 10.725000000000001 +nick young 24.95 +nick zipper 16.185000000000002 +nick zipper 34.72 +oscar allen 24.645 +oscar brown 39.55 +oscar carson 21.893333333333334 +oscar carson 22.868 +oscar carson 27.4875 +oscar carson 28.09428571428571 +oscar carson 30.373333333333335 +oscar davidson 9.046666666666667 +oscar ellison 24.185000000000002 +oscar ellison 30.1675 +oscar falkner 19.295 +oscar garcia 22.495833333333334 +oscar hernandez 16.6825 +oscar hernandez 25.736 +oscar ichabod 17.64 +oscar ichabod 21.11 +oscar ichabod 23.508000000000003 +oscar ichabod 30.392222222222227 +oscar johnson 19.9375 +oscar johnson 21.114444444444445 +oscar king 24.590000000000003 +oscar king 26.675 +oscar king 39.6 +oscar laertes 14.975 +oscar laertes 15.525 +oscar laertes 22.6 +oscar laertes 41.6 +oscar nixon 25.4025 +oscar ovid 24.854285714285712 +oscar ovid 25.309 +oscar ovid 29.63 +oscar polk 21.235999999999997 +oscar polk 21.27 +oscar quirinius 24.200000000000003 +oscar quirinius 24.391428571428573 +oscar quirinius 27.83285714285714 +oscar quirinius 27.853333333333328 +oscar robinson 12.3625 +oscar robinson 12.545 +oscar robinson 20.234 +oscar robinson 28.071666666666673 +oscar steinbeck 31.101111111111113 +oscar thompson 19.4875 +oscar thompson 19.975714285714286 +oscar thompson 21.1425 +oscar thompson 21.166363636363638 +oscar underhill 27.644 +oscar van buren 25.843333333333334 +oscar van buren 29.073333333333334 +oscar van buren 29.682727272727274 +oscar white 19.0775 +oscar white 23.483333333333334 +oscar white 24.705000000000002 +oscar white 28.0075 +oscar xylophone 30.020000000000003 +oscar xylophone 30.46833333333333 +oscar xylophone 33.64 +oscar zipper 21.69 +oscar zipper 23.478 +oscar zipper 31.36 +priscilla brown 14.222 +priscilla brown 27.044999999999998 +priscilla brown 31.14769230769231 +priscilla carson 14.33 +priscilla carson 18.951428571428572 +priscilla carson 27.084999999999997 +priscilla ichabod 28.160999999999994 +priscilla ichabod 49.46 +priscilla johnson 8.365 +priscilla johnson 18.176666666666666 +priscilla johnson 25.02666666666667 +priscilla johnson 26.918333333333337 +priscilla johnson 30.695999999999998 +priscilla king 19.747142857142855 +priscilla nixon 29.035555555555554 +priscilla nixon 30.27333333333333 +priscilla ovid 13.591999999999999 +priscilla ovid 35.879999999999995 +priscilla polk 23.12 +priscilla quirinius 21.826666666666668 +priscilla thompson 20.44 +priscilla underhill 28.23 +priscilla underhill 34.33200000000001 +priscilla van buren 18.122857142857143 +priscilla van buren 20.16 +priscilla van buren 26.447999999999997 +priscilla white 26.37769230769231 +priscilla xylophone 13.95 +priscilla xylophone 20.596666666666668 +priscilla xylophone 27.22 +priscilla young 29.19 +priscilla young 46.28 +priscilla zipper 11.64 +priscilla zipper 31.159999999999997 +quinn allen 26.347272727272728 +quinn allen 26.85833333333333 +quinn brown 26.822857142857146 +quinn brown 30.406000000000006 +quinn brown 41.53 +quinn davidson 17.375714285714288 +quinn davidson 20.22666666666667 +quinn davidson 25.6375 +quinn davidson 30.173333333333332 +quinn ellison 23.052 +quinn ellison 40.565 +quinn garcia 20.544 +quinn garcia 24.104999999999997 +quinn garcia 25.174 +quinn garcia 28.446000000000005 +quinn ichabod 15.12 +quinn king 12.73 +quinn king 15.12125 +quinn laertes 17.29 +quinn laertes 28.221666666666668 +quinn laertes 32.96 +quinn nixon 26.034000000000002 +quinn ovid 28.71 +quinn quirinius 8.61 +quinn robinson 16.852 +quinn steinbeck 30.093333333333334 +quinn steinbeck 49.21 +quinn thompson 7.365 +quinn thompson 33.43125 +quinn underhill 24.045 +quinn underhill 27.905454545454543 +quinn underhill 31.21 +quinn van buren 27.807692307692314 +quinn young 30.56 +quinn zipper 18.31 +quinn zipper 21.380000000000003 +rachel allen 32.501666666666665 +rachel allen 46.57 +rachel brown 23.08 +rachel brown 23.880000000000003 +rachel brown 24.43 +rachel brown 34.11 +rachel brown 35.345 +rachel carson 27.468125 +rachel carson 37.446666666666665 +rachel davidson 22.75 +rachel ellison 22.848333333333333 +rachel falkner 18.78125 +rachel falkner 28.876250000000002 +rachel falkner 29.577777777777776 +rachel falkner 31.831249999999997 +rachel johnson 31.108000000000004 +rachel king 17.4175 +rachel king 30.873749999999998 +rachel laertes 17.470000000000002 +rachel laertes 33.51 +rachel ovid 3.03 +rachel ovid 15.38 +rachel polk 18.564285714285713 +rachel quirinius 31.692500000000003 +rachel robinson 0.6 +rachel robinson 23.953333333333333 +rachel robinson 37.645 +rachel thompson 11.96 +rachel thompson 29.484 +rachel thompson 38.43 +rachel underhill 27.55333333333333 +rachel white 23.511428571428574 +rachel white 33.7 +rachel young 24.85166666666667 +rachel zipper 22.85 +rachel zipper 37.382 +sarah carson 10.38 +sarah carson 22.639 +sarah carson 44.92 +sarah ellison 16.36 +sarah falkner 29.34875 +sarah falkner 29.64125 +sarah garcia 11.296666666666667 +sarah garcia 20.723333333333333 +sarah garcia 24.115 +sarah ichabod 26.948333333333327 +sarah ichabod 33.80428571428571 +sarah johnson 18.3925 +sarah johnson 23.087500000000002 +sarah johnson 26.57857142857143 +sarah johnson 37.01 +sarah king 9.556666666666667 +sarah king 25.6125 +sarah miller 19.14875 +sarah ovid 29.205 +sarah robinson 11.326666666666668 +sarah robinson 35.809999999999995 +sarah steinbeck 23.26 +sarah white 21.75111111111111 +sarah white 26.850000000000005 +sarah xylophone 33.40571428571429 +sarah young 30.66 +sarah zipper 29.521666666666672 +tom brown 16.38 +tom brown 23.645 +tom carson 23.630000000000003 +tom carson 31.935 +tom carson 41.83 +tom davidson 30.404285714285717 +tom ellison 27.056 +tom ellison 27.401999999999997 +tom ellison 29.812 +tom falkner 15.901999999999997 +tom falkner 25.49857142857143 +tom hernandez 11.418000000000001 +tom hernandez 30.705000000000002 +tom ichabod 14.83 +tom johnson 30.748571428571434 +tom johnson 37.086666666666666 +tom king 17.923333333333332 +tom laertes 19.201666666666668 +tom laertes 22.276666666666667 +tom miller 17.9925 +tom miller 19.791666666666668 +tom miller 19.9225 +tom nixon 25.70625 +tom ovid 29.66 +tom polk 27.0975 +tom polk 28.646666666666672 +tom quirinius 37.68333333333333 +tom quirinius 38.28 +tom robinson 18.07 +tom robinson 19.094 +tom robinson 27.34125 +tom robinson 31.135714285714283 +tom steinbeck 32.70333333333333 +tom van buren 20.723333333333333 +tom van buren 24.8525 +tom van buren 31.631666666666664 +tom white 25.646000000000004 +tom young 3.12 +tom young 19.588333333333335 +tom zipper 23.317272727272726 +ulysses brown 16.196666666666665 +ulysses carson 16.3475 +ulysses carson 22.448181818181823 +ulysses carson 28.258 +ulysses carson 32.10833333333333 +ulysses davidson 37.775 +ulysses ellison 30.517000000000003 +ulysses garcia 32.92 +ulysses hernandez 13.877500000000001 +ulysses hernandez 20.856666666666666 +ulysses hernandez 21.32625 +ulysses ichabod 3.29 +ulysses ichabod 24.629999999999995 +ulysses johnson 32.208333333333336 +ulysses king 25.29111111111111 +ulysses laertes 14.936666666666667 +ulysses laertes 25.89 +ulysses laertes 26.63 +ulysses miller 2.36 +ulysses miller 26.403333333333336 +ulysses nixon 34.4575 +ulysses ovid 23.810000000000002 +ulysses polk 22.4075 +ulysses polk 26.778000000000002 +ulysses polk 38.73166666666667 +ulysses polk 47.68 +ulysses quirinius 33.07833333333333 +ulysses robinson 17.386666666666667 +ulysses steinbeck 22.2675 +ulysses steinbeck 24.904000000000003 +ulysses thompson 22.687142857142856 +ulysses underhill 6.66 +ulysses underhill 22.539 +ulysses underhill 24.853333333333335 +ulysses underhill 27.314 +ulysses underhill 29.424999999999997 +ulysses underhill 32.905 +ulysses underhill 41.653333333333336 +ulysses van buren 21.868181818181817 +ulysses white 15.296666666666667 +ulysses white 28.343333333333334 +ulysses xylophone 24.718 +ulysses xylophone 30.205 +ulysses xylophone 35.61 +ulysses young 21.56 +ulysses young 32.28125 +ulysses young 37.275 +victor allen 23.548000000000002 +victor allen 24.759999999999998 +victor brown 22.10181818181818 +victor brown 23.73 +victor brown 25.427272727272726 +victor brown 26.218571428571433 +victor davidson 20.55 +victor davidson 22.21666666666667 +victor davidson 29.778 +victor ellison 13.0775 +victor ellison 33.666 +victor hernandez 10.896 +victor hernandez 18.922 +victor hernandez 24.908888888888892 +victor hernandez 27.426666666666666 +victor hernandez 35.6675 +victor johnson 20.02 +victor johnson 27.070000000000004 +victor johnson 29.0775 +victor king 18.066666666666666 +victor king 21.488 +victor laertes 26.77777777777778 +victor laertes 28.095000000000002 +victor miller 5.3100000000000005 +victor nixon 21.395714285714288 +victor nixon 28.33 +victor ovid 35.225 +victor polk 21.990000000000002 +victor quirinius 24.62833333333333 +victor quirinius 29.742500000000003 +victor robinson 14.575 +victor robinson 25.92 +victor steinbeck 26.136666666666667 +victor steinbeck 26.485 +victor steinbeck 34.745999999999995 +victor thompson 18.735 +victor van buren 27.758333333333336 +victor van buren 37.38333333333333 +victor white 24.607999999999997 +victor white 30.66 +victor xylophone 2.775 +victor xylophone 8.356666666666667 +victor xylophone 24.259999999999998 +victor xylophone 25.636666666666667 +victor xylophone 31.610000000000003 +victor young 22.264444444444443 +victor zipper 39.84 +wendy allen 3.4 +wendy allen 24.695000000000004 +wendy allen 29.912 +wendy brown 28.22 +wendy brown 36.74 +wendy ellison 17.549999999999997 +wendy ellison 22.720000000000002 +wendy falkner 13.765 +wendy falkner 24.424444444444443 +wendy falkner 27.86733333333333 +wendy garcia 12.3 +wendy garcia 22.396666666666665 +wendy garcia 26.8325 +wendy garcia 28.596666666666664 +wendy hernandez 21.111428571428572 +wendy ichabod 4.44 +wendy king 23.654285714285713 +wendy king 29.325714285714287 +wendy king 34.21666666666667 +wendy laertes 31.160714285714285 +wendy laertes 31.46666666666667 +wendy laertes 39.22 +wendy miller 12.73 +wendy miller 30.343333333333334 +wendy nixon 19.92714285714286 +wendy nixon 29.675714285714285 +wendy ovid 21.193749999999998 +wendy ovid 28.49846153846154 +wendy polk 20.94 +wendy polk 22.999999999999996 +wendy quirinius 21.05 +wendy quirinius 26.8425 +wendy robinson 8.39 +wendy robinson 24.05 +wendy robinson 26.974285714285713 +wendy steinbeck 26.765 +wendy thompson 24.14 +wendy thompson 28.995384615384616 +wendy underhill 23.118333333333336 +wendy underhill 25.581666666666667 +wendy underhill 32.985 +wendy van buren 25.151666666666667 +wendy van buren 27.077142857142857 +wendy white 24.4025 +wendy xylophone 22.85181818181818 +wendy xylophone 26.96 +wendy young 4.83 +wendy young 21.325 +xavier allen 19.133333333333333 +xavier allen 26.11466666666667 +xavier allen 34.58 +xavier brown 2.63 +xavier brown 24.764285714285712 +xavier brown 30.166666666666668 +xavier carson 29.006666666666664 +xavier carson 32.106 +xavier davidson 14.094999999999999 +xavier davidson 15.906666666666666 +xavier davidson 27.353333333333335 +xavier ellison 22.174166666666668 +xavier ellison 35.01 +xavier garcia 30.357500000000005 +xavier hernandez 19.87 +xavier hernandez 20.805 +xavier hernandez 33.497499999999995 +xavier ichabod 12.34 +xavier ichabod 26.166249999999998 +xavier johnson 20.33222222222222 +xavier johnson 22.503333333333334 +xavier king 1.3 +xavier king 31.348571428571425 +xavier laertes 7.420000000000001 +xavier ovid 25.576 +xavier polk 11.094285714285714 +xavier polk 19.93 +xavier polk 23.63125 +xavier polk 30.194 +xavier quirinius 13.776666666666666 +xavier quirinius 22.27 +xavier quirinius 24.977692307692312 +xavier quirinius 34.95 +xavier thompson 16.47 +xavier underhill 1.31 +xavier white 19.331666666666667 +xavier white 34.68 +xavier xylophone 21.09625 +xavier zipper 14.89 +yuri allen 18.490000000000002 +yuri allen 22.689999999999998 +yuri brown 15.502857142857142 +yuri brown 22.934285714285714 +yuri carson 27.139999999999997 +yuri carson 35.27 +yuri ellison 10.52 +yuri ellison 25.2025 +yuri falkner 24.633076923076924 +yuri falkner 28.52 +yuri garcia 25.545 +yuri hernandez 16.35 +yuri johnson 19.9525 +yuri johnson 27.636000000000003 +yuri johnson 39.92 +yuri king 15.450000000000001 +yuri laertes 0.41000000000000003 +yuri laertes 33.15 +yuri nixon 27.795 +yuri nixon 39.145 +yuri polk 0.8 +yuri polk 9.705 +yuri polk 25.513333333333332 +yuri quirinius 16.29 +yuri quirinius 19.254999999999995 +yuri quirinius 37.878 +yuri steinbeck 27.6275 +yuri steinbeck 48.89 +yuri thompson 23.330000000000002 +yuri underhill 20.504444444444445 +yuri underhill 21.66 +yuri white 31.205 +yuri xylophone 18.790000000000003 +zach allen 13.06 +zach brown 19.985 +zach brown 26.52333333333333 +zach brown 34.66 +zach brown 34.972857142857144 +zach brown 37.45399999999999 +zach carson 26.195999999999998 +zach ellison 17.55 +zach falkner 3.42 +zach falkner 16.18 +zach garcia 20.062 +zach garcia 25.935 +zach garcia 28.974285714285717 +zach garcia 35.449999999999996 +zach ichabod 10.59 +zach ichabod 31.691999999999997 +zach king 6.81 +zach king 20.817 +zach king 32.542500000000004 +zach miller 13.23 +zach miller 26.30666666666667 +zach miller 26.73 +zach ovid 21.122500000000002 +zach ovid 26.983999999999998 +zach ovid 33.15 +zach ovid 40.59 +zach quirinius 13.38 +zach robinson 20.451999999999998 +zach steinbeck 20.358333333333334 +zach steinbeck 29.65 +zach thompson 16.45 +zach thompson 21.430000000000003 +zach underhill 31.438333333333333 +zach white 23.111428571428572 +zach xylophone 21.221428571428568 +zach xylophone 23.156666666666666 +zach young 24.72666666666667 +zach zipper 19.878888888888884 +zach zipper 34.84571428571429 +zach zipper 35.36 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf2-0-96659fde37d7a38ea15b367b47f59ce2 b/sql/hive/src/test/resources/golden/windowing_udaf2-0-96659fde37d7a38ea15b367b47f59ce2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/windowing_udaf2-1-b4bdee4908b1cb8e240c549ae5cfe4c0 b/sql/hive/src/test/resources/golden/windowing_udaf2-1-b4bdee4908b1cb8e240c549ae5cfe4c0 new file mode 100644 index 000000000000..17c31c0f0459 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_udaf2-1-b4bdee4908b1cb8e240c549ae5cfe4c0 @@ -0,0 +1 @@ +130091 130091 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-0-f498cccf82480be03022d2a36f87651e b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-0-f498cccf82480be03022d2a36f87651e new file mode 100644 index 000000000000..31b1f85a5eb5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-0-f498cccf82480be03022d2a36f87651e @@ -0,0 +1,1049 @@ + 4294967354 + 4294967416 + 4294967457 +alice allen 4294967487 +alice allen 4294967525 +alice allen 4294967531 +alice brown 4294967355 +alice carson 4294967370 +alice davidson 4294967517 +alice falkner 4294967316 +alice garcia 4294967369 +alice hernandez 4294967299 +alice hernandez 4294967314 +alice johnson 4294967424 +alice king 4294967387 +alice king 4294967516 +alice king 4294967546 +alice laertes 4294967519 +alice laertes 8589934835 +alice miller 4294967324 +alice nixon 4294967410 +alice nixon 4294967413 +alice nixon 4294967484 +alice ovid 8589934726 +alice polk 4294967366 +alice quirinius 4294967505 +alice quirinius 4294967549 +alice robinson 4294967445 +alice robinson 4294967502 +alice steinbeck 4294967364 +alice steinbeck 4294967474 +alice steinbeck 4294967549 +alice underhill 4294967441 +alice van buren 4294967428 +alice xylophone 4294967363 +alice xylophone 4294967519 +alice xylophone 8589934832 +alice zipper 4294967380 +alice zipper 4294967520 +alice zipper 8589935026 +bob brown 4294967422 +bob brown 4294967427 +bob brown 4294967431 +bob carson 4294967408 +bob davidson 4294967354 +bob davidson 4294967435 +bob davidson 4294967504 +bob ellison 4294967344 +bob ellison 4294967362 +bob ellison 4294967436 +bob ellison 4294967530 +bob falkner 8589934966 +bob garcia 4294967369 +bob garcia 4294967435 +bob garcia 4294967439 +bob garcia 8589934707 +bob garcia 8589934867 +bob hernandez 4294967500 +bob ichabod 4294967424 +bob king 4294967297 +bob king 4294967539 +bob king 8589934870 +bob laertes 4294967380 +bob laertes 4294967472 +bob miller 4294967349 +bob ovid 4294967395 +bob ovid 4294967400 +bob ovid 4294967401 +bob ovid 4294967512 +bob polk 4294967337 +bob quirinius 4294967346 +bob steinbeck 4294967342 +bob van buren 4294967422 +bob white 4294967362 +bob white 4294967493 +bob xylophone 4294967407 +bob xylophone 4294967465 +bob young 4294967413 +bob zipper 4294967299 +bob zipper 8589934723 +bob zipper 8589934840 +calvin allen 12884902208 +calvin brown 4294967411 +calvin brown 4294967437 +calvin brown 4294967530 +calvin carson 8589934876 +calvin davidson 4294967468 +calvin davidson 8589934837 +calvin ellison 4294967480 +calvin falkner 4294967300 +calvin falkner 4294967305 +calvin falkner 4294967345 +calvin falkner 8589934749 +calvin falkner 8589934840 +calvin falkner 8589934978 +calvin garcia 8589934927 +calvin hernandez 12884902173 +calvin johnson 4294967546 +calvin laertes 4294967431 +calvin laertes 4294967499 +calvin nixon 4294967300 +calvin nixon 4294967412 +calvin nixon 4294967488 +calvin ovid 4294967329 +calvin ovid 4294967349 +calvin ovid 8589934723 +calvin ovid 8589934835 +calvin polk 8589934962 +calvin quirinius 4294967521 +calvin quirinius 4294967532 +calvin robinson 4294967326 +calvin steinbeck 4294967474 +calvin steinbeck 4294967505 +calvin steinbeck 8589934722 +calvin thompson 4294967513 +calvin thompson 8589934700 +calvin underhill 4294967478 +calvin van buren 4294967300 +calvin van buren 4294967508 +calvin white 4294967304 +calvin white 8589934924 +calvin xylophone 4294967376 +calvin xylophone 8589934599 +calvin xylophone 8589934866 +calvin young 4294967342 +calvin young 8589934817 +calvin zipper 12884902359 +calvin zipper 17179869649 +david allen 4294967371 +david allen 4294967381 +david brown 8589934762 +david brown 12884902420 +david davidson 4294967522 +david davidson 8589934819 +david davidson 12884902188 +david davidson 12884902327 +david ellison 4294967463 +david ellison 8589934777 +david ellison 12884902263 +david hernandez 4294967324 +david ichabod 4294967487 +david ichabod 12884902220 +david laertes 12884902107 +david nixon 4294967381 +david ovid 4294967396 +david ovid 4294967443 +david quirinius 4294967457 +david quirinius 4294967530 +david quirinius 12884902194 +david robinson 4294967465 +david robinson 17179869575 +david thompson 4294967361 +david underhill 4294967384 +david underhill 8589934942 +david underhill 12884902357 +david van buren 4294967309 +david van buren 8589934901 +david white 4294967428 +david xylophone 4294967479 +david xylophone 4294967480 +david xylophone 8589934856 +david young 4294967296 +david young 4294967305 +ethan allen 4294967351 +ethan brown 4294967320 +ethan brown 4294967331 +ethan brown 4294967403 +ethan brown 4294967420 +ethan brown 8589934797 +ethan brown 8589934805 +ethan carson 4294967352 +ethan ellison 4294967514 +ethan ellison 8589934887 +ethan falkner 4294967318 +ethan falkner 4294967461 +ethan garcia 4294967310 +ethan hernandez 4294967349 +ethan johnson 8589934738 +ethan king 8589934731 +ethan laertes 4294967422 +ethan laertes 4294967531 +ethan laertes 8589934767 +ethan laertes 8589934806 +ethan laertes 8589934830 +ethan laertes 8589934995 +ethan laertes 12884902063 +ethan miller 4294967352 +ethan nixon 8589935019 +ethan ovid 8589934909 +ethan polk 4294967329 +ethan polk 4294967382 +ethan polk 4294967479 +ethan polk 8589935021 +ethan quirinius 4294967348 +ethan quirinius 4294967501 +ethan quirinius 8589934695 +ethan robinson 4294967353 +ethan robinson 8589935019 +ethan underhill 8589934897 +ethan van buren 4294967511 +ethan white 4294967427 +ethan white 8589934975 +ethan xylophone 8589934956 +ethan zipper 4294967462 +ethan zipper 12884902348 +fred davidson 8589934724 +fred davidson 8589934850 +fred davidson 12884902468 +fred ellison 4294967393 +fred ellison 8589934797 +fred ellison 8589934978 +fred falkner 4294967547 +fred falkner 12884902170 +fred falkner 17179869760 +fred hernandez 8589934833 +fred ichabod 8589934853 +fred ichabod 12884902455 +fred johnson 8589934904 +fred king 8589934651 +fred king 8589934951 +fred laertes 8589934883 +fred miller 12884902228 +fred nixon 4294967297 +fred nixon 4294967375 +fred nixon 4294967514 +fred nixon 12884902182 +fred polk 4294967332 +fred polk 4294967458 +fred polk 4294967507 +fred polk 8589934944 +fred quirinius 8589934894 +fred quirinius 12884902335 +fred robinson 8589934904 +fred steinbeck 4294967329 +fred steinbeck 4294967411 +fred steinbeck 4294967472 +fred underhill 4294967387 +fred van buren 8589934830 +fred van buren 12884902319 +fred van buren 12884902382 +fred van buren 17179869836 +fred white 8589934763 +fred young 4294967485 +fred young 8589934832 +fred zipper 12884902371 +gabriella allen 4294967405 +gabriella allen 12884902509 +gabriella brown 4294967403 +gabriella brown 4294967543 +gabriella carson 8589934950 +gabriella davidson 4294967507 +gabriella ellison 4294967393 +gabriella ellison 12884902284 +gabriella falkner 4294967378 +gabriella falkner 4294967523 +gabriella falkner 12884902338 +gabriella garcia 4294967419 +gabriella hernandez 4294967462 +gabriella hernandez 4294967481 +gabriella ichabod 4294967337 +gabriella ichabod 8589934740 +gabriella ichabod 8589934797 +gabriella ichabod 8589934818 +gabriella ichabod 17179869508 +gabriella king 4294967393 +gabriella king 8589934906 +gabriella laertes 4294967410 +gabriella miller 8589934768 +gabriella ovid 4294967522 +gabriella ovid 8589934895 +gabriella polk 4294967302 +gabriella polk 8589934868 +gabriella steinbeck 4294967435 +gabriella steinbeck 4294967500 +gabriella thompson 4294967412 +gabriella thompson 8589934814 +gabriella thompson 12884902318 +gabriella van buren 4294967470 +gabriella van buren 8589934783 +gabriella white 4294967335 +gabriella young 4294967431 +gabriella young 8589934980 +gabriella zipper 4294967510 +gabriella zipper 8589934792 +holly allen 12884901926 +holly brown 8589934722 +holly brown 8589934857 +holly falkner 8589934849 +holly hernandez 8589934749 +holly hernandez 8589934805 +holly hernandez 8589935056 +holly hernandez 12884902485 +holly ichabod 4294967329 +holly ichabod 8589934754 +holly ichabod 8589934981 +holly johnson 4294967535 +holly johnson 12884902194 +holly johnson 17179869874 +holly king 8589934785 +holly king 8589934939 +holly laertes 12884902333 +holly miller 8589934823 +holly nixon 4294967383 +holly nixon 8589934744 +holly polk 4294967434 +holly polk 8589934782 +holly robinson 12884902369 +holly thompson 4294967339 +holly thompson 12884902395 +holly thompson 17179869547 +holly underhill 8589934913 +holly underhill 8589934924 +holly underhill 12884902376 +holly underhill 12884902412 +holly van buren 4294967539 +holly white 17179869548 +holly white 17179869900 +holly xylophone 8589934846 +holly young 4294967500 +holly young 8589934932 +holly zipper 4294967509 +holly zipper 17179869531 +irene allen 12884902413 +irene brown 4294967428 +irene brown 8589934934 +irene brown 12884902207 +irene carson 8589934797 +irene ellison 8589934732 +irene ellison 8589934773 +irene falkner 4294967404 +irene falkner 4294967548 +irene garcia 4294967323 +irene garcia 8589934887 +irene garcia 12884902479 +irene ichabod 4294967509 +irene ichabod 8589934860 +irene johnson 8589934990 +irene laertes 4294967481 +irene laertes 12884902196 +irene laertes 17179869632 +irene miller 4294967387 +irene nixon 4294967538 +irene nixon 12884902129 +irene nixon 12884902324 +irene ovid 8589934764 +irene ovid 8589934886 +irene ovid 8589934903 +irene polk 4294967465 +irene polk 4294967521 +irene polk 8589934672 +irene polk 8589934842 +irene polk 17179869877 +irene quirinius 8589934875 +irene quirinius 12884902269 +irene quirinius 17179869628 +irene robinson 8589934676 +irene steinbeck 4294967549 +irene thompson 4294967479 +irene underhill 8589934694 +irene underhill 12884902077 +irene van buren 8589934932 +irene van buren 12884902202 +irene xylophone 8589934901 +jessica brown 8589934867 +jessica carson 4294967508 +jessica carson 8589934740 +jessica carson 17179869819 +jessica davidson 4294967384 +jessica davidson 8589934864 +jessica davidson 12884902256 +jessica davidson 12884902321 +jessica ellison 4294967316 +jessica ellison 12884902128 +jessica falkner 8589934980 +jessica garcia 4294967540 +jessica garcia 21474837337 +jessica ichabod 8589934816 +jessica johnson 8589935006 +jessica johnson 12884902222 +jessica miller 8589934898 +jessica nixon 8589934742 +jessica nixon 12884902240 +jessica ovid 8589934830 +jessica ovid 12884902307 +jessica polk 21474837163 +jessica quirinius 8589934701 +jessica quirinius 8589934872 +jessica quirinius 12884902159 +jessica quirinius 12884902276 +jessica robinson 4294967542 +jessica thompson 8589934698 +jessica thompson 12884902232 +jessica underhill 8589934810 +jessica underhill 8589934878 +jessica underhill 17179869479 +jessica van buren 8589934726 +jessica white 12884902155 +jessica white 12884902281 +jessica white 12884902296 +jessica white 12884902314 +jessica white 17179869676 +jessica xylophone 17179869697 +jessica young 17179869859 +jessica young 17179869861 +jessica zipper 4294967372 +jessica zipper 8589934727 +jessica zipper 17179869778 +katie allen 8589934791 +katie brown 17179869660 +katie davidson 12884902181 +katie ellison 12884902184 +katie ellison 12884902355 +katie falkner 8589934911 +katie garcia 8589934683 +katie garcia 12884902046 +katie hernandez 8589934812 +katie ichabod 8589934795 +katie ichabod 8589934862 +katie ichabod 8589934869 +katie king 4294967339 +katie king 4294967421 +katie king 8589934826 +katie miller 8589934829 +katie miller 12884902267 +katie nixon 21474837149 +katie ovid 4294967519 +katie polk 8589934726 +katie polk 12884902291 +katie robinson 17179869645 +katie van buren 8589934722 +katie van buren 17179869441 +katie white 4294967306 +katie white 8589934885 +katie xylophone 12884902193 +katie young 8589934819 +katie young 8589935024 +katie young 12884902058 +katie zipper 4294967354 +katie zipper 12884902310 +luke allen 8589934864 +luke allen 8589934931 +luke allen 8589935059 +luke allen 12884902257 +luke allen 12884902322 +luke brown 8589934779 +luke davidson 4294967354 +luke davidson 12884902360 +luke ellison 12884902183 +luke ellison 21474836998 +luke ellison 21474837060 +luke falkner 8589934772 +luke falkner 17179869561 +luke garcia 4294967304 +luke garcia 21474837157 +luke ichabod 12884902150 +luke ichabod 12884902366 +luke johnson 4294967527 +luke johnson 8589934812 +luke johnson 12884902161 +luke laertes 8589935027 +luke laertes 12884902031 +luke laertes 12884902184 +luke laertes 12884902213 +luke laertes 12884902378 +luke miller 8589934826 +luke ovid 4294967492 +luke ovid 8589934913 +luke polk 8589934837 +luke polk 12884902340 +luke quirinius 8589934855 +luke robinson 4294967307 +luke robinson 17179869711 +luke thompson 4294967521 +luke underhill 8589934829 +luke underhill 12884902299 +luke underhill 21474837138 +luke van buren 8589934852 +luke white 12884902418 +luke xylophone 8589934804 +luke zipper 4294967353 +mike allen 17179869750 +mike brown 17179869735 +mike carson 4294967477 +mike carson 8589934803 +mike carson 17179869855 +mike davidson 12884902377 +mike davidson 17179869841 +mike ellison 8589934833 +mike ellison 12884902165 +mike ellison 12884902513 +mike ellison 17179869587 +mike ellison 17179869824 +mike falkner 4294967301 +mike garcia 4294967398 +mike garcia 8589934800 +mike garcia 12884902292 +mike hernandez 8589934824 +mike hernandez 12884902281 +mike ichabod 4294967494 +mike king 4294967347 +mike king 4294967400 +mike king 12884902363 +mike king 12884902475 +mike king 17179869528 +mike king 17179869592 +mike miller 17179869705 +mike nixon 12884902293 +mike nixon 17179869708 +mike polk 17179869752 +mike polk 21474837097 +mike polk 21474837344 +mike quirinius 12884902240 +mike steinbeck 8589934653 +mike steinbeck 12884902273 +mike steinbeck 12884902301 +mike steinbeck 17179869903 +mike van buren 8589934942 +mike van buren 12884902402 +mike white 12884902485 +mike white 17179869676 +mike white 21474836928 +mike white 25769804626 +mike young 8589934704 +mike young 8589934878 +mike young 17179869685 +mike zipper 4294967501 +mike zipper 17179869582 +mike zipper 25769804400 +nick allen 8589934664 +nick allen 8589934860 +nick brown 21474836962 +nick davidson 4294967357 +nick ellison 12884902066 +nick ellison 17179869779 +nick falkner 8589935020 +nick falkner 12884902433 +nick garcia 8589934885 +nick garcia 17179869635 +nick garcia 17179869681 +nick ichabod 12884902193 +nick ichabod 12884902223 +nick ichabod 12884902252 +nick johnson 17179869591 +nick johnson 17179869702 +nick laertes 8589934919 +nick miller 12884902419 +nick nixon 8589934910 +nick ovid 12884902267 +nick polk 17179869712 +nick quirinius 4294967296 +nick quirinius 12884902183 +nick robinson 17179869506 +nick robinson 17179869731 +nick steinbeck 4294967355 +nick thompson 8589934922 +nick underhill 25769804624 +nick van buren 8589934635 +nick xylophone 12884902279 +nick young 12884902399 +nick young 21474837140 +nick zipper 12884902300 +nick zipper 17179869849 +oscar allen 17179869779 +oscar brown 12884902062 +oscar carson 12884902232 +oscar carson 17179869663 +oscar carson 17179869779 +oscar carson 21474837066 +oscar carson 21474837089 +oscar davidson 17179869895 +oscar ellison 4294967304 +oscar ellison 8589934740 +oscar falkner 4294967526 +oscar garcia 21474837156 +oscar hernandez 4294967343 +oscar hernandez 8589935049 +oscar ichabod 8589934837 +oscar ichabod 21474836952 +oscar ichabod 21474837021 +oscar ichabod 25769804491 +oscar johnson 12884902182 +oscar johnson 30064772044 +oscar king 12884902159 +oscar king 17179869738 +oscar king 17179869834 +oscar laertes 4294967550 +oscar laertes 8589934727 +oscar laertes 12884902043 +oscar laertes 12884902478 +oscar nixon 17179869458 +oscar ovid 12884902128 +oscar ovid 12884902240 +oscar ovid 25769804460 +oscar polk 21474836829 +oscar polk 21474837063 +oscar quirinius 8589934728 +oscar quirinius 17179869698 +oscar quirinius 21474837051 +oscar quirinius 25769804521 +oscar robinson 8589934656 +oscar robinson 12884902249 +oscar robinson 21474837105 +oscar robinson 25769804694 +oscar steinbeck 4294967548 +oscar thompson 8589934776 +oscar thompson 12884902164 +oscar thompson 12884902317 +oscar thompson 17179869884 +oscar underhill 8589934895 +oscar van buren 4294967500 +oscar van buren 8589934984 +oscar van buren 21474837205 +oscar white 4294967454 +oscar white 8589934826 +oscar white 21474836931 +oscar white 21474837305 +oscar xylophone 12884902193 +oscar xylophone 12884902307 +oscar xylophone 17179869593 +oscar zipper 8589934865 +oscar zipper 8589934874 +oscar zipper 8589934911 +priscilla brown 8589934848 +priscilla brown 8589935013 +priscilla brown 17179869801 +priscilla carson 12884902145 +priscilla carson 21474836880 +priscilla carson 30064772126 +priscilla ichabod 4294967547 +priscilla ichabod 17179869756 +priscilla johnson 4294967468 +priscilla johnson 8589934667 +priscilla johnson 17179869667 +priscilla johnson 17179869787 +priscilla johnson 25769804279 +priscilla king 12884902153 +priscilla nixon 12884902188 +priscilla nixon 25769804766 +priscilla ovid 12884902234 +priscilla ovid 30064772049 +priscilla polk 17179869480 +priscilla quirinius 12884902171 +priscilla thompson 25769804637 +priscilla underhill 4294967333 +priscilla underhill 17179869740 +priscilla van buren 12884902324 +priscilla van buren 21474837167 +priscilla van buren 21474837343 +priscilla white 4294967419 +priscilla xylophone 8589934792 +priscilla xylophone 12884902245 +priscilla xylophone 12884902287 +priscilla young 21474836992 +priscilla young 34359739656 +priscilla zipper 12884902296 +priscilla zipper 12884902537 +quinn allen 4294967542 +quinn allen 17179869552 +quinn brown 12884902251 +quinn brown 17179869401 +quinn brown 17179869626 +quinn davidson 8589934992 +quinn davidson 17179869690 +quinn davidson 25769804455 +quinn davidson 30064771771 +quinn ellison 12884902376 +quinn ellison 34359739559 +quinn garcia 8589934828 +quinn garcia 12884902387 +quinn garcia 12884902460 +quinn garcia 21474837066 +quinn ichabod 30064772171 +quinn king 4294967458 +quinn king 4294967538 +quinn laertes 8589935080 +quinn laertes 17179869711 +quinn laertes 21474837142 +quinn nixon 17179869672 +quinn ovid 17179869695 +quinn quirinius 21474836827 +quinn robinson 12884902445 +quinn steinbeck 17179869739 +quinn steinbeck 21474836905 +quinn thompson 17179869645 +quinn thompson 25769804317 +quinn underhill 8589934815 +quinn underhill 12884902185 +quinn underhill 30064771762 +quinn van buren 4294967362 +quinn young 8589934731 +quinn zipper 12884902453 +quinn zipper 17179869841 +rachel allen 8589934882 +rachel allen 12884902208 +rachel brown 8589934768 +rachel brown 12884902075 +rachel brown 17179869910 +rachel brown 17179869911 +rachel brown 21474837280 +rachel carson 8589934728 +rachel carson 17179869970 +rachel davidson 30064771666 +rachel ellison 4294967423 +rachel falkner 4294967348 +rachel falkner 12884902482 +rachel falkner 21474837331 +rachel falkner 25769804739 +rachel johnson 38654707197 +rachel king 12884902157 +rachel king 30064771759 +rachel laertes 17179869678 +rachel laertes 25769804379 +rachel ovid 12884902055 +rachel ovid 17179869857 +rachel polk 12884902391 +rachel quirinius 17179869456 +rachel robinson 17179869499 +rachel robinson 17179869703 +rachel robinson 25769804290 +rachel thompson 17179869910 +rachel thompson 21474836989 +rachel thompson 21474837392 +rachel underhill 8589934862 +rachel white 17179869585 +rachel white 21474837039 +rachel young 17179869708 +rachel zipper 4294967434 +rachel zipper 21474837228 +sarah carson 4294967319 +sarah carson 17179869688 +sarah carson 30064772084 +sarah ellison 4294967542 +sarah falkner 17179869797 +sarah falkner 21474837349 +sarah garcia 8589934733 +sarah garcia 8589934858 +sarah garcia 17179869599 +sarah ichabod 12884902196 +sarah ichabod 12884902401 +sarah johnson 12884902455 +sarah johnson 21474836981 +sarah johnson 21474837145 +sarah johnson 25769804480 +sarah king 12884902453 +sarah king 21474837191 +sarah miller 8589934958 +sarah ovid 21474837184 +sarah robinson 21474837237 +sarah robinson 21474837389 +sarah steinbeck 21474837313 +sarah white 17179869905 +sarah white 25769804341 +sarah xylophone 12884902207 +sarah young 21474837319 +sarah zipper 25769804616 +tom brown 8589934894 +tom brown 21474837024 +tom carson 4294967388 +tom carson 12884902278 +tom carson 21474836983 +tom davidson 8589934895 +tom ellison 12884902192 +tom ellison 17179869965 +tom ellison 25769804262 +tom falkner 12884902272 +tom falkner 17179869815 +tom hernandez 4294967296 +tom hernandez 12884902109 +tom ichabod 17179869628 +tom johnson 25769804829 +tom johnson 30064771891 +tom king 12884902390 +tom laertes 12884902181 +tom laertes 12884902236 +tom miller 12884901992 +tom miller 17179869647 +tom miller 21474837107 +tom nixon 17179869677 +tom ovid 12884902279 +tom polk 8589934748 +tom polk 8589934892 +tom quirinius 12884902174 +tom quirinius 21474836986 +tom robinson 8589934753 +tom robinson 12884902203 +tom robinson 12884902358 +tom robinson 21474836952 +tom steinbeck 8589934912 +tom van buren 8589934823 +tom van buren 12884902122 +tom van buren 25769804641 +tom white 21474837076 +tom young 4294967535 +tom young 21474837038 +tom zipper 30064772355 +ulysses brown 8589934991 +ulysses carson 8589934789 +ulysses carson 21474837258 +ulysses carson 25769804457 +ulysses carson 34359739082 +ulysses davidson 12884902216 +ulysses ellison 17179869551 +ulysses garcia 12884902382 +ulysses hernandez 12884902210 +ulysses hernandez 12884902276 +ulysses hernandez 17179869748 +ulysses ichabod 4294967353 +ulysses ichabod 12884902217 +ulysses johnson 21474837122 +ulysses king 8589934995 +ulysses laertes 8589934801 +ulysses laertes 21474837354 +ulysses laertes 25769804499 +ulysses miller 21474837284 +ulysses miller 30064771926 +ulysses nixon 17179869288 +ulysses ovid 17179869754 +ulysses polk 8589934855 +ulysses polk 8589934862 +ulysses polk 12884902420 +ulysses polk 17179869479 +ulysses quirinius 17179869659 +ulysses robinson 4294967531 +ulysses steinbeck 8589935027 +ulysses steinbeck 21474837100 +ulysses thompson 12884902194 +ulysses underhill 8589934760 +ulysses underhill 8589934799 +ulysses underhill 12884902240 +ulysses underhill 17179869759 +ulysses underhill 17179869760 +ulysses underhill 17179869939 +ulysses underhill 21474837264 +ulysses van buren 8589934938 +ulysses white 25769804453 +ulysses white 30064772086 +ulysses xylophone 8589935029 +ulysses xylophone 12884902249 +ulysses xylophone 25769804765 +ulysses young 4294967427 +ulysses young 17179869391 +ulysses young 30064771844 +victor allen 8589934793 +victor allen 12884902264 +victor brown 4294967455 +victor brown 17179869657 +victor brown 21474837426 +victor brown 30064771922 +victor davidson 17179869715 +victor davidson 17179869872 +victor davidson 25769804287 +victor ellison 17179869611 +victor ellison 17179869709 +victor hernandez 8589934847 +victor hernandez 12884902463 +victor hernandez 17179869647 +victor hernandez 17179869720 +victor hernandez 25769804310 +victor johnson 17179869652 +victor johnson 21474837148 +victor johnson 25769804771 +victor king 8589934917 +victor king 25769804714 +victor laertes 12884902188 +victor laertes 21474837186 +victor miller 21474837170 +victor nixon 8589934778 +victor nixon 12884902261 +victor ovid 12884902350 +victor polk 17179869376 +victor quirinius 21474837074 +victor quirinius 21474837279 +victor robinson 21474836948 +victor robinson 21474837097 +victor steinbeck 12884902162 +victor steinbeck 17179869721 +victor steinbeck 21474836916 +victor thompson 25769804395 +victor van buren 21474837010 +victor van buren 25769804601 +victor white 8589934816 +victor white 30064771798 +victor xylophone 17179869560 +victor xylophone 25769804719 +victor xylophone 25769804760 +victor xylophone 34359739093 +victor xylophone 34359739095 +victor young 21474837052 +victor zipper 12884902345 +wendy allen 21474837127 +wendy allen 25769804525 +wendy allen 25769804732 +wendy brown 12884902342 +wendy brown 21474836889 +wendy ellison 12884902392 +wendy ellison 21474836763 +wendy falkner 8589934926 +wendy falkner 17179869470 +wendy falkner 25769804816 +wendy garcia 17179869439 +wendy garcia 17179869732 +wendy garcia 30064771654 +wendy garcia 30064771704 +wendy hernandez 17179869752 +wendy ichabod 17179869547 +wendy king 17179869612 +wendy king 21474837301 +wendy king 30064772042 +wendy laertes 8589934872 +wendy laertes 12884902469 +wendy laertes 21474837084 +wendy miller 17179869661 +wendy miller 17179869682 +wendy nixon 12884902521 +wendy nixon 21474836846 +wendy ovid 21474837025 +wendy ovid 38654706512 +wendy polk 8589934960 +wendy polk 21474837144 +wendy quirinius 12884902263 +wendy quirinius 17179869652 +wendy robinson 21474837104 +wendy robinson 25769804321 +wendy robinson 25769804728 +wendy steinbeck 12884902299 +wendy thompson 17179869494 +wendy thompson 21474837072 +wendy underhill 17179869898 +wendy underhill 21474837064 +wendy underhill 25769804845 +wendy van buren 25769804447 +wendy van buren 25769804679 +wendy white 17179869866 +wendy xylophone 17179869596 +wendy xylophone 25769804554 +wendy young 4294967313 +wendy young 25769804562 +xavier allen 12884902364 +xavier allen 17179869960 +xavier allen 21474836864 +xavier brown 8589934824 +xavier brown 17179869646 +xavier brown 25769804653 +xavier carson 17179869770 +xavier carson 21474837445 +xavier davidson 30064772118 +xavier davidson 34359739403 +xavier davidson 38654706539 +xavier ellison 34359739490 +xavier ellison 34359739559 +xavier garcia 21474837142 +xavier hernandez 21474837012 +xavier hernandez 25769804421 +xavier hernandez 38654707021 +xavier ichabod 12884902315 +xavier ichabod 17179869567 +xavier johnson 8589934922 +xavier johnson 38654707066 +xavier king 12884902272 +xavier king 21474836962 +xavier laertes 17179869795 +xavier ovid 17179869597 +xavier polk 12884902254 +xavier polk 17179869581 +xavier polk 17179869743 +xavier polk 34359739344 +xavier quirinius 12884902240 +xavier quirinius 21474836996 +xavier quirinius 25769804437 +xavier quirinius 25769804456 +xavier thompson 17179869822 +xavier underhill 8589934813 +xavier white 12884902262 +xavier white 12884902366 +xavier xylophone 17179869722 +xavier zipper 12884902377 +yuri allen 8589935035 +yuri allen 12884902279 +yuri brown 8589934912 +yuri brown 12884902319 +yuri carson 21474837146 +yuri carson 25769804245 +yuri ellison 25769804504 +yuri ellison 25769804568 +yuri falkner 25769804699 +yuri falkner 42949674720 +yuri garcia 4294967362 +yuri hernandez 21474837117 +yuri johnson 21474837002 +yuri johnson 21474837165 +yuri johnson 25769804545 +yuri king 30064772090 +yuri laertes 30064772076 +yuri laertes 34359739328 +yuri nixon 12884902232 +yuri nixon 12884902265 +yuri polk 12884902362 +yuri polk 21474837245 +yuri polk 25769804539 +yuri quirinius 12884902198 +yuri quirinius 17179869606 +yuri quirinius 30064771819 +yuri steinbeck 4294967535 +yuri steinbeck 8589934657 +yuri thompson 12884902467 +yuri underhill 17179869566 +yuri underhill 17179869715 +yuri white 34359739045 +yuri xylophone 12884902412 +zach allen 17179869908 +zach brown 21474836879 +zach brown 21474836891 +zach brown 21474837040 +zach brown 21474837073 +zach brown 30064771852 +zach carson 21474837185 +zach ellison 8589934898 +zach falkner 17179869807 +zach falkner 25769804634 +zach garcia 17179869536 +zach garcia 21474837142 +zach garcia 30064772246 +zach garcia 34359739192 +zach ichabod 17179869613 +zach ichabod 17179869838 +zach king 17179869700 +zach king 21474837427 +zach king 34359739578 +zach miller 4294967391 +zach miller 12884902310 +zach miller 17179869709 +zach ovid 17179869731 +zach ovid 21474837032 +zach ovid 21474837127 +zach ovid 30064771625 +zach quirinius 34359739151 +zach robinson 21474836938 +zach steinbeck 17179869667 +zach steinbeck 25769804623 +zach thompson 12884902354 +zach thompson 17179869659 +zach underhill 12884902149 +zach white 25769804490 +zach xylophone 12884902198 +zach xylophone 21474837163 +zach young 17179869687 +zach zipper 17179869708 +zach zipper 17179869834 +zach zipper 21474837369 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-1-6378faf36ffd3f61e61cee6c0cb70e6 b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-1-6378faf36ffd3f61e61cee6c0cb70e6 new file mode 100644 index 000000000000..1436509e4ec1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-1-6378faf36ffd3f61e61cee6c0cb70e6 @@ -0,0 +1,1049 @@ + 9.220000267028809 + 43.72999954223633 + 89.52999877929688 +alice allen 2.7899999618530273 +alice allen 21.450000762939453 +alice allen 73.62999725341797 +alice brown 71.30999755859375 +alice carson 39.029998779296875 +alice davidson 70.3499984741211 +alice falkner 90.25 +alice garcia 48.45000076293945 +alice hernandez 88.16999816894531 +alice hernandez 90.55999755859375 +alice johnson 4.46999979019165 +alice king 19.139999389648438 +alice king 23.170000076293945 +alice king 52.22999954223633 +alice laertes 68.94999694824219 +alice laertes 69.52999877929688 +alice miller 68.95999908447266 +alice nixon 40.0 +alice nixon 48.150001525878906 +alice nixon 79.83000183105469 +alice ovid 9.039999961853027 +alice polk 62.900001525878906 +alice quirinius 37.13999938964844 +alice quirinius 62.29999923706055 +alice robinson 5.079999923706055 +alice robinson 56.099998474121094 +alice steinbeck 38.619998931884766 +alice steinbeck 55.5099983215332 +alice steinbeck 92.37000274658203 +alice underhill 98.18000030517578 +alice van buren 38.939998626708984 +alice xylophone 33.58000183105469 +alice xylophone 43.15999984741211 +alice xylophone 78.20999908447266 +alice zipper 26.43000030517578 +alice zipper 42.47999954223633 +alice zipper 89.93000030517578 +bob brown 8.069999694824219 +bob brown 70.93000030517578 +bob brown 93.08999633789062 +bob carson 50.09000015258789 +bob davidson 1.2899999618530273 +bob davidson 53.93000030517578 +bob davidson 74.72000122070312 +bob ellison 41.34000015258789 +bob ellison 65.0199966430664 +bob ellison 75.02999877929688 +bob ellison 80.30000305175781 +bob falkner 16.989999771118164 +bob garcia 4.460000038146973 +bob garcia 5.400000095367432 +bob garcia 45.59000015258789 +bob garcia 80.30000305175781 +bob garcia 87.56999969482422 +bob hernandez 22.68000030517578 +bob ichabod 82.55999755859375 +bob king 8.789999961853027 +bob king 12.539999961853027 +bob king 39.0099983215332 +bob laertes 0.7900000214576721 +bob laertes 10.670000076293945 +bob miller 61.91999816894531 +bob ovid 46.86000061035156 +bob ovid 62.849998474121094 +bob ovid 88.77999877929688 +bob ovid 97.08999633789062 +bob polk 7.980000019073486 +bob quirinius 46.099998474121094 +bob steinbeck 9.699999809265137 +bob van buren 33.66999816894531 +bob white 45.34000015258789 +bob white 45.349998474121094 +bob xylophone 19.690000534057617 +bob xylophone 107.93000221252441 +bob young 35.16999816894531 +bob zipper 1.25 +bob zipper 3.819999933242798 +bob zipper 34.349998474121094 +calvin allen 63.119998931884766 +calvin brown 28.110000610351562 +calvin brown 85.9000015258789 +calvin brown 90.19999694824219 +calvin carson 80.2300033569336 +calvin davidson 31.770000457763672 +calvin davidson 85.51000213623047 +calvin ellison 26.489999771118164 +calvin falkner 2.9700000286102295 +calvin falkner 56.040000915527344 +calvin falkner 56.33000183105469 +calvin falkner 80.5999984741211 +calvin falkner 93.61000061035156 +calvin falkner 94.30999755859375 +calvin garcia 41.849998474121094 +calvin hernandez 33.869998931884766 +calvin johnson 66.61000061035156 +calvin laertes 23.1299991607666 +calvin laertes 50.310001373291016 +calvin nixon 9.8100004196167 +calvin nixon 41.20000076293945 +calvin nixon 69.73999786376953 +calvin ovid 69.95999908447266 +calvin ovid 71.26000213623047 +calvin ovid 79.12000274658203 +calvin ovid 84.72000122070312 +calvin polk 65.72000122070312 +calvin quirinius 29.540000915527344 +calvin quirinius 53.02000045776367 +calvin robinson 40.439998626708984 +calvin steinbeck 15.220000267028809 +calvin steinbeck 22.850000381469727 +calvin steinbeck 93.30000305175781 +calvin thompson 8.90999984741211 +calvin thompson 93.7300033569336 +calvin underhill 59.70000076293945 +calvin van buren 34.209999084472656 +calvin van buren 64.0 +calvin white 50.279998779296875 +calvin white 90.69000244140625 +calvin xylophone 21.700000762939453 +calvin xylophone 25.420000076293945 +calvin xylophone 56.810001373291016 +calvin young 24.489999771118164 +calvin young 39.810001373291016 +calvin zipper 9.1899995803833 +calvin zipper 95.37999725341797 +david allen 51.25 +david allen 64.87000274658203 +david brown 3.2100000381469727 +david brown 93.63999938964844 +david davidson 1.0800000429153442 +david davidson 62.720001220703125 +david davidson 74.1500015258789 +david davidson 95.80999755859375 +david ellison 47.689998626708984 +david ellison 85.2300033569336 +david ellison 94.1500015258789 +david hernandez 99.91000366210938 +david ichabod 5.28000020980835 +david ichabod 82.55000305175781 +david laertes 76.70999908447266 +david nixon 50.31999969482422 +david ovid 25.110000610351562 +david ovid 61.70000076293945 +david quirinius 20.639999389648438 +david quirinius 29.239999771118164 +david quirinius 79.97000122070312 +david robinson 25.280000686645508 +david robinson 161.1199951171875 +david thompson 80.89999771118164 +david underhill 8.319999694824219 +david underhill 88.7699966430664 +david underhill 97.55999755859375 +david van buren 83.56999969482422 +david van buren 115.43999862670898 +david white 45.189998626708984 +david xylophone 8.069999694824219 +david xylophone 54.34000015258789 +david xylophone 72.9800033569336 +david young 10.25 +david young 35.650001525878906 +ethan allen 32.75 +ethan brown 7.110000133514404 +ethan brown 10.09000015258789 +ethan brown 15.630000114440918 +ethan brown 61.86000061035156 +ethan brown 73.18000030517578 +ethan brown 82.30000305175781 +ethan carson 76.33000183105469 +ethan ellison 0.2800000011920929 +ethan ellison 81.47000122070312 +ethan falkner 50.02000045776367 +ethan falkner 59.43000030517578 +ethan garcia 43.189998626708984 +ethan hernandez 49.779998779296875 +ethan johnson 90.05000305175781 +ethan king 4.349999904632568 +ethan laertes 15.449999809265137 +ethan laertes 54.75 +ethan laertes 59.209999084472656 +ethan laertes 70.38999938964844 +ethan laertes 80.70999908447266 +ethan laertes 95.06999969482422 +ethan laertes 96.29000091552734 +ethan miller 25.3700008392334 +ethan nixon 37.779998779296875 +ethan ovid 57.290000915527344 +ethan polk 2.3499999046325684 +ethan polk 21.31999969482422 +ethan polk 23.440000534057617 +ethan polk 122.71999740600586 +ethan quirinius 3.859999895095825 +ethan quirinius 51.84000015258789 +ethan quirinius 97.23999786376953 +ethan robinson 67.94000244140625 +ethan robinson 78.62000274658203 +ethan underhill 55.630001068115234 +ethan van buren 36.70000076293945 +ethan white 60.849998474121094 +ethan white 63.41999816894531 +ethan xylophone 57.11000061035156 +ethan zipper 2.9200000762939453 +ethan zipper 97.51000213623047 +fred davidson 18.860000610351562 +fred davidson 37.2400016784668 +fred davidson 78.30999755859375 +fred ellison 31.179998874664307 +fred ellison 48.59000015258789 +fred ellison 96.77999877929688 +fred falkner 10.289999961853027 +fred falkner 72.04000091552734 +fred falkner 85.0 +fred hernandez 55.9900016784668 +fred ichabod 47.359999656677246 +fred ichabod 81.31999969482422 +fred johnson 96.08999633789062 +fred king 48.369998931884766 +fred king 72.13999843597412 +fred laertes 57.63999938964844 +fred miller 46.970001220703125 +fred nixon 28.690000534057617 +fred nixon 38.04999923706055 +fred nixon 70.5199966430664 +fred nixon 93.02999877929688 +fred polk 23.959999084472656 +fred polk 39.18000030517578 +fred polk 47.31999969482422 +fred polk 90.12000274658203 +fred quirinius 15.300000190734863 +fred quirinius 29.399999618530273 +fred robinson 89.02999877929688 +fred steinbeck 32.22999954223633 +fred steinbeck 41.310001373291016 +fred steinbeck 91.05000305175781 +fred underhill 90.7699966430664 +fred van buren 1.0199999809265137 +fred van buren 21.940000534057617 +fred van buren 52.869998931884766 +fred van buren 83.58000183105469 +fred white 37.79999923706055 +fred young 46.79999923706055 +fred young 97.70999908447266 +fred zipper 29.020000457763672 +gabriella allen 46.27000045776367 +gabriella allen 64.22000122070312 +gabriella brown 15.260000228881836 +gabriella brown 84.83000183105469 +gabriella carson 42.7599983215332 +gabriella davidson 6.550000190734863 +gabriella ellison 48.08000183105469 +gabriella ellison 71.54000091552734 +gabriella falkner 10.170000076293945 +gabriella falkner 51.720001220703125 +gabriella falkner 87.61000061035156 +gabriella garcia 43.0099983215332 +gabriella hernandez 76.91999816894531 +gabriella hernandez 92.9800033569336 +gabriella ichabod 10.729999542236328 +gabriella ichabod 26.639999389648438 +gabriella ichabod 66.36000061035156 +gabriella ichabod 71.12999725341797 +gabriella ichabod 90.3499984741211 +gabriella king 20.670000076293945 +gabriella king 80.45999908447266 +gabriella laertes 65.37999725341797 +gabriella miller 50.83000183105469 +gabriella ovid 77.7400016784668 +gabriella ovid 92.4000015258789 +gabriella polk 35.68000030517578 +gabriella polk 88.05000305175781 +gabriella steinbeck 46.45000076293945 +gabriella steinbeck 78.63999938964844 +gabriella thompson 73.31999969482422 +gabriella thompson 88.36000061035156 +gabriella thompson 94.25 +gabriella van buren 69.80000305175781 +gabriella van buren 70.05999755859375 +gabriella white 55.18000030517578 +gabriella young 9.25 +gabriella young 59.709999084472656 +gabriella zipper 36.2599983215332 +gabriella zipper 91.62999725341797 +holly allen 44.56999969482422 +holly brown 77.80999755859375 +holly brown 78.7300033569336 +holly falkner 80.73999786376953 +holly hernandez 20.81999969482422 +holly hernandez 21.190000534057617 +holly hernandez 24.790000915527344 +holly hernandez 30.25 +holly ichabod 83.2699966430664 +holly ichabod 84.69000244140625 +holly ichabod 90.51000213623047 +holly johnson 36.95000076293945 +holly johnson 64.36000061035156 +holly johnson 65.62000274658203 +holly king 42.310001373291016 +holly king 55.38999938964844 +holly laertes 52.5 +holly miller 50.400001525878906 +holly nixon 53.779998779296875 +holly nixon 88.0199966430664 +holly polk 3.619999885559082 +holly polk 98.30999755859375 +holly robinson 69.31999969482422 +holly thompson 0.07999999821186066 +holly thompson 86.69000244140625 +holly thompson 145.93999481201172 +holly underhill 42.54999923706055 +holly underhill 50.40999984741211 +holly underhill 79.95999908447266 +holly underhill 96.68000030517578 +holly van buren 68.80999755859375 +holly white 7.960000038146973 +holly white 32.91999816894531 +holly xylophone 93.11000061035156 +holly young 60.220001220703125 +holly young 66.16999816894531 +holly zipper 99.12999725341797 +holly zipper 99.29000091552734 +irene allen 38.849998474121094 +irene brown 4.789999961853027 +irene brown 53.939998626708984 +irene brown 87.66999816894531 +irene carson 94.54000091552734 +irene ellison 45.2400016784668 +irene ellison 50.08000183105469 +irene falkner 22.079999923706055 +irene falkner 99.91999816894531 +irene garcia 15.369999885559082 +irene garcia 58.43000030517578 +irene garcia 86.93000030517578 +irene ichabod 41.439998626708984 +irene ichabod 99.62000274658203 +irene johnson 5.880000114440918 +irene laertes 9.569999694824219 +irene laertes 42.66999816894531 +irene laertes 44.43000030517578 +irene miller 65.44000244140625 +irene nixon 15.100000381469727 +irene nixon 29.780000686645508 +irene nixon 42.560001373291016 +irene ovid 5.239999771118164 +irene ovid 35.130001068115234 +irene ovid 79.75 +irene polk 0.9800000190734863 +irene polk 24.020000457763672 +irene polk 42.2400016784668 +irene polk 47.08000183105469 +irene polk 95.83999633789062 +irene quirinius 12.899999618530273 +irene quirinius 58.86000061035156 +irene quirinius 70.0 +irene robinson 94.2699966430664 +irene steinbeck 94.33000183105469 +irene thompson 78.30000305175781 +irene underhill 28.309999465942383 +irene underhill 57.349998474121094 +irene van buren 54.439998626708984 +irene van buren 54.9900016784668 +irene xylophone 74.19000244140625 +jessica brown 51.290000915527344 +jessica carson 25.549999237060547 +jessica carson 31.860000610351562 +jessica carson 62.20000076293945 +jessica davidson 33.54999923706055 +jessica davidson 49.77000045776367 +jessica davidson 95.33999633789062 +jessica davidson 99.20999908447266 +jessica ellison 11.180000305175781 +jessica ellison 22.780000686645508 +jessica falkner 99.6500015258789 +jessica garcia 5.539999961853027 +jessica garcia 87.92999941110611 +jessica ichabod 59.15999984741211 +jessica johnson 9.5600004196167 +jessica johnson 40.79999923706055 +jessica miller 151.0199966430664 +jessica nixon 77.0999984741211 +jessica nixon 90.06999969482422 +jessica ovid 71.68000030517578 +jessica ovid 119.9000015258789 +jessica polk 49.68000030517578 +jessica quirinius 22.940000534057617 +jessica quirinius 32.470001220703125 +jessica quirinius 35.619998931884766 +jessica quirinius 46.869998931884766 +jessica robinson 112.36000442504883 +jessica thompson 38.33000183105469 +jessica thompson 89.55000305175781 +jessica underhill 26.079999923706055 +jessica underhill 45.41999816894531 +jessica underhill 46.209999084472656 +jessica van buren 9.739999771118164 +jessica white 11.550000190734863 +jessica white 36.58000183105469 +jessica white 73.93000030517578 +jessica white 74.30000305175781 +jessica white 96.62000274658203 +jessica xylophone 53.060001373291016 +jessica young 11.1899995803833 +jessica young 43.369998931884766 +jessica zipper 6.630000114440918 +jessica zipper 12.020000457763672 +jessica zipper 92.43999862670898 +katie allen 64.66999816894531 +katie brown 27.719999313354492 +katie davidson 170.84000396728516 +katie ellison 3.609999895095825 +katie ellison 80.97000122070312 +katie falkner 18.5 +katie garcia 24.729999542236328 +katie garcia 84.4000015258789 +katie hernandez 38.61999988555908 +katie ichabod 30.709999084472656 +katie ichabod 39.97999954223633 +katie ichabod 43.16999816894531 +katie king 39.34000015258789 +katie king 39.83000183105469 +katie king 97.80999755859375 +katie miller 31.399999618530273 +katie miller 74.77999877929688 +katie nixon 121.3700008392334 +katie ovid 50.65999984741211 +katie polk 11.680000305175781 +katie polk 40.2400016784668 +katie robinson 13.890000343322754 +katie van buren 17.739999771118164 +katie van buren 52.529998779296875 +katie white 1.309999942779541 +katie white 34.72999954223633 +katie xylophone 14.130000114440918 +katie young 31.010000228881836 +katie young 72.51000213623047 +katie young 97.56999969482422 +katie zipper 18.93000030517578 +katie zipper 58.75 +luke allen 15.180000305175781 +luke allen 50.959999084472656 +luke allen 66.61000061035156 +luke allen 89.55000305175781 +luke allen 99.38999938964844 +luke brown 51.790000915527344 +luke davidson 7.050000190734863 +luke davidson 28.950000762939453 +luke ellison 1.8700000047683716 +luke ellison 16.25 +luke ellison 87.83000183105469 +luke falkner 32.25 +luke falkner 39.60000038146973 +luke garcia 13.350000381469727 +luke garcia 30.3700008392334 +luke ichabod 8.449999809265137 +luke ichabod 97.87000274658203 +luke johnson 11.149999618530273 +luke johnson 14.4399995803833 +luke johnson 31.670000076293945 +luke laertes 0.5199999809265137 +luke laertes 4.800000190734863 +luke laertes 11.819999694824219 +luke laertes 16.690000534057617 +luke laertes 45.9900016784668 +luke miller 97.6500015258789 +luke ovid 38.04999923706055 +luke ovid 159.68000030517578 +luke polk 46.880001068115234 +luke polk 95.27999877929688 +luke quirinius 40.41999816894531 +luke robinson 55.099998474121094 +luke robinson 65.69999694824219 +luke thompson 94.37999725341797 +luke underhill 59.68000030517578 +luke underhill 95.52999877929688 +luke underhill 96.94000244140625 +luke van buren 148.62999725341797 +luke white 67.12000274658203 +luke xylophone 48.279998779296875 +luke zipper 24.829999923706055 +mike allen 48.53999900817871 +mike brown 48.22999954223633 +mike carson 20.06999969482422 +mike carson 47.56999969482422 +mike carson 81.66000366210938 +mike davidson 27.309999465942383 +mike davidson 54.83000183105469 +mike ellison 28.559999465942383 +mike ellison 37.099998474121094 +mike ellison 62.13999938964844 +mike ellison 79.37999725341797 +mike ellison 85.73999786376953 +mike falkner 16.479999542236328 +mike garcia 70.8499984741211 +mike garcia 75.83000183105469 +mike garcia 79.20999908447266 +mike hernandez 37.900001525878906 +mike hernandez 59.45000076293945 +mike ichabod 64.7699966430664 +mike king 38.790000915527344 +mike king 62.7400016784668 +mike king 78.26000213623047 +mike king 84.2300033569336 +mike king 85.0999984741211 +mike king 94.68000030517578 +mike miller 3.9600000381469727 +mike nixon 60.119998931884766 +mike nixon 92.95999908447266 +mike polk 12.449999809265137 +mike polk 27.06999969482422 +mike polk 99.68000030517578 +mike quirinius 89.37999725341797 +mike steinbeck 5.849999904632568 +mike steinbeck 85.13999938964844 +mike steinbeck 93.07000207901001 +mike steinbeck 97.45999908447266 +mike van buren 80.83999633789062 +mike van buren 114.56999969482422 +mike white 9.569999694824219 +mike white 28.889999389648438 +mike white 32.0099983215332 +mike white 91.87999725341797 +mike young 7.820000171661377 +mike young 74.58999633789062 +mike young 83.54000091552734 +mike zipper 26.729999542236328 +mike zipper 83.91999816894531 +mike zipper 97.38999938964844 +nick allen 21.830000400543213 +nick allen 35.08000183105469 +nick brown 42.5099983215332 +nick davidson 49.439998626708984 +nick ellison 9.680000305175781 +nick ellison 89.01000213623047 +nick falkner 10.130000114440918 +nick falkner 88.47000122070312 +nick garcia 13.9399995803833 +nick garcia 26.389999389648438 +nick garcia 46.43000030517578 +nick ichabod 23.450000762939453 +nick ichabod 47.59000015258789 +nick ichabod 74.41999816894531 +nick johnson 3.9700000286102295 +nick johnson 94.08000183105469 +nick laertes 96.25 +nick miller 82.97000122070312 +nick nixon 96.37999725341797 +nick ovid 87.98999786376953 +nick polk 59.27000141143799 +nick quirinius 67.44999694824219 +nick quirinius 81.16999816894531 +nick robinson 57.66999816894531 +nick robinson 60.709999084472656 +nick steinbeck 97.83000183105469 +nick thompson 11.90999984741211 +nick underhill 20.809999465942383 +nick van buren 51.290000915527344 +nick xylophone 103.45999908447266 +nick young 0.27000001072883606 +nick young 24.799999237060547 +nick zipper 56.619998931884766 +nick zipper 119.0199966430664 +oscar allen 18.6299991607666 +oscar brown 13.100000381469727 +oscar carson 6.869999885559082 +oscar carson 55.20000076293945 +oscar carson 78.9800033569336 +oscar carson 87.4800033569336 +oscar carson 98.51000213623047 +oscar davidson 64.45999908447266 +oscar ellison 57.88999938964844 +oscar ellison 107.7100019454956 +oscar falkner 98.4800033569336 +oscar garcia 67.4800033569336 +oscar hernandez 95.4800033569336 +oscar hernandez 125.92999649047852 +oscar ichabod 3.3299999237060547 +oscar ichabod 33.52000045776367 +oscar ichabod 71.80000305175781 +oscar ichabod 76.69000244140625 +oscar johnson 16.09000015258789 +oscar johnson 139.69000244140625 +oscar king 19.059999465942383 +oscar king 25.8799991607666 +oscar king 59.5 +oscar laertes 5.510000228881836 +oscar laertes 8.420000076293945 +oscar laertes 9.260000228881836 +oscar laertes 27.1200008392334 +oscar nixon 41.619998931884766 +oscar ovid 37.13999938964844 +oscar ovid 82.23999786376953 +oscar ovid 91.52999877929688 +oscar polk 30.610000610351562 +oscar polk 63.900001525878906 +oscar quirinius 41.45000076293945 +oscar quirinius 65.43000030517578 +oscar quirinius 113.35000228881836 +oscar quirinius 139.10000610351562 +oscar robinson 11.34000015258789 +oscar robinson 42.849998474121094 +oscar robinson 74.52999877929688 +oscar robinson 131.31999969482422 +oscar steinbeck 29.59000015258789 +oscar thompson 31.90999984741211 +oscar thompson 41.34000015258789 +oscar thompson 60.529998779296875 +oscar thompson 70.88999938964844 +oscar underhill 87.4000015258789 +oscar van buren 2.180000066757202 +oscar van buren 61.880001068115234 +oscar van buren 91.77999877929688 +oscar white 19.0 +oscar white 28.450000762939453 +oscar white 51.849998474121094 +oscar white 59.83000183105469 +oscar xylophone 21.799999237060547 +oscar xylophone 57.119998931884766 +oscar xylophone 57.22999954223633 +oscar zipper 13.989999771118164 +oscar zipper 32.88999938964844 +oscar zipper 39.81999969482422 +priscilla brown 70.23999786376953 +priscilla brown 80.5199966430664 +priscilla brown 104.63999938964844 +priscilla carson 7.960000038146973 +priscilla carson 79.80999946594238 +priscilla carson 85.43000316619873 +priscilla ichabod 80.04000091552734 +priscilla ichabod 92.61000061035156 +priscilla johnson 61.939998626708984 +priscilla johnson 67.9800033569336 +priscilla johnson 68.32999992370605 +priscilla johnson 91.4800033569336 +priscilla johnson 92.48000144958496 +priscilla king 43.91999816894531 +priscilla nixon 95.80999755859375 +priscilla nixon 107.69000244140625 +priscilla ovid 52.72999954223633 +priscilla ovid 125.73999643325806 +priscilla polk 15.149999618530273 +priscilla quirinius 9.710000038146973 +priscilla thompson 9.800000190734863 +priscilla underhill 35.720001220703125 +priscilla underhill 68.22000122070312 +priscilla van buren 68.88999938964844 +priscilla van buren 91.61000061035156 +priscilla van buren 170.5500030517578 +priscilla white 78.27999877929688 +priscilla xylophone 0.15000000596046448 +priscilla xylophone 21.489999771118164 +priscilla xylophone 59.61000061035156 +priscilla young 0.4300000071525574 +priscilla young 4.320000171661377 +priscilla zipper 18.6299991607666 +priscilla zipper 25.670000076293945 +quinn allen 54.72999954223633 +quinn allen 83.33000183105469 +quinn brown 24.280000686645508 +quinn brown 52.439998626708984 +quinn brown 80.58000183105469 +quinn davidson 61.57999849319458 +quinn davidson 67.18000030517578 +quinn davidson 83.4000015258789 +quinn davidson 95.11000061035156 +quinn ellison 19.280000686645508 +quinn ellison 30.649999618530273 +quinn garcia 40.97999954223633 +quinn garcia 59.9900016784668 +quinn garcia 74.0199966430664 +quinn garcia 172.8499984741211 +quinn ichabod 36.790000915527344 +quinn king 74.62000274658203 +quinn king 86.2300033569336 +quinn laertes 4.710000038146973 +quinn laertes 41.290000915527344 +quinn laertes 76.5199966430664 +quinn nixon 86.64000129699707 +quinn ovid 52.500000953674316 +quinn quirinius 32.18000030517578 +quinn robinson 38.64999866485596 +quinn steinbeck 8.449999809265137 +quinn steinbeck 66.51000213623047 +quinn thompson 74.9399995803833 +quinn thompson 76.27999877929688 +quinn underhill 17.15999984741211 +quinn underhill 79.4800033569336 +quinn underhill 140.92000198364258 +quinn van buren 82.5199966430664 +quinn young 45.060001373291016 +quinn zipper 22.25 +quinn zipper 58.0 +rachel allen 15.8100004196167 +rachel allen 74.44999694824219 +rachel brown 2.9600000381469727 +rachel brown 30.809999465942383 +rachel brown 33.36000061035156 +rachel brown 34.40999984741211 +rachel brown 52.16999816894531 +rachel carson 37.599998474121094 +rachel carson 98.95999908447266 +rachel davidson 4.920000076293945 +rachel ellison 10.600000381469727 +rachel falkner 46.150001525878906 +rachel falkner 80.91999816894531 +rachel falkner 88.80000305175781 +rachel falkner 99.23999786376953 +rachel johnson 62.22999954223633 +rachel king 36.220001220703125 +rachel king 59.45000076293945 +rachel laertes 44.220001220703125 +rachel laertes 45.45000076293945 +rachel ovid 0.6000000238418579 +rachel ovid 1.0800000429153442 +rachel polk 89.27999877929688 +rachel quirinius 12.4399995803833 +rachel robinson 4.570000171661377 +rachel robinson 30.360000610351562 +rachel robinson 64.94999694824219 +rachel thompson 0.5600000023841858 +rachel thompson 4.170000076293945 +rachel thompson 58.52000045776367 +rachel underhill 48.45000076293945 +rachel white 43.709999084472656 +rachel white 94.72000122070312 +rachel young 43.130001068115234 +rachel zipper 7.059999942779541 +rachel zipper 72.18000030517578 +sarah carson 1.909999966621399 +sarah carson 14.210000038146973 +sarah carson 78.88999938964844 +sarah ellison 16.989999771118164 +sarah falkner 90.27999877929688 +sarah falkner 99.36000061035156 +sarah garcia 41.290000915527344 +sarah garcia 58.010000228881836 +sarah garcia 153.8800048828125 +sarah ichabod 81.31999969482422 +sarah ichabod 97.26000213623047 +sarah johnson 16.239999771118164 +sarah johnson 45.099998474121094 +sarah johnson 73.87999725341797 +sarah johnson 77.66000366210938 +sarah king 41.869998931884766 +sarah king 48.25 +sarah miller 41.709999084472656 +sarah ovid 60.02000045776367 +sarah robinson 33.83000183105469 +sarah robinson 66.88999938964844 +sarah steinbeck 40.16999816894531 +sarah white 37.849998474121094 +sarah white 89.80999755859375 +sarah xylophone 68.31999969482422 +sarah young 45.560001373291016 +sarah zipper 83.08000183105469 +tom brown 8.609999656677246 +tom brown 12.319999694824219 +tom carson 5.440000057220459 +tom carson 16.079999923706055 +tom carson 18.889999389648438 +tom davidson 170.0 +tom ellison 76.73999786376953 +tom ellison 98.2300033569336 +tom ellison 155.99999618530273 +tom falkner 60.130001068115234 +tom falkner 88.22000122070312 +tom hernandez 41.36000061035156 +tom hernandez 81.63999938964844 +tom ichabod 103.29000282287598 +tom johnson 14.920000076293945 +tom johnson 43.56999969482422 +tom king 15.75 +tom laertes 43.310001373291016 +tom laertes 64.6500015258789 +tom miller 21.229999542236328 +tom miller 68.25 +tom miller 139.04000091552734 +tom nixon 153.83999633789062 +tom ovid 8.670000076293945 +tom polk 38.29999923706055 +tom polk 54.43000030517578 +tom quirinius 10.1899995803833 +tom quirinius 75.31999969482422 +tom robinson 90.69000244140625 +tom robinson 98.72000122070312 +tom robinson 99.1500015258789 +tom robinson 123.5199966430664 +tom steinbeck 26.489999771118164 +tom van buren 3.2799999713897705 +tom van buren 40.779998779296875 +tom van buren 63.5099983215332 +tom white 40.040000915527344 +tom young 22.850000381469727 +tom young 84.30999755859375 +tom zipper 122.78000259399414 +ulysses brown 72.79000091552734 +ulysses carson 77.41999816894531 +ulysses carson 79.54000091552734 +ulysses carson 146.7100067138672 +ulysses carson 220.18000030517578 +ulysses davidson 55.16999816894531 +ulysses ellison 96.7300033569336 +ulysses garcia 89.80000305175781 +ulysses hernandez 35.16999816894531 +ulysses hernandez 54.470001220703125 +ulysses hernandez 68.25 +ulysses ichabod 19.1299991607666 +ulysses ichabod 98.56999969482422 +ulysses johnson 102.5999984741211 +ulysses king 74.19000244140625 +ulysses laertes 1.9199999570846558 +ulysses laertes 24.860000610351562 +ulysses laertes 50.1899995803833 +ulysses miller 2.9600000381469727 +ulysses miller 76.27999877929688 +ulysses nixon 80.95999908447266 +ulysses ovid 29.360000610351562 +ulysses polk 8.710000038146973 +ulysses polk 60.060001373291016 +ulysses polk 65.0199966430664 +ulysses polk 97.10000038146973 +ulysses quirinius 112.56999969482422 +ulysses robinson 104.85999870300293 +ulysses steinbeck 32.40999984741211 +ulysses steinbeck 74.0 +ulysses thompson 198.83000564575195 +ulysses underhill 14.119999885559082 +ulysses underhill 22.360000610351562 +ulysses underhill 35.88999938964844 +ulysses underhill 57.369998931884766 +ulysses underhill 81.58000183105469 +ulysses underhill 88.4800033569336 +ulysses underhill 99.66999816894531 +ulysses van buren 95.52999877929688 +ulysses white 59.54999923706055 +ulysses white 170.0800018310547 +ulysses xylophone 39.689998626708984 +ulysses xylophone 54.099998474121094 +ulysses xylophone 57.3100004196167 +ulysses young 14.930000305175781 +ulysses young 32.52000045776367 +ulysses young 114.55999946594238 +victor allen 44.27000045776367 +victor allen 89.5 +victor brown 59.34000015258789 +victor brown 77.88999938964844 +victor brown 90.37999725341797 +victor brown 91.97000122070312 +victor davidson 60.2599983215332 +victor davidson 66.5999984741211 +victor davidson 98.54999923706055 +victor ellison 17.8700008392334 +victor ellison 68.8499984741211 +victor hernandez 19.030000686645508 +victor hernandez 59.619998931884766 +victor hernandez 69.87999725341797 +victor hernandez 71.3499984741211 +victor hernandez 74.5199966430664 +victor johnson 18.200000762939453 +victor johnson 42.89000141620636 +victor johnson 72.55999755859375 +victor king 47.880001068115234 +victor king 66.66999816894531 +victor laertes 62.91999816894531 +victor laertes 67.58999633789062 +victor miller 22.1200008392334 +victor nixon 34.029998779296875 +victor nixon 68.5 +victor ovid 125.84000015258789 +victor polk 17.210000038146973 +victor quirinius 50.70000076293945 +victor quirinius 134.4000015258789 +victor robinson 51.560001373291016 +victor robinson 58.66999816894531 +victor steinbeck 12.460000038146973 +victor steinbeck 46.09000015258789 +victor steinbeck 52.720001220703125 +victor thompson 58.65999984741211 +victor van buren 34.970001220703125 +victor van buren 41.68000030517578 +victor white 5.670000076293945 +victor white 135.02999687194824 +victor xylophone 10.09000015258789 +victor xylophone 11.220000267028809 +victor xylophone 28.5 +victor xylophone 62.38999938964844 +victor xylophone 76.0999984741211 +victor young 88.55000305175781 +victor zipper 26.289999961853027 +wendy allen 56.06999969482422 +wendy allen 93.96999740600586 +wendy allen 220.7900003194809 +wendy brown 27.8700008392334 +wendy brown 50.2599983215332 +wendy ellison 94.66000366210938 +wendy ellison 124.93999481201172 +wendy falkner 22.010000228881836 +wendy falkner 97.68000030517578 +wendy falkner 141.36000061035156 +wendy garcia 30.6200008392334 +wendy garcia 57.25 +wendy garcia 82.1500015258789 +wendy garcia 133.3400001525879 +wendy hernandez 48.11000061035156 +wendy ichabod 13.149999618530273 +wendy king 45.189998626708984 +wendy king 63.33000183105469 +wendy king 183.75 +wendy laertes 46.619998931884766 +wendy laertes 70.37999725341797 +wendy laertes 79.98999786376953 +wendy miller 1.2699999809265137 +wendy miller 12.420000076293945 +wendy nixon 45.91999816894531 +wendy nixon 60.2599983215332 +wendy ovid 86.62999725341797 +wendy ovid 95.33000183105469 +wendy polk 32.369998931884766 +wendy polk 42.04000073671341 +wendy quirinius 12.15999984741211 +wendy quirinius 14.300000190734863 +wendy robinson 26.469999313354492 +wendy robinson 71.06999969482422 +wendy robinson 117.02000045776367 +wendy steinbeck 120.67000007629395 +wendy thompson 67.34000015258789 +wendy thompson 85.76000213623047 +wendy underhill 68.04000091552734 +wendy underhill 79.19000244140625 +wendy underhill 89.77999877929688 +wendy van buren 57.459999084472656 +wendy van buren 92.81999969482422 +wendy white 73.68000030517578 +wendy xylophone 76.69999694824219 +wendy xylophone 90.60000038146973 +wendy young 8.449999809265137 +wendy young 33.7599983215332 +xavier allen 45.68000030517578 +xavier allen 83.93000030517578 +xavier allen 98.22000122070312 +xavier brown 7.789999961853027 +xavier brown 90.7300033569336 +xavier brown 96.2300033569336 +xavier carson 20.790000915527344 +xavier carson 94.68000030517578 +xavier davidson 15.920000076293945 +xavier davidson 82.41000366210938 +xavier davidson 106.5199966430664 +xavier ellison 12.850000381469727 +xavier ellison 77.97000122070312 +xavier garcia 70.04000091552734 +xavier hernandez 6.670000076293945 +xavier hernandez 38.56999969482422 +xavier hernandez 67.26000213623047 +xavier ichabod 4.71999979019165 +xavier ichabod 71.19000244140625 +xavier johnson 27.299999237060547 +xavier johnson 203.65999794006348 +xavier king 8.569999694824219 +xavier king 87.22000122070312 +xavier laertes 15.899999618530273 +xavier ovid 112.91000366210938 +xavier polk 13.869999885559082 +xavier polk 61.209999084472656 +xavier polk 72.62000274658203 +xavier polk 76.93000030517578 +xavier quirinius 62.52000045776367 +xavier quirinius 83.01000022888184 +xavier quirinius 89.55000305175781 +xavier quirinius 97.14999961853027 +xavier thompson 9.930000305175781 +xavier underhill 47.27000045776367 +xavier white 59.20000171661377 +xavier white 75.29000091552734 +xavier xylophone 79.41999816894531 +xavier zipper 8.449999809265137 +yuri allen 52.849998474121094 +yuri allen 94.98999977111816 +yuri brown 75.19000244140625 +yuri brown 84.02999877929688 +yuri carson 6.289999961853027 +yuri carson 91.16000366210938 +yuri ellison 1.1200000047683716 +yuri ellison 98.82999801635742 +yuri falkner 39.6299991607666 +yuri falkner 86.0 +yuri garcia 27.65999984741211 +yuri hernandez 2.069999933242798 +yuri johnson 0.12999999523162842 +yuri johnson 39.900001525878906 +yuri johnson 48.220001220703125 +yuri king 69.59000015258789 +yuri laertes 37.59000015258789 +yuri laertes 61.95000076293945 +yuri nixon 2.200000047683716 +yuri nixon 82.81000328063965 +yuri polk 26.760000228881836 +yuri polk 28.790000915527344 +yuri polk 105.11999702453613 +yuri quirinius 10.260000228881836 +yuri quirinius 54.310001373291016 +yuri quirinius 57.93000030517578 +yuri steinbeck 17.790000915527344 +yuri steinbeck 75.87999725341797 +yuri thompson 14.920000076293945 +yuri underhill 23.770000457763672 +yuri underhill 83.87000274658203 +yuri white 34.58000183105469 +yuri xylophone 20.3799991607666 +zach allen 65.43000030517578 +zach brown 48.0099983215332 +zach brown 49.119998931884766 +zach brown 57.08000183105469 +zach brown 67.37999725341797 +zach brown 100.46000289916992 +zach carson 95.86999893188477 +zach ellison 6.840000152587891 +zach falkner 9.130000114440918 +zach falkner 91.41999816894531 +zach garcia 32.20000076293945 +zach garcia 84.37999725341797 +zach garcia 106.86999893188477 +zach garcia 167.62000274658203 +zach ichabod 64.25 +zach ichabod 106.69000244140625 +zach king 46.18000030517578 +zach king 70.51000213623047 +zach king 86.93000030517578 +zach miller 2.5999999046325684 +zach miller 21.280000686645508 +zach miller 53.27000045776367 +zach ovid 0.10000000149011612 +zach ovid 23.06999969482422 +zach ovid 92.55000305175781 +zach ovid 94.33999633789062 +zach quirinius 39.209999084472656 +zach robinson 122.81000137329102 +zach steinbeck 85.48999786376953 +zach steinbeck 90.05000305175781 +zach thompson 71.5 +zach thompson 91.63999938964844 +zach underhill 86.22000122070312 +zach white 70.52999877929688 +zach xylophone 43.84999942779541 +zach xylophone 71.01000213623047 +zach young 71.31999969482422 +zach zipper 52.60000133514404 +zach zipper 85.87000274658203 +zach zipper 94.43000030517578 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-2-5f0eab306ea3c22b11ace9b542a7ee56 b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-2-5f0eab306ea3c22b11ace9b542a7ee56 new file mode 100644 index 000000000000..e55bede9242e --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-2-5f0eab306ea3c22b11ace9b542a7ee56 @@ -0,0 +1,1049 @@ + 257.04000091552734 + 261.16000175476074 + 284.2699966430664 +alice allen 73.62999725341797 +alice allen 195.0099983215332 +alice allen 196.729998588562 +alice brown 245.52000427246094 +alice carson 424.63000106811523 +alice davidson 319.00999450683594 +alice falkner 90.25 +alice garcia 174.36999893188477 +alice hernandez 185.6699981689453 +alice hernandez 380.1699981689453 +alice johnson 367.2900071144104 +alice king 58.78999900817871 +alice king 294.1199951171875 +alice king 371.23999404907227 +alice laertes 68.94999694824219 +alice laertes 258.3999938964844 +alice miller 154.19000244140625 +alice nixon 209.91000366210938 +alice nixon 246.36000442504883 +alice nixon 260.19000244140625 +alice ovid 49.8199987411499 +alice polk 148.63999938964844 +alice quirinius 239.81999588012695 +alice quirinius 301.4100036621094 +alice robinson 140.47999572753906 +alice robinson 266.4900016784668 +alice steinbeck 169.76000213623047 +alice steinbeck 186.70999908447266 +alice steinbeck 446.8099937438965 +alice underhill 98.18000030517578 +alice van buren 112.42000198364258 +alice xylophone 78.20999908447266 +alice xylophone 91.22000122070312 +alice xylophone 413.1199951171875 +alice zipper 89.93000030517578 +alice zipper 279.54000091552734 +alice zipper 293.25000381469727 +bob brown 188.89999389648438 +bob brown 228.80999946594238 +bob brown 247.37999725341797 +bob carson 207.67000198364258 +bob davidson 53.93000030517578 +bob davidson 113.83999919891357 +bob davidson 259.0899963378906 +bob ellison 65.0199966430664 +bob ellison 80.30000305175781 +bob ellison 243.86000061035156 +bob ellison 245.02999877929688 +bob falkner 208.82000160217285 +bob garcia 33.410000801086426 +bob garcia 87.56999969482422 +bob garcia 120.17999649047852 +bob garcia 148.65999841690063 +bob garcia 178.87000274658203 +bob hernandez 337.23999977111816 +bob ichabod 82.55999755859375 +bob king 114.11000156402588 +bob king 134.81999588012695 +bob king 152.7699956893921 +bob laertes 42.89999961853027 +bob laertes 393.99999433755875 +bob miller 146.1500015258789 +bob ovid 62.849998474121094 +bob ovid 88.77999877929688 +bob ovid 97.08999633789062 +bob ovid 102.93000030517578 +bob polk 261.4599976539612 +bob quirinius 298.7199897766113 +bob steinbeck 103.01999950408936 +bob van buren 174.89999771118164 +bob white 194.25 +bob white 347.7799949645996 +bob xylophone 19.690000534057617 +bob xylophone 191.52999687194824 +bob young 78.17999649047852 +bob zipper 132.86000061035156 +bob zipper 139.6900042295456 +bob zipper 295.59000039100647 +calvin allen 255.68000411987305 +calvin brown 85.9000015258789 +calvin brown 238.02000427246094 +calvin brown 275.8699951171875 +calvin carson 80.2300033569336 +calvin davidson 31.770000457763672 +calvin davidson 181.76000213623047 +calvin ellison 188.0300006866455 +calvin falkner 93.61000061035156 +calvin falkner 94.30999755859375 +calvin falkner 125.91999816894531 +calvin falkner 137.1699981689453 +calvin falkner 140.99999594688416 +calvin falkner 168.81999969482422 +calvin garcia 307.439998626709 +calvin hernandez 303.4599952697754 +calvin johnson 152.8300018310547 +calvin laertes 150.69999885559082 +calvin laertes 216.81000137329102 +calvin nixon 131.57999801635742 +calvin nixon 143.3699951171875 +calvin nixon 196.34000301361084 +calvin ovid 69.95999908447266 +calvin ovid 176.13999938964844 +calvin ovid 176.3800048828125 +calvin ovid 248.65999603271484 +calvin polk 147.04000091552734 +calvin quirinius 226.66999435424805 +calvin quirinius 266.7100009918213 +calvin robinson 289.7900047302246 +calvin steinbeck 92.05000305175781 +calvin steinbeck 118.15000057220459 +calvin steinbeck 333.6000003814697 +calvin thompson 93.7300033569336 +calvin thompson 249.56000137329102 +calvin underhill 208.3400001525879 +calvin van buren 136.51000213623047 +calvin van buren 347.0999946594238 +calvin white 90.69000244140625 +calvin white 112.15999984741211 +calvin xylophone 25.420000076293945 +calvin xylophone 237.71999740600586 +calvin xylophone 315.2099952697754 +calvin young 222.96000289916992 +calvin young 243.3199977874756 +calvin zipper 95.37999725341797 +calvin zipper 531.3600015640259 +david allen 202.43000030517578 +david allen 302.4399948120117 +david brown 93.63999938964844 +david brown 258.05999851226807 +david davidson 74.1500015258789 +david davidson 95.80999755859375 +david davidson 106.50000202655792 +david davidson 149.94000244140625 +david ellison 85.2300033569336 +david ellison 94.1500015258789 +david ellison 208.3900032043457 +david hernandez 99.91000366210938 +david ichabod 82.55000305175781 +david ichabod 320.47999143600464 +david laertes 250.1699981689453 +david nixon 174.58999633789062 +david ovid 198.21000289916992 +david ovid 230.47999954223633 +david quirinius 29.239999771118164 +david quirinius 79.97000122070312 +david quirinius 180.92999649047852 +david robinson 147.65999603271484 +david robinson 168.7100009918213 +david thompson 41.88999938964844 +david underhill 97.55999755859375 +david underhill 277.5999984741211 +david underhill 369.4600009918213 +david van buren 83.56999969482422 +david van buren 289.189998626709 +david white 124.6099967956543 +david xylophone 135.70000076293945 +david xylophone 237.06000137329102 +david xylophone 338.20999908447266 +david young 172.49000549316406 +david young 184.9800033569336 +ethan allen 240.42000198364258 +ethan brown 61.86000061035156 +ethan brown 73.18000030517578 +ethan brown 105.29000043869019 +ethan brown 177.8300018310547 +ethan brown 185.98999691009521 +ethan brown 284.729998588562 +ethan carson 265.22999572753906 +ethan ellison 166.5 +ethan ellison 244.99000671505928 +ethan falkner 59.43000030517578 +ethan falkner 196.17000198364258 +ethan garcia 271.5999946594238 +ethan hernandez 264.50999450683594 +ethan johnson 90.05000305175781 +ethan king 36.49000024795532 +ethan laertes 95.06999969482422 +ethan laertes 96.29000091552734 +ethan laertes 189.66000270843506 +ethan laertes 192.79999923706055 +ethan laertes 249.04000091552734 +ethan laertes 249.76000213623047 +ethan laertes 369.9599952697754 +ethan miller 314.5599994659424 +ethan nixon 493.03000259399414 +ethan ovid 57.290000915527344 +ethan polk 2.3499999046325684 +ethan polk 59.869998931884766 +ethan polk 219.6599998474121 +ethan polk 263.8600025177002 +ethan quirinius 97.23999786376953 +ethan quirinius 111.70999908447266 +ethan quirinius 317.69000363349915 +ethan robinson 78.62000274658203 +ethan robinson 149.5800018310547 +ethan underhill 231.25000381469727 +ethan van buren 152.60000228881836 +ethan white 155.81999969482422 +ethan white 235.55999755859375 +ethan xylophone 414.61000061035156 +ethan zipper 97.51000213623047 +ethan zipper 297.20999908447266 +fred davidson 78.30999755859375 +fred davidson 105.8499984741211 +fred davidson 220.56000137329102 +fred ellison 56.489999771118164 +fred ellison 96.77999877929688 +fred ellison 199.52000045776367 +fred falkner 66.77999973297119 +fred falkner 85.0 +fred falkner 169.91000366210938 +fred hernandez 117.85000228881836 +fred ichabod 81.31999969482422 +fred ichabod 202.45000457763672 +fred johnson 96.08999633789062 +fred king 140.24999618530273 +fred king 343.82000064849854 +fred laertes 57.63999938964844 +fred miller 176.18000030517578 +fred nixon 28.690000534057617 +fred nixon 187.40999603271484 +fred nixon 246.77000045776367 +fred nixon 338.34999084472656 +fred polk 90.12000274658203 +fred polk 323.1899948120117 +fred polk 357.19000244140625 +fred polk 496.16999435424805 +fred quirinius 218.82999801635742 +fred quirinius 224.12000179290771 +fred robinson 89.02999877929688 +fred steinbeck 32.22999954223633 +fred steinbeck 91.05000305175781 +fred steinbeck 231.92000198364258 +fred underhill 183.31999969482422 +fred van buren 83.58000183105469 +fred van buren 318.38000106811523 +fred van buren 346.1400008201599 +fred van buren 391.9999942779541 +fred white 187.38000106811523 +fred young 97.70999908447266 +fred young 141.22999954223633 +fred zipper 163.89999771118164 +gabriella allen 274.8599967956543 +gabriella allen 283.5 +gabriella brown 163.89999961853027 +gabriella brown 465.0 +gabriella carson 147.78999710083008 +gabriella davidson 263.64000415802 +gabriella ellison 71.54000091552734 +gabriella ellison 188.55999755859375 +gabriella falkner 51.720001220703125 +gabriella falkner 87.61000061035156 +gabriella falkner 162.21999835968018 +gabriella garcia 43.0099983215332 +gabriella hernandez 190.5500030517578 +gabriella hernandez 267.4700012207031 +gabriella ichabod 71.12999725341797 +gabriella ichabod 90.3499984741211 +gabriella ichabod 148.6999969482422 +gabriella ichabod 175.70000076293945 +gabriella ichabod 285.72999572753906 +gabriella king 166.75000190734863 +gabriella king 177.6999969482422 +gabriella laertes 65.37999725341797 +gabriella miller 148.4800033569336 +gabriella ovid 92.4000015258789 +gabriella ovid 137.82999801635742 +gabriella polk 244.07000350952148 +gabriella polk 282.00000762939453 +gabriella steinbeck 272.2799987792969 +gabriella steinbeck 461.060001373291 +gabriella thompson 88.36000061035156 +gabriella thompson 94.25 +gabriella thompson 158.80999755859375 +gabriella van buren 146.0800018310547 +gabriella van buren 151.63999938964844 +gabriella white 138.72000122070312 +gabriella young 30.739999771118164 +gabriella young 146.62999725341797 +gabriella zipper 91.62999725341797 +gabriella zipper 357.5099983215332 +holly allen 44.56999969482422 +holly brown 173.64999389648438 +holly brown 174.2100067138672 +holly falkner 166.22999572753906 +holly hernandez 180.0800018310547 +holly hernandez 248.71000480651855 +holly hernandez 336.87000274658203 +holly hernandez 523.2800025939941 +holly ichabod 179.37000274658203 +holly ichabod 180.18000030517578 +holly ichabod 184.66000366210938 +holly johnson 64.36000061035156 +holly johnson 145.61000061035156 +holly johnson 157.12999725341797 +holly king 275.2299995422363 +holly king 288.52000427246094 +holly laertes 246.21000289916992 +holly miller 290.21999740600586 +holly nixon 177.39999389648438 +holly nixon 228.58999633789062 +holly polk 98.30999755859375 +holly polk 307.0799951553345 +holly robinson 219.27999877929688 +holly thompson 75.41999816894531 +holly thompson 86.69000244140625 +holly thompson 523.360002592206 +holly underhill 96.68000030517578 +holly underhill 163.54000091552734 +holly underhill 187.47000122070312 +holly underhill 328.0099983215332 +holly van buren 161.7699966430664 +holly white 122.98999786376953 +holly white 335.93999576568604 +holly xylophone 191.34000396728516 +holly young 60.220001220703125 +holly young 297.20999908447266 +holly zipper 99.12999725341797 +holly zipper 99.29000091552734 +irene allen 234.6400032043457 +irene brown 4.789999961853027 +irene brown 176.4499969482422 +irene brown 338.2099952697754 +irene carson 292.0 +irene ellison 201.06000137329102 +irene ellison 230.79000091552734 +irene falkner 99.91999816894531 +irene falkner 210.11000061035156 +irene garcia 40.78999996185303 +irene garcia 86.93000030517578 +irene garcia 183.02000045776367 +irene ichabod 99.62000274658203 +irene ichabod 281.96999740600586 +irene johnson 243.59999752044678 +irene laertes 112.54000091552734 +irene laertes 227.45000076293945 +irene laertes 246.53000259399414 +irene miller 395.9100036621094 +irene nixon 29.780000686645508 +irene nixon 199.45999908447266 +irene nixon 261.46000480651855 +irene ovid 158.97000122070312 +irene ovid 339.94000244140625 +irene ovid 362.82000732421875 +irene polk 95.83999633789062 +irene polk 183.43000411987305 +irene polk 258.7100033760071 +irene polk 284.6300048828125 +irene polk 507.2400016784668 +irene quirinius 157.5800018310547 +irene quirinius 250.61000061035156 +irene quirinius 431.6499996185303 +irene robinson 191.72999572753906 +irene steinbeck 94.33000183105469 +irene thompson 256.0 +irene underhill 135.55999755859375 +irene underhill 327.0299892425537 +irene van buren 54.439998626708984 +irene van buren 193.71000289916992 +irene xylophone 168.5 +jessica brown 422.5299949645996 +jessica carson 103.66000366210938 +jessica carson 144.92000198364258 +jessica carson 259.1099967956543 +jessica davidson 95.33999633789062 +jessica davidson 99.20999908447266 +jessica davidson 137.17000198364258 +jessica davidson 227.79999923706055 +jessica ellison 207.35000228881836 +jessica ellison 237.4300022125244 +jessica falkner 99.6500015258789 +jessica garcia 174.70999908447266 +jessica garcia 185.62000179290771 +jessica ichabod 124.59000015258789 +jessica johnson 272.0500030517578 +jessica johnson 294.2899990081787 +jessica miller 77.83999633789062 +jessica nixon 77.0999984741211 +jessica nixon 90.06999969482422 +jessica ovid 71.68000030517578 +jessica ovid 309.44000244140625 +jessica polk 472.2099952697754 +jessica quirinius 35.619998931884766 +jessica quirinius 192.7000026702881 +jessica quirinius 208.6500015258789 +jessica quirinius 370.0599937438965 +jessica robinson 254.5300064086914 +jessica thompson 115.9000015258789 +jessica thompson 180.60000610351562 +jessica underhill 199.10999870300293 +jessica underhill 234.29000091552734 +jessica underhill 257.09000396728516 +jessica van buren 9.739999771118164 +jessica white 96.62000274658203 +jessica white 166.54000091552734 +jessica white 240.52999877929688 +jessica white 432.17999362945557 +jessica white 497.6400032043457 +jessica xylophone 385.4799995422363 +jessica young 47.410000801086426 +jessica young 240.6500015258789 +jessica zipper 323.4199962615967 +jessica zipper 344.8399953842163 +jessica zipper 480.06999588012695 +katie allen 312.9700012207031 +katie brown 573.4599933624268 +katie davidson 96.91000366210938 +katie ellison 163.52999877929688 +katie ellison 384.4699947834015 +katie falkner 125.57000160217285 +katie garcia 84.4000015258789 +katie garcia 160.28999710083008 +katie hernandez 257.9600028991699 +katie ichabod 187.63999557495117 +katie ichabod 274.97999572753906 +katie ichabod 362.9200019836426 +katie king 97.80999755859375 +katie king 169.56999969482422 +katie king 314.1999969482422 +katie miller 31.399999618530273 +katie miller 228.40999603271484 +katie nixon 23.190000534057617 +katie ovid 207.1200065612793 +katie polk 143.2599983215332 +katie polk 247.02000045776367 +katie robinson 83.84999942779541 +katie van buren 297.0300064086914 +katie van buren 464.54999351501465 +katie white 344.1700019836426 +katie white 465.8599934577942 +katie xylophone 175.89999675750732 +katie young 31.010000228881836 +katie young 72.51000213623047 +katie young 97.56999969482422 +katie zipper 101.9000015258789 +katie zipper 314.75 +luke allen 89.55000305175781 +luke allen 133.4800033569336 +luke allen 210.8800048828125 +luke allen 392.0300064086914 +luke allen 420.6299934387207 +luke brown 129.20999908447266 +luke davidson 28.950000762939453 +luke davidson 106.41000080108643 +luke ellison 42.09000027179718 +luke ellison 136.52000427246094 +luke ellison 187.51000213623047 +luke falkner 172.8799991607666 +luke falkner 216.0199966430664 +luke garcia 50.94000053405762 +luke garcia 345.1200008392334 +luke ichabod 67.90000057220459 +luke ichabod 97.87000274658203 +luke johnson 59.00999927520752 +luke johnson 105.32000160217285 +luke johnson 187.2899990081787 +luke laertes 105.42000198364258 +luke laertes 147.14999723434448 +luke laertes 158.86000061035156 +luke laertes 167.01999855041504 +luke laertes 281.19999504089355 +luke miller 97.6500015258789 +luke ovid 186.53000259399414 +luke ovid 340.1300048828125 +luke polk 95.27999877929688 +luke polk 277.6700019836426 +luke quirinius 115.83999633789062 +luke robinson 137.33999633789062 +luke robinson 145.23999786376953 +luke thompson 94.37999725341797 +luke underhill 96.94000244140625 +luke underhill 194.73999786376953 +luke underhill 372.6899948120117 +luke van buren 193.93999862670898 +luke white 67.12000274658203 +luke xylophone 102.37999725341797 +luke zipper 223.54000282287598 +mike allen 79.60999870300293 +mike brown 202.81999588012695 +mike carson 81.66000366210938 +mike carson 105.02999877929688 +mike carson 405.4499931335449 +mike davidson 137.74999809265137 +mike davidson 393.17999267578125 +mike ellison 79.37999725341797 +mike ellison 85.73999786376953 +mike ellison 127.15999603271484 +mike ellison 228.07999992370605 +mike ellison 263.8899955749512 +mike falkner 254.50000381469727 +mike garcia 70.8499984741211 +mike garcia 173.63999938964844 +mike garcia 177.5199966430664 +mike hernandez 59.45000076293945 +mike hernandez 327.6900062561035 +mike ichabod 64.7699966430664 +mike king 78.26000213623047 +mike king 84.2300033569336 +mike king 94.68000030517578 +mike king 133.5900001525879 +mike king 134.87999725341797 +mike king 173.45999908447266 +mike miller 57.890000343322754 +mike nixon 92.95999908447266 +mike nixon 203.68999862670898 +mike polk 32.140000343322754 +mike polk 99.68000030517578 +mike polk 306.61000061035156 +mike quirinius 89.37999725341797 +mike steinbeck 85.13999938964844 +mike steinbeck 97.45999908447266 +mike steinbeck 153.86000204086304 +mike steinbeck 221.21999502182007 +mike van buren 80.83999633789062 +mike van buren 174.21000289916992 +mike white 91.87999725341797 +mike white 341.80999755859375 +mike white 341.86000061035156 +mike white 389.20000076293945 +mike young 74.58999633789062 +mike young 83.54000091552734 +mike young 112.19000101089478 +mike zipper 86.98999786376953 +mike zipper 97.38999938964844 +mike zipper 174.61000061035156 +nick allen 173.32000207901 +nick allen 257.7300033569336 +nick brown 192.45000076293945 +nick davidson 258.9799919128418 +nick ellison 183.34000396728516 +nick ellison 193.02000427246094 +nick falkner 10.130000114440918 +nick falkner 182.72000122070312 +nick garcia 142.65999603271484 +nick garcia 183.7699966430664 +nick garcia 277.8299951553345 +nick ichabod 110.43999862670898 +nick ichabod 112.54999923706055 +nick ichabod 241.68999481201172 +nick johnson 192.56000518798828 +nick johnson 325.9499976634979 +nick laertes 96.25 +nick miller 82.97000122070312 +nick nixon 96.37999725341797 +nick ovid 184.3699951171875 +nick polk 199.57000064849854 +nick quirinius 174.80999755859375 +nick quirinius 243.8300018310547 +nick robinson 129.65999603271484 +nick robinson 216.54999923706055 +nick steinbeck 97.83000183105469 +nick thompson 205.4900016784668 +nick underhill 166.42000007629395 +nick van buren 222.6500015258789 +nick xylophone 75.3499984741211 +nick young 332.23999786376953 +nick young 346.41000083088875 +nick zipper 222.9199981689453 +nick zipper 529.7199974060059 +oscar allen 246.42999839782715 +oscar brown 274.01000022888184 +oscar carson 78.9800033569336 +oscar carson 87.4800033569336 +oscar carson 98.51000213623047 +oscar carson 203.86000442504883 +oscar carson 321.82000064849854 +oscar davidson 361.6699981689453 +oscar ellison 146.44000244140625 +oscar ellison 234.32000160217285 +oscar falkner 98.4800033569336 +oscar garcia 231.04000091552734 +oscar hernandez 85.48999786376953 +oscar hernandez 95.4800033569336 +oscar ichabod 71.80000305175781 +oscar ichabod 123.78000068664551 +oscar ichabod 173.31000518798828 +oscar ichabod 251.22000122070312 +oscar johnson 146.27000427246094 +oscar johnson 260.1600036621094 +oscar king 124.2699966430664 +oscar king 249.5399990081787 +oscar king 284.8599910736084 +oscar laertes 15.640000343322754 +oscar laertes 254.8499984741211 +oscar laertes 261.41000175476074 +oscar laertes 261.8400020599365 +oscar nixon 41.619998931884766 +oscar ovid 82.23999786376953 +oscar ovid 187.76000213623047 +oscar ovid 260.6100044250488 +oscar polk 63.900001525878906 +oscar polk 252.71000289916992 +oscar quirinius 73.4800033569336 +oscar quirinius 165.3800048828125 +oscar quirinius 244.2699966430664 +oscar quirinius 248.75 +oscar robinson 93.31999969482422 +oscar robinson 163.55999755859375 +oscar robinson 191.8300018310547 +oscar robinson 315.1999912261963 +oscar steinbeck 376.6899948120117 +oscar thompson 131.1400032043457 +oscar thompson 148.01000213623047 +oscar thompson 325.42000579833984 +oscar thompson 545.7399940490723 +oscar underhill 87.4000015258789 +oscar van buren 61.880001068115234 +oscar van buren 188.8699951171875 +oscar van buren 209.53000235557556 +oscar white 129.73999786376953 +oscar white 148.9800033569336 +oscar white 275.1500015258789 +oscar white 303.8599910736084 +oscar xylophone 115.22999954223633 +oscar xylophone 319.75000381469727 +oscar xylophone 475.3300018310547 +oscar zipper 109.53999710083008 +oscar zipper 214.40999603271484 +oscar zipper 214.6500015258789 +priscilla brown 77.56999969482422 +priscilla brown 165.5199966430664 +priscilla brown 408.4499969482422 +priscilla carson 168.8300018310547 +priscilla carson 195.7900047302246 +priscilla carson 207.5300006866455 +priscilla ichabod 92.61000061035156 +priscilla ichabod 206.16000366210938 +priscilla johnson 89.1500015258789 +priscilla johnson 156.4600067138672 +priscilla johnson 158.88000106811523 +priscilla johnson 190.61000061035156 +priscilla johnson 211.01000022888184 +priscilla king 371.9299964904785 +priscilla nixon 95.80999755859375 +priscilla nixon 278.87999725341797 +priscilla ovid 96.27000284194946 +priscilla ovid 198.3400001525879 +priscilla polk 252.5800018310547 +priscilla quirinius 131.8499994277954 +priscilla thompson 230.36000156402588 +priscilla underhill 143.56999969482422 +priscilla underhill 354.37000274658203 +priscilla van buren 82.72000122070312 +priscilla van buren 145.61000061035156 +priscilla van buren 183.72000122070312 +priscilla white 78.27999877929688 +priscilla xylophone 21.489999771118164 +priscilla xylophone 159.26000213623047 +priscilla xylophone 406.1000007688999 +priscilla young 163.2900013923645 +priscilla young 260.59000366926193 +priscilla zipper 311.399995803833 +priscilla zipper 327.97999572753906 +quinn allen 257.94000244140625 +quinn allen 365.2299995422363 +quinn brown 80.58000183105469 +quinn brown 80.81999969482422 +quinn brown 198.71000289916992 +quinn davidson 83.4000015258789 +quinn davidson 95.11000061035156 +quinn davidson 154.79000091552734 +quinn davidson 227.13999938964844 +quinn ellison 237.17000007629395 +quinn ellison 361.14000129699707 +quinn garcia 92.33000183105469 +quinn garcia 148.63999938964844 +quinn garcia 226.78999710083008 +quinn garcia 246.70000076293945 +quinn ichabod 89.63999938964844 +quinn king 74.62000274658203 +quinn king 86.2300033569336 +quinn laertes 112.36000061035156 +quinn laertes 243.6900019645691 +quinn laertes 265.51000213623047 +quinn nixon 149.3300018310547 +quinn ovid 393.2099943161011 +quinn quirinius 266.8200035095215 +quinn robinson 247.6400032043457 +quinn steinbeck 144.81999969482422 +quinn steinbeck 213.65999507904053 +quinn thompson 156.51000213623047 +quinn thompson 274.1599931716919 +quinn underhill 248.3000030517578 +quinn underhill 252.61999130249023 +quinn underhill 321.9799976348877 +quinn van buren 82.5199966430664 +quinn young 90.97999954223633 +quinn zipper 58.0 +quinn zipper 249.38999938964844 +rachel allen 15.8100004196167 +rachel allen 151.80999755859375 +rachel brown 193.5800018310547 +rachel brown 312.0800018310547 +rachel brown 347.7999954223633 +rachel brown 423.98999214172363 +rachel brown 437.64999771118164 +rachel carson 98.95999908447266 +rachel carson 385.3799934387207 +rachel davidson 396.38999938964844 +rachel ellison 299.12000465393066 +rachel falkner 88.80000305175781 +rachel falkner 99.23999786376953 +rachel falkner 172.54999542236328 +rachel falkner 233.55999755859375 +rachel johnson 197.92999649047852 +rachel king 36.220001220703125 +rachel king 219.8400001525879 +rachel laertes 97.17000198364258 +rachel laertes 109.5999984741211 +rachel ovid 80.20999872684479 +rachel ovid 260.18999683856964 +rachel polk 89.27999877929688 +rachel quirinius 205.1400022506714 +rachel robinson 254.1300015449524 +rachel robinson 286.0400047302246 +rachel robinson 332.4199981689453 +rachel thompson 137.73000198602676 +rachel thompson 213.31000137329102 +rachel thompson 380.85999488830566 +rachel underhill 175.6099967956543 +rachel white 94.72000122070312 +rachel white 196.18000030517578 +rachel young 230.6400032043457 +rachel zipper 148.9000015258789 +rachel zipper 238.98000192642212 +sarah carson 175.62000274658203 +sarah carson 307.70000088214874 +sarah carson 386.8999948501587 +sarah ellison 161.80999946594238 +sarah falkner 99.36000061035156 +sarah falkner 281.62000274658203 +sarah garcia 73.6500015258789 +sarah garcia 153.73000144958496 +sarah garcia 312.8899955749512 +sarah ichabod 81.31999969482422 +sarah ichabod 97.26000213623047 +sarah johnson 140.37999725341797 +sarah johnson 177.57000732421875 +sarah johnson 248.4499969482422 +sarah johnson 309.1800022125244 +sarah king 216.75 +sarah king 268.5399932861328 +sarah miller 222.31000518798828 +sarah ovid 146.25000381469727 +sarah robinson 143.43000030517578 +sarah robinson 310.75 +sarah steinbeck 208.72000122070312 +sarah white 140.22999572753906 +sarah white 181.86000061035156 +sarah xylophone 68.31999969482422 +sarah young 185.80999755859375 +sarah zipper 168.22000122070312 +tom brown 181.1000051498413 +tom brown 404.3500061035156 +tom carson 142.60999822616577 +tom carson 299.57999992370605 +tom carson 592.3499927520752 +tom davidson 180.61000061035156 +tom ellison 98.2300033569336 +tom ellison 154.58999633789062 +tom ellison 173.02999877929688 +tom falkner 88.22000122070312 +tom falkner 139.11000442504883 +tom hernandez 81.63999938964844 +tom hernandez 263.67000579833984 +tom ichabod 214.0699977874756 +tom johnson 405.95000076293945 +tom johnson 438.9099922180176 +tom king 218.18000030517578 +tom laertes 244.37000274658203 +tom laertes 473.0999984741211 +tom miller 68.25 +tom miller 85.59000015258789 +tom miller 127.56999969482422 +tom nixon 85.02999877929688 +tom ovid 217.32000160217285 +tom polk 188.87000274658203 +tom polk 206.52000045776367 +tom quirinius 120.27000427246094 +tom quirinius 232.63000202178955 +tom robinson 90.69000244140625 +tom robinson 98.72000122070312 +tom robinson 99.1500015258789 +tom robinson 209.5399932861328 +tom steinbeck 277.7100009918213 +tom van buren 40.779998779296875 +tom van buren 217.70000076293945 +tom van buren 375.2099964618683 +tom white 223.4700050354004 +tom young 174.36000061035156 +tom young 304.8199977874756 +tom zipper 213.7900047302246 +ulysses brown 247.1500015258789 +ulysses carson 77.41999816894531 +ulysses carson 79.54000091552734 +ulysses carson 150.93000030517578 +ulysses carson 162.24000549316406 +ulysses davidson 414.7100009918213 +ulysses ellison 96.7300033569336 +ulysses garcia 89.80000305175781 +ulysses hernandez 106.29999542236328 +ulysses hernandez 134.44000244140625 +ulysses hernandez 160.22000122070312 +ulysses ichabod 98.56999969482422 +ulysses ichabod 309.34999656677246 +ulysses johnson 152.47000122070312 +ulysses king 244.7100067138672 +ulysses laertes 138.4400042295456 +ulysses laertes 173.55999755859375 +ulysses laertes 256.91999912261963 +ulysses miller 76.27999877929688 +ulysses miller 417.67000102996826 +ulysses nixon 174.56999969482422 +ulysses ovid 130.13000106811523 +ulysses polk 123.9399995803833 +ulysses polk 149.95999908447266 +ulysses polk 205.2400016784668 +ulysses polk 237.5699920654297 +ulysses quirinius 330.4700012207031 +ulysses robinson 79.48999786376953 +ulysses steinbeck 144.8300018310547 +ulysses steinbeck 155.66000366210938 +ulysses thompson 159.92000579833984 +ulysses underhill 81.58000183105469 +ulysses underhill 88.4800033569336 +ulysses underhill 99.66999816894531 +ulysses underhill 135.55999755859375 +ulysses underhill 189.1099977493286 +ulysses underhill 289.6800003051758 +ulysses underhill 385.60000228881836 +ulysses van buren 95.52999877929688 +ulysses white 188.8300018310547 +ulysses white 305.79000091552734 +ulysses xylophone 54.099998474121094 +ulysses xylophone 205.2099952697754 +ulysses xylophone 251.94000148773193 +ulysses young 100.77000045776367 +ulysses young 275.8300018310547 +ulysses young 522.1700019836426 +victor allen 220.1699981689453 +victor allen 222.10000228881836 +victor brown 77.88999938964844 +victor brown 90.37999725341797 +victor brown 91.97000122070312 +victor brown 455.25000381469727 +victor davidson 149.06000137329102 +victor davidson 291.48000717163086 +victor davidson 321.25 +victor ellison 314.37000274658203 +victor ellison 442.50000190734863 +victor hernandez 69.87999725341797 +victor hernandez 99.85000038146973 +victor hernandez 143.02000045776367 +victor hernandez 160.38999938964844 +victor hernandez 391.2999954223633 +victor johnson 145.18000030517578 +victor johnson 190.27000045776367 +victor johnson 308.1900006532669 +victor king 108.10000228881836 +victor king 310.5 +victor laertes 145.42999267578125 +victor laertes 214.72999572753906 +victor miller 173.76000022888184 +victor nixon 68.5 +victor nixon 269.5899963378906 +victor ovid 151.39999771118164 +victor polk 175.8799991607666 +victor quirinius 65.55000305175781 +victor quirinius 168.5500030517578 +victor robinson 177.9100022315979 +victor robinson 204.09999084472656 +victor steinbeck 52.720001220703125 +victor steinbeck 220.45999908447266 +victor steinbeck 309.4900064468384 +victor thompson 58.65999984741211 +victor van buren 206.77999877929688 +victor van buren 222.44000244140625 +victor white 156.36999893188477 +victor white 167.2699966430664 +victor xylophone 158.36999893188477 +victor xylophone 161.54000091552734 +victor xylophone 234.76000308990479 +victor xylophone 267.82999420166016 +victor xylophone 314.95000076293945 +victor young 88.55000305175781 +victor zipper 192.92999649047852 +wendy allen 56.06999969482422 +wendy allen 66.16000306606293 +wendy allen 267.3199996948242 +wendy brown 453.53000259399414 +wendy brown 525.5100040435791 +wendy ellison 193.95000457763672 +wendy ellison 260.9099998474121 +wendy falkner 77.36000061035156 +wendy falkner 97.68000030517578 +wendy falkner 128.30999565124512 +wendy garcia 4.409999847412109 +wendy garcia 76.72000122070312 +wendy garcia 189.42999839782715 +wendy garcia 265.5900001525879 +wendy hernandez 48.11000061035156 +wendy ichabod 104.3700008392334 +wendy king 156.89999771118164 +wendy king 183.31999969482422 +wendy king 403.27000427246094 +wendy laertes 79.98999786376953 +wendy laertes 165.0999984741211 +wendy laertes 365.0 +wendy miller 72.9500002861023 +wendy miller 313.8300037384033 +wendy nixon 45.91999816894531 +wendy nixon 60.2599983215332 +wendy ovid 95.33000183105469 +wendy ovid 180.36000061035156 +wendy polk 386.7400016784668 +wendy polk 443.3400018811226 +wendy quirinius 152.04999828338623 +wendy quirinius 240.23999977111816 +wendy robinson 71.06999969482422 +wendy robinson 249.35000610351562 +wendy robinson 391.4699993133545 +wendy steinbeck 92.11000061035156 +wendy thompson 136.35000228881836 +wendy thompson 183.1500015258789 +wendy underhill 318.6500015258789 +wendy underhill 320.75000190734863 +wendy underhill 328.2300033569336 +wendy van buren 57.459999084472656 +wendy van buren 92.81999969482422 +wendy white 171.36000061035156 +wendy xylophone 153.62999725341797 +wendy xylophone 223.94999885559082 +wendy young 40.22000026702881 +wendy young 513.8299942016602 +xavier allen 102.97000122070312 +xavier allen 168.3300018310547 +xavier allen 197.45999908447266 +xavier brown 55.20000076293945 +xavier brown 90.7300033569336 +xavier brown 96.2300033569336 +xavier carson 193.63999938964844 +xavier carson 265.1600036621094 +xavier davidson 63.349998474121094 +xavier davidson 264.27000427246094 +xavier davidson 288.1999988555908 +xavier ellison 138.42000198364258 +xavier ellison 262.6300048828125 +xavier garcia 148.66000366210938 +xavier hernandez 122.13999938964844 +xavier hernandez 164.97000122070312 +xavier hernandez 306.25 +xavier ichabod 211.84000635147095 +xavier ichabod 244.50000762939453 +xavier johnson 56.53999900817871 +xavier johnson 89.0999984741211 +xavier king 87.22000122070312 +xavier king 151.22999572753906 +xavier laertes 183.65999794006348 +xavier ovid 398.2100067138672 +xavier polk 72.62000274658203 +xavier polk 76.93000030517578 +xavier polk 261.5100030899048 +xavier polk 318.01000213623047 +xavier quirinius 22.1200008392334 +xavier quirinius 89.55000305175781 +xavier quirinius 246.2400016784668 +xavier quirinius 402.2100009918213 +xavier thompson 283.9400005340576 +xavier underhill 120.45000076293945 +xavier white 138.02999591827393 +xavier white 172.06999969482422 +xavier xylophone 79.41999816894531 +xavier zipper 373.67999935150146 +yuri allen 52.849998474121094 +yuri allen 417.3700008392334 +yuri brown 170.52000427246094 +yuri brown 180.70999908447266 +yuri carson 188.99000549316406 +yuri carson 537.6500015258789 +yuri ellison 86.91999816894531 +yuri ellison 376.32999646663666 +yuri falkner 152.99000358581543 +yuri falkner 181.06999969482422 +yuri garcia 274.6800003051758 +yuri hernandez 153.46999764442444 +yuri johnson 197.28000259399414 +yuri johnson 236.0800018310547 +yuri johnson 258.1899985074997 +yuri king 551.9899978637695 +yuri laertes 37.59000015258789 +yuri laertes 253.4799976348877 +yuri nixon 95.54999732971191 +yuri nixon 248.9700005054474 +yuri polk 82.33999633789062 +yuri polk 275.3200035095215 +yuri polk 305.6399974822998 +yuri quirinius 112.97000122070312 +yuri quirinius 148.27999877929688 +yuri quirinius 449.1699924468994 +yuri steinbeck 292.94000244140625 +yuri steinbeck 357.5 +yuri thompson 428.03999519348145 +yuri underhill 83.87000274658203 +yuri underhill 350.7999897003174 +yuri white 132.09000396728516 +yuri xylophone 107.07000160217285 +zach allen 65.43000030517578 +zach brown 135.6999969482422 +zach brown 247.04999542236328 +zach brown 256.8000030517578 +zach brown 362.38000106811523 +zach brown 418.75 +zach carson 291.7700004577637 +zach ellison 135.149995803833 +zach falkner 91.41999816894531 +zach falkner 196.41999912261963 +zach garcia 84.37999725341797 +zach garcia 160.70000457763672 +zach garcia 167.7599983215332 +zach garcia 205.36999893188477 +zach ichabod 116.2699966430664 +zach ichabod 151.18000030517578 +zach king 127.63000106811523 +zach king 182.2699966430664 +zach king 269.0999984741211 +zach miller 199.71000289916992 +zach miller 220.73999977111816 +zach miller 264.0600047111511 +zach ovid 92.55000305175781 +zach ovid 94.33999633789062 +zach ovid 105.94999847561121 +zach ovid 136.04000091552734 +zach quirinius 103.11000061035156 +zach robinson 76.72000122070312 +zach steinbeck 85.48999786376953 +zach steinbeck 182.87000274658203 +zach thompson 116.93999862670898 +zach thompson 319.9499969482422 +zach underhill 86.22000122070312 +zach white 70.52999877929688 +zach xylophone 227.52000427246094 +zach xylophone 286.45000076293945 +zach young 313.00999450683594 +zach zipper 85.87000274658203 +zach zipper 94.43000030517578 +zach zipper 139.38999938964844 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-3-6f104992e0050576085064815de43194 b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-3-6f104992e0050576085064815de43194 new file mode 100644 index 000000000000..ae2a1e9dd7d3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-3-6f104992e0050576085064815de43194 @@ -0,0 +1,1049 @@ + 43.52666695912679 + 48.735000133514404 + 57.96666590372721 +alice allen 39.21833221117655 +alice allen 42.813999557495116 +alice allen 49.1824996471405 +alice brown 59.68166707456112 +alice carson 46.703333189090095 +alice davidson 59.51833279927572 +alice falkner 64.48333247502644 +alice garcia 55.114999771118164 +alice hernandez 49.61333228151003 +alice hernandez 69.70166667302449 +alice johnson 49.25166805585226 +alice king 20.052499681711197 +alice king 49.019999186197914 +alice king 56.0733331044515 +alice laertes 35.24999898672104 +alice laertes 68.85333251953125 +alice miller 55.542000198364256 +alice nixon 44.010000785191856 +alice nixon 50.900000762939456 +alice nixon 64.78333409627278 +alice ovid 24.90999937057495 +alice polk 47.426666259765625 +alice quirinius 52.80166610081991 +alice quirinius 52.94833393891653 +alice robinson 44.41500027974447 +alice robinson 55.04249954223633 +alice steinbeck 37.735000928243004 +alice steinbeck 41.02800045013428 +alice steinbeck 53.83499916394552 +alice underhill 52.64500021934509 +alice van buren 48.27666727701823 +alice xylophone 28.047500252723694 +alice xylophone 37.82199954986572 +alice xylophone 59.75166606903076 +alice zipper 48.875000635782875 +alice zipper 51.365000108877815 +alice zipper 89.93000030517578 +bob brown 38.134999910990395 +bob brown 56.20666662851969 +bob brown 77.51166598002116 +bob carson 52.77200050354004 +bob davidson 28.945000171661377 +bob davidson 37.946666399637856 +bob davidson 53.90333271026611 +bob ellison 51.383999824523926 +bob ellison 57.30333344141642 +bob ellison 58.53666559855143 +bob ellison 66.7400016784668 +bob falkner 39.053333600362144 +bob garcia 16.705000400543213 +bob garcia 37.16499960422516 +bob garcia 51.17999919255575 +bob garcia 52.37666575113932 +bob garcia 67.39166768391927 +bob hernandez 53.261999893188474 +bob ichabod 43.96999979019165 +bob king 38.03666718800863 +bob king 38.19249892234802 +bob king 67.40999794006348 +bob laertes 21.449999809265137 +bob laertes 50.37166612346967 +bob miller 41.906000471115114 +bob ovid 27.836666425069172 +bob ovid 39.3833335240682 +bob ovid 60.90749979019165 +bob ovid 63.069999313354494 +bob polk 41.88333296775818 +bob quirinius 54.504998207092285 +bob steinbeck 34.33999983469645 +bob van buren 58.29999923706055 +bob white 30.22333288192749 +bob white 41.44000005722046 +bob xylophone 12.163333415985107 +bob xylophone 47.88249921798706 +bob young 27.38499927520752 +bob zipper 27.93800084590912 +bob zipper 59.11800007820129 +bob zipper 66.43000030517578 +calvin allen 71.51000118255615 +calvin brown 50.44166612625122 +calvin brown 53.625000953674316 +calvin brown 67.48333485921223 +calvin carson 62.17750120162964 +calvin davidson 14.03000009059906 +calvin davidson 43.640000343322754 +calvin ellison 52.52750015258789 +calvin falkner 24.016666332880657 +calvin falkner 46.87333329518636 +calvin falkner 47.53666607538859 +calvin falkner 56.47499918937683 +calvin falkner 57.336666107177734 +calvin falkner 72.25 +calvin garcia 46.484999338785805 +calvin hernandez 43.90199909210205 +calvin johnson 76.41500091552734 +calvin laertes 39.09249973297119 +calvin laertes 47.323333422342934 +calvin nixon 30.113332668940227 +calvin nixon 47.7533327738444 +calvin nixon 49.08500075340271 +calvin ovid 41.924999713897705 +calvin ovid 49.10499978065491 +calvin ovid 62.26499938964844 +calvin ovid 62.27999989191691 +calvin polk 52.95333353678385 +calvin quirinius 53.34200019836426 +calvin quirinius 54.831998634338376 +calvin robinson 60.470001220703125 +calvin steinbeck 35.644999980926514 +calvin steinbeck 53.75800037384033 +calvin steinbeck 56.11000084877014 +calvin thompson 42.355000257492065 +calvin thompson 65.91666666666667 +calvin underhill 47.41199951171875 +calvin van buren 31.591666102409363 +calvin van buren 43.18200063705444 +calvin white 45.27500128746033 +calvin white 56.079999923706055 +calvin xylophone 20.394999980926514 +calvin xylophone 40.59999958674113 +calvin xylophone 54.56500005722046 +calvin young 42.246666272481285 +calvin young 64.49000072479248 +calvin zipper 57.49500036239624 +calvin zipper 57.59000039100647 +david allen 40.14999977747599 +david allen 54.545000076293945 +david brown 35.4516666730245 +david brown 62.83999938964844 +david davidson 35.50000067551931 +david davidson 52.75250005722046 +david davidson 74.1500015258789 +david davidson 95.80999755859375 +david ellison 43.43166727821032 +david ellison 52.18750047683716 +david ellison 72.58400039672851 +david hernandez 64.47600173950195 +david ichabod 29.518332719802856 +david ichabod 34.6100010573864 +david laertes 61.127999496459964 +david nixon 53.60249900817871 +david ovid 37.977500796318054 +david ovid 41.58999983469645 +david quirinius 24.96250009536743 +david quirinius 45.23249912261963 +david quirinius 52.282000350952146 +david robinson 42.17750024795532 +david robinson 62.54666519165039 +david thompson 41.88999938964844 +david underhill 48.143333752950035 +david underhill 62.53499941031138 +david underhill 97.55999755859375 +david van buren 34.84749984741211 +david van buren 51.61666679382324 +david white 62.30499839782715 +david xylophone 33.92500019073486 +david xylophone 53.142000222206114 +david xylophone 72.77166684468587 +david young 30.13666756947835 +david young 51.540000915527344 +ethan allen 53.442500591278076 +ethan brown 7.110000133514404 +ethan brown 41.260000228881836 +ethan brown 41.495000084241234 +ethan brown 41.83833312988281 +ethan brown 46.497499227523804 +ethan brown 63.17750072479248 +ethan carson 64.31999969482422 +ethan ellison 46.72599983215332 +ethan ellison 61.24750167876482 +ethan falkner 36.90250104665756 +ethan falkner 52.71000099182129 +ethan garcia 32.771666407585144 +ethan hernandez 56.239999008178714 +ethan johnson 82.38333384195964 +ethan king 8.399999856948853 +ethan laertes 37.93200054168701 +ethan laertes 40.4883329073588 +ethan laertes 53.396667132774986 +ethan laertes 53.81666628519694 +ethan laertes 55.987499713897705 +ethan laertes 66.36999956766765 +ethan laertes 68.83400039672851 +ethan miller 56.519999504089355 +ethan nixon 50.13333353648583 +ethan ovid 37.51333363850912 +ethan polk 2.3499999046325684 +ethan polk 44.147999954223636 +ethan polk 54.91499996185303 +ethan polk 60.340000788370766 +ethan quirinius 40.21750020980835 +ethan quirinius 47.125000298023224 +ethan quirinius 57.69000013669332 +ethan robinson 41.5060001373291 +ethan robinson 55.0640007019043 +ethan underhill 68.01250076293945 +ethan van buren 43.55250072479248 +ethan white 46.40999913215637 +ethan white 53.03200073242188 +ethan xylophone 70.29333432515462 +ethan zipper 46.92999982833862 +ethan zipper 66.04500198364258 +fred davidson 26.487499618902802 +fred davidson 44.99333477020264 +fred davidson 53.93666648864746 +fred ellison 22.2599999109904 +fred ellison 48.047999954223634 +fred ellison 63.423333485921226 +fred falkner 20.035000324249268 +fred falkner 44.243999004364014 +fred falkner 62.130001068115234 +fred hernandez 37.42199997901916 +fred ichabod 46.56999969482422 +fred ichabod 58.275001525878906 +fred johnson 54.63333257039388 +fred king 46.30499919255575 +fred king 61.48500037193298 +fred laertes 26.203333059946697 +fred miller 43.46400032043457 +fred nixon 28.690000534057617 +fred nixon 32.77999955415726 +fred nixon 52.03799936771393 +fred nixon 60.468332290649414 +fred polk 26.81166632970174 +fred polk 42.48999996185303 +fred polk 46.95200061798096 +fred polk 60.465998840332034 +fred quirinius 43.466000366210935 +fred quirinius 45.79799957275391 +fred robinson 62.42833296457926 +fred steinbeck 32.22999954223633 +fred steinbeck 48.73800039291382 +fred steinbeck 65.91750144958496 +fred underhill 72.94250011444092 +fred van buren 41.28250002861023 +fred van buren 41.5283338278532 +fred van buren 50.871665954589844 +fred van buren 81.77000045776367 +fred white 34.897499561309814 +fred young 58.56666692097982 +fred young 70.61499977111816 +fred zipper 33.90500068664551 +gabriella allen 51.041666666666664 +gabriella allen 55.13999938964844 +gabriella brown 54.63333320617676 +gabriella brown 72.33333396911621 +gabriella carson 49.26333236694336 +gabriella davidson 52.72800083160401 +gabriella ellison 54.68999989827474 +gabriella ellison 71.54000091552734 +gabriella falkner 20.834000015258788 +gabriella falkner 46.348333517710365 +gabriella falkner 48.58500099182129 +gabriella garcia 24.78999964396159 +gabriella hernandez 57.61000029246012 +gabriella hernandez 65.11166667938232 +gabriella ichabod 33.78749895095825 +gabriella ichabod 38.9950008392334 +gabriella ichabod 48.15499925613403 +gabriella ichabod 49.739999008178714 +gabriella ichabod 52.789999643961586 +gabriella king 35.58200044631958 +gabriella king 49.63000059723854 +gabriella laertes 47.81000010172526 +gabriella miller 62.17666753133138 +gabriella ovid 45.94333267211914 +gabriella ovid 50.435001373291016 +gabriella polk 42.58000100851059 +gabriella polk 72.87000179290771 +gabriella steinbeck 65.42000102996826 +gabriella steinbeck 72.0499997138977 +gabriella thompson 49.897499561309814 +gabriella thompson 52.84250023961067 +gabriella thompson 57.23249959945679 +gabriella van buren 39.31000053882599 +gabriella van buren 57.920000076293945 +gabriella white 49.85333410898844 +gabriella young 15.369999885559082 +gabriella young 49.04999907811483 +gabriella zipper 42.82499901453654 +gabriella zipper 59.58499972025553 +holly allen 29.50499963760376 +holly brown 55.284998178482056 +holly brown 55.79833386838436 +holly falkner 40.07666663328806 +holly hernandez 46.40500044822693 +holly hernandez 56.14500045776367 +holly hernandez 56.160000483194985 +holly hernandez 58.95166748017073 +holly ichabod 67.19250011444092 +holly ichabod 68.44500064849854 +holly ichabod 90.09000015258789 +holly johnson 42.795000076293945 +holly johnson 55.76999855041504 +holly johnson 66.11333338419597 +holly king 48.788000869750974 +holly king 64.25 +holly laertes 50.40750074386597 +holly miller 45.60666608810425 +holly nixon 58.096665700276695 +holly nixon 69.59249877929688 +holly polk 41.59666601816813 +holly polk 44.64799900054932 +holly robinson 48.405999755859376 +holly thompson 34.605000495910645 +holly thompson 38.36749941110611 +holly thompson 65.99833394338687 +holly underhill 46.52600040435791 +holly underhill 47.92666663726171 +holly underhill 62.41600036621094 +holly underhill 79.95999908447266 +holly van buren 58.63333225250244 +holly white 31.232499361038208 +holly white 61.494998931884766 +holly xylophone 70.32500044504802 +holly young 54.05000114440918 +holly young 57.103333473205566 +holly zipper 50.59500014781952 +holly zipper 67.81000137329102 +irene allen 53.364000701904295 +irene brown 4.789999961853027 +irene brown 51.65399932861328 +irene brown 87.66999816894531 +irene carson 59.2433336575826 +irene ellison 36.446667989095054 +irene ellison 53.85000038146973 +irene falkner 47.429999995231626 +irene falkner 61.20666694641113 +irene garcia 15.369999885559082 +irene garcia 48.28000005086263 +irene garcia 55.643333435058594 +irene ichabod 40.868333299954735 +irene ichabod 64.45666694641113 +irene johnson 44.37999935150147 +irene laertes 27.625 +irene laertes 46.25 +irene laertes 49.17500019073486 +irene miller 75.87500063578288 +irene nixon 29.780000686645508 +irene nixon 30.070000807444256 +irene nixon 36.34333356221517 +irene ovid 27.21500023206075 +irene ovid 55.6175012588501 +irene ovid 60.353999328613284 +irene polk 47.438334147135414 +irene polk 47.77666505177816 +irene polk 51.74200067520142 +irene polk 52.12200088500977 +irene polk 52.184000205993655 +irene quirinius 52.958333015441895 +irene quirinius 55.29000053405762 +irene quirinius 78.79000091552734 +irene robinson 61.398332595825195 +irene steinbeck 64.34000142415364 +irene thompson 41.92250097543001 +irene underhill 27.72499966621399 +irene underhill 50.783331871032715 +irene van buren 46.65999984741211 +irene van buren 49.71000099182129 +irene xylophone 61.220001220703125 +jessica brown 63.449999491373696 +jessica carson 41.78999951481819 +jessica carson 47.52000069618225 +jessica carson 51.83000183105469 +jessica davidson 34.938333332538605 +jessica davidson 45.91000066200892 +jessica davidson 51.082499504089355 +jessica davidson 64.3099988301595 +jessica ellison 32.53000124295553 +jessica ellison 35.265999984741214 +jessica falkner 54.933334032694496 +jessica garcia 32.575000286102295 +jessica garcia 60.584999084472656 +jessica ichabod 46.704999923706055 +jessica johnson 55.73199977874756 +jessica johnson 72.14000066121419 +jessica miller 55.27799835205078 +jessica nixon 58.53999900817871 +jessica nixon 90.06999969482422 +jessica ovid 36.47500014305115 +jessica ovid 63.03000005086263 +jessica polk 69.52999877929688 +jessica quirinius 19.01333288351695 +jessica quirinius 41.630000829696655 +jessica quirinius 42.58750009536743 +jessica quirinius 47.72999986012777 +jessica robinson 64.81000061035157 +jessica thompson 45.83666737874349 +jessica thompson 57.540000915527344 +jessica underhill 45.360000133514404 +jessica underhill 51.40999889373779 +jessica underhill 64.27250099182129 +jessica van buren 9.739999771118164 +jessica white 38.38999891281128 +jessica white 40.75400023460388 +jessica white 61.89800128936768 +jessica white 62.34749984741211 +jessica white 65.7680004119873 +jessica xylophone 50.808332761128746 +jessica young 18.400000254313152 +jessica young 49.04400033950806 +jessica zipper 35.039999643961586 +jessica zipper 52.78999869028727 +jessica zipper 55.51499891281128 +katie allen 56.10800056457519 +katie brown 48.86833222707113 +katie davidson 96.91000366210938 +katie ellison 31.106667200724285 +katie ellison 38.356666127840676 +katie falkner 17.24333318074544 +katie garcia 53.429999033610024 +katie garcia 53.933334002892174 +katie hernandez 72.71666717529297 +katie ichabod 39.97999954223633 +katie ichabod 54.99599914550781 +katie ichabod 59.41999944051107 +katie king 44.88599967956543 +katie king 60.2549991607666 +katie king 67.27499961853027 +katie miller 31.399999618530273 +katie miller 50.57666619618734 +katie nixon 23.190000534057617 +katie ovid 52.96000158786774 +katie polk 51.029998779296875 +katie polk 54.93600006103516 +katie robinson 13.890000343322754 +katie van buren 52.78999908765157 +katie van buren 53.21750092506409 +katie white 59.799998950958255 +katie white 64.70499992370605 +katie xylophone 53.54499959945679 +katie young 31.010000228881836 +katie young 37.55500102043152 +katie young 49.87499872843424 +katie zipper 29.79666694998741 +katie zipper 50.95000076293945 +luke allen 43.757998657226565 +luke allen 53.18000030517578 +luke allen 62.09749984741211 +luke allen 70.2933349609375 +luke allen 89.55000305175781 +luke brown 45.95999972025553 +luke davidson 28.950000762939453 +luke davidson 53.20500040054321 +luke ellison 5.159999907016754 +luke ellison 18.630000392595928 +luke ellison 76.88000106811523 +luke falkner 6.175000190734863 +luke falkner 54.0049991607666 +luke garcia 15.320000424981117 +luke garcia 25.47000026702881 +luke ichabod 33.950000286102295 +luke ichabod 57.17500114440918 +luke johnson 14.4399995803833 +luke johnson 37.23333263397217 +luke johnson 52.660000801086426 +luke laertes 11.819999694824219 +luke laertes 23.5 +luke laertes 43.71999907493591 +luke laertes 45.9900016784668 +luke laertes 57.85499906539917 +luke miller 67.85000038146973 +luke ovid 38.04999923706055 +luke ovid 79.64200134277344 +luke polk 55.106666564941406 +luke polk 70.18999862670898 +luke quirinius 26.016666491826374 +luke robinson 61.256665547688804 +luke robinson 72.61999893188477 +luke thompson 47.203332940737404 +luke underhill 49.244999408721924 +luke underhill 59.32999897003174 +luke underhill 72.18333307902019 +luke van buren 64.646666208903 +luke white 67.12000274658203 +luke xylophone 46.74333190917969 +luke zipper 36.16000008583069 +mike allen 10.709999859333038 +mike brown 61.0674991607666 +mike carson 50.90249824523926 +mike carson 52.30200061798096 +mike carson 52.51499938964844 +mike davidson 29.733333269755047 +mike davidson 47.396666526794434 +mike ellison 37.807498931884766 +mike ellison 44.339999198913574 +mike ellison 45.89199924468994 +mike ellison 47.91249990463257 +mike ellison 55.295000076293945 +mike falkner 16.479999542236328 +mike garcia 38.98999913533529 +mike garcia 54.91333325703939 +mike garcia 57.096666971842446 +mike hernandez 59.45000076293945 +mike hernandez 62.446667989095054 +mike ichabod 54.69666544596354 +mike king 38.790000915527344 +mike king 43.2050017118454 +mike king 46.85666592915853 +mike king 46.945000648498535 +mike king 53.382500648498535 +mike king 94.68000030517578 +mike miller 3.9600000381469727 +mike nixon 53.894999742507935 +mike nixon 92.95999908447266 +mike polk 12.449999809265137 +mike polk 39.040000319480896 +mike polk 71.40500068664551 +mike quirinius 89.37999725341797 +mike steinbeck 34.05499875545502 +mike steinbeck 45.909999465942384 +mike steinbeck 51.286667346954346 +mike steinbeck 59.292500019073486 +mike van buren 43.13999819755554 +mike van buren 49.59000015258789 +mike white 30.53999964396159 +mike white 34.61499961217245 +mike white 48.08999943733215 +mike white 53.967501163482666 +mike young 10.484999895095825 +mike young 47.070000648498535 +mike young 74.58999633789062 +mike zipper 29.029999288419884 +mike zipper 77.96200027465821 +mike zipper 91.57500076293945 +nick allen 29.146666367848713 +nick allen 64.4325008392334 +nick brown 49.429999669392906 +nick davidson 34.36599922180176 +nick ellison 49.345001220703125 +nick ellison 89.01000213623047 +nick falkner 7.820000171661377 +nick falkner 45.025000631809235 +nick garcia 23.9499994913737 +nick garcia 33.49333349863688 +nick garcia 64.33499908447266 +nick ichabod 30.945000171661377 +nick ichabod 56.27499961853027 +nick ichabod 59.24249863624573 +nick johnson 32.3674995303154 +nick johnson 74.30666859944661 +nick laertes 38.38749980926514 +nick miller 82.97000122070312 +nick nixon 70.01333173116048 +nick ovid 56.82999897003174 +nick polk 33.929999669392906 +nick quirinius 58.91199932098389 +nick quirinius 81.16999816894531 +nick robinson 34.53999948501587 +nick robinson 59.80499839782715 +nick steinbeck 57.25333340962728 +nick thompson 51.3725004196167 +nick underhill 55.47333335876465 +nick van buren 74.21666717529297 +nick xylophone 51.82000001271566 +nick young 0.27000001072883606 +nick young 41.02499961853027 +nick zipper 47.72999954223633 +nick zipper 61.917999267578125 +oscar allen 35.21600015163422 +oscar brown 38.03999948501587 +oscar carson 38.09600009918213 +oscar carson 51.34333419799805 +oscar carson 56.9925012588501 +oscar carson 74.00500106811523 +oscar carson 98.51000213623047 +oscar davidson 65.88750076293945 +oscar ellison 50.507500410079956 +oscar ellison 66.57000096638997 +oscar falkner 64.42000198364258 +oscar garcia 66.36333465576172 +oscar hernandez 42.089999516805015 +oscar hernandez 51.21199997067451 +oscar ichabod 25.300000190734863 +oscar ichabod 41.173332850138344 +oscar ichabod 53.44666830698649 +oscar ichabod 71.80000305175781 +oscar johnson 39.9366668065389 +oscar johnson 44.30500118434429 +oscar king 30.59749937057495 +oscar king 46.149999141693115 +oscar king 49.65999984741211 +oscar laertes 5.510000228881836 +oscar laertes 19.099999745686848 +oscar laertes 39.36250039935112 +oscar laertes 45.340000788370766 +oscar nixon 41.619998931884766 +oscar ovid 45.09000015258789 +oscar ovid 82.23999786376953 +oscar ovid 83.47750091552734 +oscar polk 42.54333368937174 +oscar polk 56.8033332824707 +oscar quirinius 52.94500160217285 +oscar quirinius 65.34666570027669 +oscar quirinius 67.00000127156575 +oscar quirinius 79.4800033569336 +oscar robinson 32.16999944051107 +oscar robinson 38.880001068115234 +oscar robinson 63.9433339436849 +oscar robinson 65.80500030517578 +oscar steinbeck 45.73599967956543 +oscar thompson 40.33599853515625 +oscar thompson 47.860000928243004 +oscar thompson 60.529998779296875 +oscar thompson 60.59333292643229 +oscar underhill 43.980000764131546 +oscar van buren 2.180000066757202 +oscar van buren 53.34999942779541 +oscar van buren 61.880001068115234 +oscar white 38.0633331934611 +oscar white 42.584999084472656 +oscar white 55.179999033610024 +oscar white 74.4900016784668 +oscar xylophone 41.3133331934611 +oscar xylophone 55.5049991607666 +oscar xylophone 67.54500007629395 +oscar zipper 15.680000305175781 +oscar zipper 24.019999504089355 +oscar zipper 39.81999969482422 +priscilla brown 42.88666502634684 +priscilla brown 63.242499351501465 +priscilla brown 77.56999969482422 +priscilla carson 44.799999713897705 +priscilla carson 49.663333892822266 +priscilla carson 78.57333374023438 +priscilla ichabod 56.83666737874349 +priscilla ichabod 58.48666636149088 +priscilla johnson 34.416667779286705 +priscilla johnson 52.890000343322754 +priscilla johnson 53.72666883468628 +priscilla johnson 61.939998626708984 +priscilla johnson 89.1500015258789 +priscilla king 34.30750045180321 +priscilla nixon 27.734999656677246 +priscilla nixon 60.90999984741211 +priscilla ovid 48.13500142097473 +priscilla ovid 66.35999870300293 +priscilla polk 15.149999618530273 +priscilla quirinius 18.606666564941406 +priscilla thompson 48.87000131607056 +priscilla underhill 40.05500078201294 +priscilla underhill 49.54999961853027 +priscilla van buren 42.62666734059652 +priscilla van buren 61.54000017642975 +priscilla van buren 72.80500030517578 +priscilla white 43.177499771118164 +priscilla xylophone 21.489999771118164 +priscilla xylophone 40.144999124109745 +priscilla xylophone 59.61000061035156 +priscilla young 31.610000324249267 +priscilla young 53.71000152826309 +priscilla zipper 18.8799991607666 +priscilla zipper 25.670000076293945 +quinn allen 56.77749991416931 +quinn allen 76.47250080108643 +quinn brown 23.536666870117188 +quinn brown 31.829999446868896 +quinn brown 50.388000297546384 +quinn davidson 41.42499923706055 +quinn davidson 45.90250015258789 +quinn davidson 48.16000032424927 +quinn davidson 71.51000022888184 +quinn ellison 50.6766668955485 +quinn ellison 62.32666842142741 +quinn garcia 39.98599967956543 +quinn garcia 43.27000045776367 +quinn garcia 74.31999969482422 +quinn garcia 92.33000183105469 +quinn ichabod 44.81999969482422 +quinn king 50.99666786193848 +quinn king 74.62000274658203 +quinn laertes 5.884999990463257 +quinn laertes 49.1379997253418 +quinn laertes 56.18000030517578 +quinn nixon 74.66500091552734 +quinn ovid 1.2100000381469727 +quinn quirinius 46.15500068664551 +quinn robinson 44.96249985694885 +quinn steinbeck 24.06999921798706 +quinn steinbeck 41.750000953674316 +quinn thompson 23.744999408721924 +quinn thompson 73.64500045776367 +quinn underhill 41.83333269755045 +quinn underhill 49.63749885559082 +quinn underhill 83.85000228881836 +quinn van buren 54.36333338419596 +quinn young 45.489999771118164 +quinn zipper 22.25 +quinn zipper 33.355000019073486 +rachel allen 15.8100004196167 +rachel allen 71.57666524251302 +rachel brown 2.9600000381469727 +rachel brown 33.022499561309814 +rachel brown 42.442498207092285 +rachel brown 59.21999931335449 +rachel brown 64.52666727701823 +rachel carson 61.17999776204427 +rachel carson 69.85333315531413 +rachel davidson 42.292500495910645 +rachel ellison 10.600000381469727 +rachel falkner 23.615000784397125 +rachel falkner 50.46250069141388 +rachel falkner 54.669999837875366 +rachel falkner 72.96999740600586 +rachel johnson 61.76249885559082 +rachel king 22.005000591278076 +rachel king 66.46500015258789 +rachel laertes 39.025001525878906 +rachel laertes 45.45000076293945 +rachel ovid 0.6000000238418579 +rachel ovid 1.0800000429153442 +rachel polk 78.08499908447266 +rachel quirinius 51.12000131607056 +rachel robinson 30.360000610351562 +rachel robinson 37.683333237965904 +rachel robinson 39.199998219807945 +rachel thompson 0.5600000023841858 +rachel thompson 38.81000010172526 +rachel thompson 49.772499561309814 +rachel underhill 48.45000076293945 +rachel white 42.83999991416931 +rachel white 58.609999656677246 +rachel young 43.130001068115234 +rachel zipper 7.059999942779541 +rachel zipper 49.676667173703514 +sarah carson 54.022500067949295 +sarah carson 54.18333212534586 +sarah carson 87.81000137329102 +sarah ellison 16.989999771118164 +sarah falkner 73.06500005722046 +sarah falkner 99.36000061035156 +sarah garcia 38.43250036239624 +sarah garcia 64.97333272298177 +sarah garcia 73.6500015258789 +sarah ichabod 57.46000003814697 +sarah ichabod 81.31999969482422 +sarah johnson 34.98499917984009 +sarah johnson 45.099998474121094 +sarah johnson 74.42249870300293 +sarah johnson 74.72000122070312 +sarah king 41.869998931884766 +sarah king 48.25 +sarah miller 41.53499984741211 +sarah ovid 33.38000011444092 +sarah robinson 33.83000183105469 +sarah robinson 66.88999938964844 +sarah steinbeck 31.023332993189495 +sarah white 45.974998474121094 +sarah white 61.54666709899902 +sarah xylophone 61.60666529337565 +sarah young 45.560001373291016 +sarah zipper 60.69000053405762 +tom brown 40.08000135421753 +tom brown 55.44499969482422 +tom carson 5.440000057220459 +tom carson 26.32499885559082 +tom carson 31.476666768391926 +tom davidson 53.00749921798706 +tom ellison 67.32666714986165 +tom ellison 76.73999786376953 +tom ellison 77.29499816894531 +tom falkner 60.130001068115234 +tom falkner 88.22000122070312 +tom hernandez 41.36000061035156 +tom hernandez 81.63999938964844 +tom ichabod 42.08666547139486 +tom johnson 33.75999959309896 +tom johnson 53.47666549682617 +tom king 40.0 +tom laertes 32.05000114440918 +tom laertes 43.46000099182129 +tom miller 21.229999542236328 +tom miller 43.37666702270508 +tom miller 44.41333325703939 +tom nixon 46.35333251953125 +tom ovid 43.04499912261963 +tom polk 38.29999923706055 +tom polk 54.45000076293945 +tom quirinius 19.82000058889389 +tom quirinius 22.580000400543213 +tom robinson 66.16999816894531 +tom robinson 74.18666712443034 +tom robinson 80.77000045776367 +tom robinson 98.72000122070312 +tom steinbeck 44.999999046325684 +tom van buren 31.103334546089172 +tom van buren 40.779998779296875 +tom van buren 63.5099983215332 +tom white 49.06500053405762 +tom young 54.16999912261963 +tom young 78.54999923706055 +tom zipper 48.666666984558105 +ulysses brown 72.79000091552734 +ulysses carson 40.28500175476074 +ulysses carson 71.55000305175781 +ulysses carson 77.41999816894531 +ulysses carson 79.54000091552734 +ulysses davidson 41.90166711807251 +ulysses ellison 96.7300033569336 +ulysses garcia 89.80000305175781 +ulysses hernandez 21.339999516805012 +ulysses hernandez 54.470001220703125 +ulysses hernandez 80.11000061035156 +ulysses ichabod 19.1299991607666 +ulysses ichabod 98.56999969482422 +ulysses johnson 42.55000019073486 +ulysses king 81.57000223795573 +ulysses laertes 1.9199999570846558 +ulysses laertes 11.890000343322754 +ulysses laertes 53.599998474121094 +ulysses miller 40.43499946594238 +ulysses miller 47.78200054168701 +ulysses nixon 74.60333251953125 +ulysses ovid 30.940000534057617 +ulysses polk 8.710000038146973 +ulysses polk 39.17500019073486 +ulysses polk 65.0199966430664 +ulysses polk 68.41333389282227 +ulysses quirinius 65.93499946594238 +ulysses robinson 79.48999786376953 +ulysses steinbeck 32.40999984741211 +ulysses steinbeck 43.93499994277954 +ulysses thompson 79.96000289916992 +ulysses underhill 17.85000006357829 +ulysses underhill 42.5533332824707 +ulysses underhill 46.60000157356262 +ulysses underhill 47.84000049829483 +ulysses underhill 51.85000133514404 +ulysses underhill 67.77999877929688 +ulysses underhill 99.66999816894531 +ulysses van buren 69.89999961853027 +ulysses white 45.01500183343887 +ulysses white 71.3933334350586 +ulysses xylophone 27.880000591278076 +ulysses xylophone 39.689998626708984 +ulysses xylophone 54.099998474121094 +ulysses young 32.52000045776367 +ulysses young 39.81333382924398 +ulysses young 80.7933349609375 +victor allen 44.27000045776367 +victor allen 57.994998931884766 +victor brown 56.84499931335449 +victor brown 59.34000015258789 +victor brown 90.37999725341797 +victor brown 91.97000122070312 +victor davidson 52.07000160217285 +victor davidson 54.239999771118164 +victor davidson 68.8033332824707 +victor ellison 45.45750088989735 +victor ellison 58.42999839782715 +victor hernandez 23.164999961853027 +victor hernandez 57.98499870300293 +victor hernandez 59.619998931884766 +victor hernandez 71.42499923706055 +victor hernandez 80.19499969482422 +victor johnson 1.5800000429153442 +victor johnson 46.7450008392334 +victor johnson 72.59000015258789 +victor king 37.559998989105225 +victor king 47.880001068115234 +victor laertes 50.0099983215332 +victor laertes 70.13999938964844 +victor miller 22.1200008392334 +victor nixon 34.029998779296875 +victor nixon 37.08500003814697 +victor ovid 18.815000653266907 +victor polk 3.0 +victor quirinius 26.450000405311584 +victor quirinius 33.080001533031464 +victor robinson 4.590000152587891 +victor robinson 41.21999931335449 +victor steinbeck 41.82500123977661 +victor steinbeck 46.09000015258789 +victor steinbeck 52.720001220703125 +victor thompson 45.346666971842446 +victor van buren 34.970001220703125 +victor van buren 46.57333246866862 +victor white 5.670000076293945 +victor white 74.16999816894531 +victor xylophone 11.220000267028809 +victor xylophone 28.954999923706055 +victor xylophone 34.010000228881836 +victor xylophone 43.179999351501465 +victor xylophone 62.38999938964844 +victor young 70.91000175476074 +victor zipper 48.795000076293945 +wendy allen 0.6100000143051147 +wendy allen 52.64999961853027 +wendy allen 56.06999969482422 +wendy brown 51.874999046325684 +wendy brown 66.73999913533528 +wendy ellison 27.014999389648438 +wendy ellison 94.66000366210938 +wendy falkner 14.425000190734863 +wendy falkner 77.36000061035156 +wendy falkner 85.68000030517578 +wendy garcia 4.409999847412109 +wendy garcia 20.390000343322754 +wendy garcia 38.42500060796738 +wendy garcia 57.25 +wendy hernandez 48.11000061035156 +wendy ichabod 13.149999618530273 +wendy king 33.234999656677246 +wendy king 74.97999954223633 +wendy king 87.94000244140625 +wendy laertes 49.01999855041504 +wendy laertes 54.750000635782875 +wendy laertes 79.98999786376953 +wendy miller 1.2699999809265137 +wendy miller 47.41500186920166 +wendy nixon 30.17999916151166 +wendy nixon 45.91999816894531 +wendy ovid 85.26000213623047 +wendy ovid 86.62999725341797 +wendy polk 43.679999669392906 +wendy polk 54.65333424011866 +wendy quirinius 12.15999984741211 +wendy quirinius 14.300000190734863 +wendy robinson 58.815001487731934 +wendy robinson 71.06999969482422 +wendy robinson 71.77999877929688 +wendy steinbeck 77.31500053405762 +wendy thompson 58.09000015258789 +wendy thompson 85.76000213623047 +wendy underhill 41.23333231608073 +wendy underhill 49.33500003814697 +wendy underhill 81.56000137329102 +wendy van buren 57.459999084472656 +wendy van buren 91.43500137329102 +wendy white 73.68000030517578 +wendy xylophone 31.149999618530273 +wendy xylophone 76.81499862670898 +wendy young 8.449999809265137 +wendy young 33.7599983215332 +xavier allen 45.68000030517578 +xavier allen 52.255000591278076 +xavier allen 83.93000030517578 +xavier brown 7.789999961853027 +xavier brown 77.15500259399414 +xavier brown 80.35000228881836 +xavier carson 20.790000915527344 +xavier carson 55.30000019073486 +xavier davidson 15.920000076293945 +xavier davidson 58.46999931335449 +xavier davidson 82.41000366210938 +xavier ellison 16.614999771118164 +xavier ellison 77.97000122070312 +xavier garcia 49.415000915527344 +xavier hernandez 6.670000076293945 +xavier hernandez 23.054999828338623 +xavier hernandez 67.26000213623047 +xavier ichabod 4.71999979019165 +xavier ichabod 71.19000244140625 +xavier johnson 27.299999237060547 +xavier johnson 89.0999984741211 +xavier king 22.729999542236328 +xavier king 87.22000122070312 +xavier laertes 24.050000190734863 +xavier ovid 58.08000183105469 +xavier polk 13.869999885559082 +xavier polk 58.98750066757202 +xavier polk 72.62000274658203 +xavier polk 76.93000030517578 +xavier quirinius 22.1200008392334 +xavier quirinius 58.24666786193848 +xavier quirinius 62.52000045776367 +xavier quirinius 89.55000305175781 +xavier thompson 9.930000305175781 +xavier underhill 47.27000045776367 +xavier white 8.369999885559082 +xavier white 75.29000091552734 +xavier xylophone 79.41999816894531 +xavier zipper 8.449999809265137 +yuri allen 30.6299991607666 +yuri allen 52.849998474121094 +yuri brown 75.19000244140625 +yuri brown 84.02999877929688 +yuri carson 6.289999961853027 +yuri carson 91.16000366210938 +yuri ellison 1.1200000047683716 +yuri ellison 86.91999816894531 +yuri falkner 6.739999771118164 +yuri falkner 80.8650016784668 +yuri garcia 27.65999984741211 +yuri hernandez 2.069999933242798 +yuri johnson 0.12999999523162842 +yuri johnson 39.900001525878906 +yuri johnson 48.220001220703125 +yuri king 22.270000457763672 +yuri laertes 10.15999984741211 +yuri laertes 37.59000015258789 +yuri nixon 2.200000047683716 +yuri nixon 17.3700008392334 +yuri polk 26.760000228881836 +yuri polk 28.790000915527344 +yuri polk 82.33999633789062 +yuri quirinius 10.260000228881836 +yuri quirinius 38.69000053405762 +yuri quirinius 57.93000030517578 +yuri steinbeck 56.064998626708984 +yuri steinbeck 75.87999725341797 +yuri thompson 14.920000076293945 +yuri underhill 23.770000457763672 +yuri underhill 83.87000274658203 +yuri white 34.58000183105469 +yuri xylophone 20.3799991607666 +zach allen 65.43000030517578 +zach brown 48.0099983215332 +zach brown 57.08000183105469 +zach brown 58.24999809265137 +zach brown 67.37999725341797 +zach brown 75.7300033569336 +zach carson 67.78500175476074 +zach ellison 6.840000152587891 +zach falkner 50.274999141693115 +zach falkner 91.41999816894531 +zach garcia 32.20000076293945 +zach garcia 35.79999923706055 +zach garcia 69.97000122070312 +zach garcia 84.37999725341797 +zach ichabod 36.88999938964844 +zach ichabod 64.25 +zach king 31.864999771118164 +zach king 46.18000030517578 +zach king 86.93000030517578 +zach miller 2.5999999046325684 +zach miller 21.280000686645508 +zach miller 53.27000045776367 +zach ovid 0.10000000149011612 +zach ovid 23.06999969482422 +zach ovid 92.55000305175781 +zach ovid 94.33999633789062 +zach quirinius 39.209999084472656 +zach robinson 76.72000122070312 +zach steinbeck 85.48999786376953 +zach steinbeck 90.05000305175781 +zach thompson 53.59000015258789 +zach thompson 71.5 +zach underhill 86.22000122070312 +zach white 70.52999877929688 +zach xylophone 29.40999984741211 +zach xylophone 71.01000213623047 +zach young 71.31999969482422 +zach zipper 7.539999961853027 +zach zipper 85.87000274658203 +zach zipper 94.43000030517578 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-4-cd2e3d2344810cb3ba843d4c01c81d7e b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-4-cd2e3d2344810cb3ba843d4c01c81d7e new file mode 100644 index 000000000000..ee1c26e331a1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-4-cd2e3d2344810cb3ba843d4c01c81d7e @@ -0,0 +1,1049 @@ + 17.601666666666667 + 30.72 + 33.07833333333334 +alice allen 23.081666666666667 +alice allen 23.263333333333332 +alice allen 31.38 +alice brown 11.518333333333333 +alice carson 31.99166666666667 +alice davidson 26.986666666666665 +alice falkner 30.513333333333335 +alice garcia 17.13 +alice hernandez 28.126666666666665 +alice hernandez 28.69666666666667 +alice johnson 28.30333333333333 +alice king 23.451666666666668 +alice king 25.20333333333333 +alice king 30.613999999999997 +alice laertes 23.633333333333336 +alice laertes 30.16428571428571 +alice miller 32.50833333333333 +alice nixon 25.278333333333336 +alice nixon 35.345 +alice nixon 36.458333333333336 +alice ovid 26.04714285714286 +alice polk 23.185 +alice quirinius 20.388333333333335 +alice quirinius 23.064999999999998 +alice robinson 30.296666666666667 +alice robinson 33.15 +alice steinbeck 27.894999999999996 +alice steinbeck 31.47833333333334 +alice steinbeck 36.089999999999996 +alice underhill 22.513333333333335 +alice van buren 36.32000000000001 +alice xylophone 27.355000000000004 +alice xylophone 30.505714285714284 +alice xylophone 30.613999999999997 +alice zipper 20.113333333333333 +alice zipper 28.058333333333337 +alice zipper 29.705000000000002 +bob brown 13.546666666666669 +bob brown 16.50333333333333 +bob brown 27.59 +bob carson 25.781666666666666 +bob davidson 18.073333333333334 +bob davidson 23.573333333333334 +bob davidson 31.894999999999996 +bob ellison 16.493333333333336 +bob ellison 17.889999999999997 +bob ellison 25.84142857142857 +bob ellison 33.07833333333334 +bob falkner 21.783333333333335 +bob garcia 16.492857142857144 +bob garcia 23.088333333333335 +bob garcia 26.42833333333333 +bob garcia 32.709999999999994 +bob garcia 33.91833333333334 +bob hernandez 22.303333333333338 +bob ichabod 23.336666666666662 +bob king 20.591666666666665 +bob king 21.244999999999997 +bob king 33.15333333333333 +bob laertes 21.240000000000002 +bob laertes 28.88 +bob miller 32.158750000000005 +bob ovid 21.83666666666667 +bob ovid 23.678571428571427 +bob ovid 25.12333333333333 +bob ovid 31.46 +bob polk 16.3475 +bob quirinius 28.465714285714284 +bob steinbeck 26.99 +bob van buren 26.127142857142854 +bob white 23.021666666666672 +bob white 23.582857142857144 +bob xylophone 26.18166666666667 +bob xylophone 27.995 +bob young 25.495714285714286 +bob zipper 26.435000000000002 +bob zipper 28.384285714285713 +bob zipper 30.65833333333333 +calvin allen 20.812 +calvin brown 20.808 +calvin brown 28.26 +calvin brown 28.37875 +calvin carson 20.38714285714286 +calvin davidson 22.65142857142857 +calvin davidson 23.585714285714285 +calvin ellison 24.6875 +calvin falkner 16.073333333333334 +calvin falkner 17.054285714285715 +calvin falkner 25.2025 +calvin falkner 28.525714285714283 +calvin falkner 33.382 +calvin falkner 37.29666666666667 +calvin garcia 19.924285714285713 +calvin hernandez 21.759999999999998 +calvin johnson 24.03222222222222 +calvin laertes 24.702857142857145 +calvin laertes 28.788333333333338 +calvin nixon 15.631250000000001 +calvin nixon 25.64428571428571 +calvin nixon 29.084999999999997 +calvin ovid 16.243333333333332 +calvin ovid 25.2025 +calvin ovid 25.935555555555553 +calvin ovid 29.299999999999997 +calvin polk 29.615 +calvin quirinius 19.294999999999998 +calvin quirinius 31.64625 +calvin robinson 27.811428571428575 +calvin steinbeck 17.14 +calvin steinbeck 17.535 +calvin steinbeck 21.551666666666666 +calvin thompson 28.592857142857145 +calvin thompson 33.382 +calvin underhill 20.113333333333333 +calvin van buren 28.384285714285713 +calvin van buren 32.106 +calvin white 26.948333333333334 +calvin white 28.256249999999998 +calvin xylophone 25.33666666666667 +calvin xylophone 27.061428571428568 +calvin xylophone 28.384285714285713 +calvin young 19.331666666666667 +calvin young 24.63 +calvin zipper 16.08125 +calvin zipper 28.80555555555556 +david allen 26.21375 +david allen 26.948333333333334 +david brown 14.222 +david brown 27.1 +david davidson 21.062857142857144 +david davidson 26.12 +david davidson 30.951428571428576 +david davidson 31.96142857142857 +david ellison 22.687142857142856 +david ellison 23.435 +david ellison 25.442999999999998 +david hernandez 28.279999999999998 +david ichabod 16.3475 +david ichabod 23.751428571428573 +david laertes 29.07 +david nixon 25.535714285714285 +david ovid 26.75714285714286 +david ovid 29.912857142857142 +david quirinius 17.179999999999996 +david quirinius 20.808 +david quirinius 22.90625 +david robinson 25.274 +david robinson 25.885 +david thompson 22.19125 +david underhill 20.812 +david underhill 21.546666666666667 +david underhill 28.26 +david van buren 24.472857142857148 +david van buren 32.75125 +david white 22.861428571428576 +david xylophone 19.3325 +david xylophone 26.930000000000003 +david xylophone 30.093333333333334 +david young 21.12375 +david young 25.03857142857143 +ethan allen 27.038333333333338 +ethan brown 16.3475 +ethan brown 19.37 +ethan brown 23.73 +ethan brown 25.57166666666667 +ethan brown 26.168333333333333 +ethan brown 31.893333333333334 +ethan carson 25.655714285714286 +ethan ellison 27.807777777777776 +ethan ellison 32.244285714285716 +ethan falkner 28.287142857142857 +ethan falkner 31.411428571428576 +ethan garcia 24.153750000000002 +ethan hernandez 17.986 +ethan johnson 31.54571428571429 +ethan king 22.62 +ethan laertes 15.045555555555556 +ethan laertes 17.889999999999997 +ethan laertes 24.28 +ethan laertes 25.2025 +ethan laertes 26.56888888888889 +ethan laertes 28.37875 +ethan laertes 34.84571428571429 +ethan miller 28.592857142857145 +ethan nixon 29.86833333333333 +ethan ovid 17.85166666666667 +ethan polk 16.463 +ethan polk 18.184545454545457 +ethan polk 25.737142857142857 +ethan polk 29.895000000000003 +ethan quirinius 21.827777777777776 +ethan quirinius 23.42 +ethan quirinius 35.54333333333333 +ethan robinson 24.03222222222222 +ethan robinson 36.35333333333333 +ethan underhill 21.545555555555556 +ethan van buren 15.21142857142857 +ethan white 29.702857142857145 +ethan white 33.33 +ethan xylophone 29.521666666666665 +ethan zipper 23.994 +ethan zipper 26.765 +fred davidson 27.729999999999997 +fred davidson 29.90625 +fred davidson 30.804999999999996 +fred ellison 16.720000000000002 +fred ellison 22.17125 +fred ellison 33.15833333333334 +fred falkner 17.13142857142857 +fred falkner 26.82 +fred falkner 31.925555555555555 +fred hernandez 28.094285714285718 +fred ichabod 23.352500000000003 +fred ichabod 32.906666666666666 +fred johnson 16.9925 +fred king 20.125 +fred king 30.377142857142854 +fred laertes 26.458571428571428 +fred miller 29.66666666666667 +fred nixon 19.565555555555555 +fred nixon 21.830000000000002 +fred nixon 25.828333333333333 +fred nixon 28.094285714285718 +fred polk 21.744999999999997 +fred polk 22.301666666666666 +fred polk 24.156666666666666 +fred polk 28.217142857142857 +fred quirinius 27.495 +fred quirinius 29.615714285714283 +fred robinson 24.243636363636366 +fred steinbeck 17.91333333333333 +fred steinbeck 21.12375 +fred steinbeck 26.47125 +fred underhill 26.43166666666667 +fred van buren 16.165714285714284 +fred van buren 23.285 +fred van buren 24.875714285714288 +fred van buren 27.878333333333334 +fred white 29.498571428571427 +fred young 17.889999999999997 +fred young 19.565555555555555 +fred zipper 21.581428571428575 +gabriella allen 22.03142857142857 +gabriella allen 26.87375 +gabriella brown 29.675714285714285 +gabriella brown 32.44 +gabriella carson 28.75 +gabriella davidson 27.531250000000004 +gabriella ellison 19.82 +gabriella ellison 27.353333333333335 +gabriella falkner 17.535 +gabriella falkner 19.487500000000004 +gabriella falkner 30.964999999999996 +gabriella garcia 20.544 +gabriella hernandez 20.818333333333335 +gabriella hernandez 28.094285714285718 +gabriella ichabod 10.58 +gabriella ichabod 18.64272727272727 +gabriella ichabod 20.504444444444445 +gabriella ichabod 23.185 +gabriella ichabod 23.35375 +gabriella king 16.18 +gabriella king 27.887500000000003 +gabriella laertes 23.799090909090907 +gabriella miller 15.695714285714283 +gabriella ovid 23.515454545454546 +gabriella ovid 33.33 +gabriella polk 20.38714285714286 +gabriella polk 35.77285714285715 +gabriella steinbeck 16.18 +gabriella steinbeck 32.464999999999996 +gabriella thompson 26.477777777777778 +gabriella thompson 27.29777777777778 +gabriella thompson 30.87666666666667 +gabriella van buren 28.513333333333335 +gabriella van buren 32.41111111111111 +gabriella white 26.765 +gabriella young 24.022499999999997 +gabriella young 29.521666666666665 +gabriella zipper 21.43727272727273 +gabriella zipper 32.106 +holly allen 24.271428571428572 +holly brown 22.959 +holly brown 27.498 +holly falkner 29.66666666666667 +holly hernandez 14.179999999999998 +holly hernandez 22.396666666666665 +holly hernandez 27.434000000000005 +holly hernandez 27.887500000000003 +holly ichabod 27.87375 +holly ichabod 32.525 +holly ichabod 34.042857142857144 +holly johnson 20.808 +holly johnson 25.024285714285718 +holly johnson 30.487142857142857 +holly king 23.185 +holly king 29.008888888888894 +holly laertes 19.41 +holly miller 29.89333333333333 +holly nixon 23.162857142857142 +holly nixon 28.876250000000002 +holly polk 22.7825 +holly polk 26.297499999999996 +holly robinson 24.160000000000004 +holly thompson 19.565555555555555 +holly thompson 27.048999999999996 +holly thompson 29.70555555555556 +holly underhill 17.876250000000002 +holly underhill 27.820000000000004 +holly underhill 30.613999999999997 +holly underhill 30.708 +holly van buren 20.113333333333333 +holly white 23.185 +holly white 29.64272727272727 +holly xylophone 26.400909090909092 +holly young 27.807777777777776 +holly young 31.63 +holly zipper 27.401999999999997 +holly zipper 28.384285714285713 +irene allen 35.345 +irene brown 22.527500000000003 +irene brown 28.384285714285713 +irene brown 32.81875 +irene carson 27.10666666666667 +irene ellison 16.720000000000002 +irene ellison 28.592857142857145 +irene falkner 19.41 +irene falkner 30.564999999999998 +irene garcia 16.9925 +irene garcia 24.03222222222222 +irene garcia 26.297499999999996 +irene ichabod 27.540000000000003 +irene ichabod 29.34875 +irene johnson 25.418181818181814 +irene laertes 22.124285714285712 +irene laertes 22.200000000000003 +irene laertes 24.446666666666665 +irene miller 30.166666666666668 +irene nixon 18.922222222222224 +irene nixon 25.2025 +irene nixon 33.382 +irene ovid 28.256249999999998 +irene ovid 31.63 +irene ovid 32.088750000000005 +irene polk 24.63 +irene polk 25.296363636363637 +irene polk 27.137142857142855 +irene polk 30.65222222222222 +irene polk 33.760000000000005 +irene quirinius 27.044999999999998 +irene quirinius 33.07833333333334 +irene quirinius 41.865 +irene robinson 32.18875 +irene steinbeck 16.463 +irene thompson 25.281666666666666 +irene underhill 24.4025 +irene underhill 28.531 +irene van buren 25.532222222222224 +irene van buren 32.50142857142857 +irene xylophone 26.288181818181815 +jessica brown 28.968181818181822 +jessica carson 19.41 +jessica carson 24.854285714285716 +jessica carson 25.406363636363633 +jessica davidson 22.19625 +jessica davidson 23.888 +jessica davidson 26.297499999999996 +jessica davidson 26.825454545454537 +jessica ellison 22.07777777777778 +jessica ellison 33.33 +jessica falkner 22.637272727272727 +jessica garcia 14.749999999999996 +jessica garcia 29.675714285714285 +jessica ichabod 31.831249999999997 +jessica johnson 21.546666666666667 +jessica johnson 29.986363636363638 +jessica miller 28.735000000000003 +jessica nixon 19.13111111111111 +jessica nixon 26.244999999999997 +jessica ovid 25.274 +jessica ovid 33.181666666666665 +jessica polk 26.79222222222222 +jessica quirinius 20.38714285714286 +jessica quirinius 25.776000000000003 +jessica quirinius 28.26 +jessica quirinius 29.605000000000008 +jessica robinson 24.5625 +jessica thompson 25.736 +jessica thompson 30.87363636363636 +jessica underhill 16.400000000000002 +jessica underhill 25.529090909090915 +jessica underhill 31.63 +jessica van buren 24.446666666666665 +jessica white 20.812 +jessica white 23.26 +jessica white 27.807777777777776 +jessica white 29.031000000000006 +jessica white 30.654545454545453 +jessica xylophone 15.296666666666667 +jessica young 26.718333333333334 +jessica young 27.853749999999998 +jessica zipper 20.3575 +jessica zipper 24.446666666666665 +jessica zipper 29.276363636363637 +katie allen 29.029090909090915 +katie brown 24.156666666666666 +katie davidson 15.383749999999997 +katie ellison 20.978333333333335 +katie ellison 26.96 +katie falkner 24.5625 +katie garcia 27.807777777777776 +katie garcia 28.287142857142857 +katie hernandez 23.667272727272724 +katie ichabod 14.222 +katie ichabod 20.553749999999997 +katie ichabod 31.831249999999997 +katie king 20.05444444444445 +katie king 20.242222222222225 +katie king 23.342857142857145 +katie miller 26.21857142857143 +katie miller 27.675000000000004 +katie nixon 14.476999999999999 +katie ovid 28.37875 +katie polk 20.99 +katie polk 25.090000000000003 +katie robinson 30.65222222222222 +katie van buren 27.133636363636366 +katie van buren 29.675714285714285 +katie white 22.555714285714288 +katie white 24.463749999999997 +katie xylophone 25.74142857142857 +katie young 23.011250000000004 +katie young 26.650000000000002 +katie young 29.301428571428573 +katie zipper 28.26 +katie zipper 29.675714285714285 +luke allen 16.9925 +luke allen 20.595000000000002 +luke allen 27.54181818181818 +luke allen 27.887500000000003 +luke allen 33.07833333333334 +luke brown 29.34875 +luke davidson 27.3575 +luke davidson 31.473333333333333 +luke ellison 14.024444444444443 +luke ellison 22.555714285714288 +luke ellison 28.592857142857145 +luke falkner 21.855 +luke falkner 27.044999999999998 +luke garcia 27.887500000000003 +luke garcia 31.237000000000002 +luke ichabod 28.198571428571427 +luke ichabod 34.345000000000006 +luke johnson 21.239999999999995 +luke johnson 25.462727272727275 +luke johnson 30.188888888888894 +luke laertes 18.344 +luke laertes 20.817 +luke laertes 27.401999999999997 +luke laertes 30.72285714285714 +luke laertes 41.865 +luke miller 22.539000000000005 +luke ovid 16.615454545454543 +luke ovid 26.06625 +luke polk 28.163000000000004 +luke polk 28.840909090909097 +luke quirinius 27.077142857142857 +luke robinson 24.816363636363644 +luke robinson 27.110909090909093 +luke thompson 28.44454545454545 +luke underhill 22.175 +luke underhill 25.518888888888892 +luke underhill 27.34125 +luke van buren 16.54 +luke white 20.544 +luke xylophone 24.5625 +luke zipper 24.764285714285712 +mike allen 23.860000000000003 +mike brown 31.016363636363643 +mike carson 26.066363636363636 +mike carson 28.947142857142858 +mike carson 33.382 +mike davidson 20.544 +mike davidson 21.239999999999995 +mike ellison 18.922 +mike ellison 21.4175 +mike ellison 25.45272727272727 +mike ellison 27.26 +mike ellison 28.39888888888889 +mike falkner 29.397777777777776 +mike garcia 20.544 +mike garcia 24.582 +mike garcia 34.84571428571429 +mike hernandez 10.4925 +mike hernandez 17.7 +mike ichabod 26.772727272727273 +mike king 17.889999999999997 +mike king 19.294999999999998 +mike king 20.004285714285714 +mike king 23.197999999999997 +mike king 23.285 +mike king 27.401999999999997 +mike miller 31.587272727272726 +mike nixon 17.775555555555556 +mike nixon 27.044999999999998 +mike polk 22.175 +mike polk 23.751428571428573 +mike polk 28.095 +mike quirinius 19.13111111111111 +mike steinbeck 14.222 +mike steinbeck 18.100909090909088 +mike steinbeck 18.344 +mike steinbeck 33.760000000000005 +mike van buren 23.42 +mike van buren 25.828333333333333 +mike white 19.13111111111111 +mike white 25.755714285714284 +mike white 29.031000000000006 +mike white 30.516999999999996 +mike young 26.765 +mike young 27.766 +mike young 28.409090909090914 +mike zipper 17.306 +mike zipper 33.23 +mike zipper 41.865 +nick allen 19.331666666666667 +nick allen 32.106 +nick brown 27.578181818181818 +nick davidson 29.100000000000005 +nick ellison 24.764285714285712 +nick ellison 29.521666666666665 +nick falkner 22.555714285714288 +nick falkner 23.15888888888888 +nick garcia 21.546666666666667 +nick garcia 26.25090909090909 +nick garcia 30.166666666666668 +nick ichabod 21.855 +nick ichabod 23.479000000000003 +nick ichabod 29.100000000000005 +nick johnson 25.274 +nick johnson 29.994 +nick laertes 25.820909090909094 +nick miller 19.87888888888889 +nick nixon 17.082 +nick ovid 33.597777777777786 +nick polk 25.736 +nick quirinius 20.707500000000003 +nick quirinius 28.094285714285718 +nick robinson 22.396666666666665 +nick robinson 25.298749999999995 +nick steinbeck 21.192857142857143 +nick thompson 30.72285714285714 +nick underhill 29.345 +nick van buren 25.152727272727272 +nick xylophone 26.948333333333334 +nick young 23.751428571428573 +nick young 24.810000000000002 +nick zipper 24.854285714285716 +nick zipper 27.353333333333335 +oscar allen 18.815 +oscar brown 26.948333333333334 +oscar carson 24.764285714285712 +oscar carson 27.766 +oscar carson 28.094285714285718 +oscar carson 28.31555555555556 +oscar carson 35.22818181818182 +oscar davidson 17.535 +oscar ellison 22.121428571428574 +oscar ellison 28.735000000000003 +oscar falkner 19.294999999999998 +oscar garcia 20.62636363636364 +oscar hernandez 22.539000000000005 +oscar hernandez 23.31888888888889 +oscar ichabod 20.818333333333335 +oscar ichabod 21.546666666666667 +oscar ichabod 26.914545454545454 +oscar ichabod 28.811111111111106 +oscar johnson 22.381818181818183 +oscar johnson 24.266363636363643 +oscar king 15.296666666666667 +oscar king 25.580000000000002 +oscar king 28.37875 +oscar laertes 21.51818181818182 +oscar laertes 23.285 +oscar laertes 24.4025 +oscar laertes 25.345454545454547 +oscar nixon 18.88111111111111 +oscar ovid 24.854285714285716 +oscar ovid 25.274 +oscar ovid 33.29636363636364 +oscar polk 19.331666666666667 +oscar polk 29.34875 +oscar quirinius 22.928 +oscar quirinius 25.66727272727273 +oscar quirinius 25.970909090909092 +oscar quirinius 29.66666666666667 +oscar robinson 20.90666666666667 +oscar robinson 21.855 +oscar robinson 23.42 +oscar robinson 32.90500000000001 +oscar steinbeck 32.02818181818182 +oscar thompson 20.817 +oscar thompson 21.477000000000004 +oscar thompson 21.843636363636367 +oscar thompson 23.559000000000005 +oscar underhill 22.555714285714288 +oscar van buren 27.210000000000008 +oscar van buren 28.592857142857145 +oscar van buren 31.375454545454545 +oscar white 20.818333333333335 +oscar white 21.748 +oscar white 24.582 +oscar white 28.287142857142857 +oscar xylophone 25.845 +oscar xylophone 28.735000000000003 +oscar xylophone 30.72285714285714 +oscar zipper 24.511111111111113 +oscar zipper 25.067777777777778 +oscar zipper 26.21857142857143 +priscilla brown 14.222 +priscilla brown 27.044999999999998 +priscilla brown 30.19909090909091 +priscilla carson 18.07 +priscilla carson 20.70875 +priscilla carson 26.297499999999996 +priscilla ichabod 29.451111111111118 +priscilla ichabod 29.878888888888884 +priscilla johnson 16.9925 +priscilla johnson 22.050000000000004 +priscilla johnson 24.093000000000004 +priscilla johnson 29.200000000000003 +priscilla johnson 29.246 +priscilla king 15.536666666666669 +priscilla nixon 18.9 +priscilla nixon 30.516999999999996 +priscilla ovid 16.005000000000003 +priscilla ovid 29.88111111111111 +priscilla polk 28.018888888888892 +priscilla quirinius 23.064999999999998 +priscilla thompson 27.077142857142857 +priscilla underhill 28.160999999999994 +priscilla underhill 28.56571428571429 +priscilla van buren 20.7 +priscilla van buren 21.830000000000002 +priscilla van buren 24.665 +priscilla white 29.23375 +priscilla xylophone 19.331666666666667 +priscilla xylophone 23.185 +priscilla xylophone 27.34125 +priscilla young 26.32777777777778 +priscilla young 30.613999999999997 +priscilla zipper 13.498 +priscilla zipper 31.972727272727266 +quinn allen 28.786666666666672 +quinn allen 29.471818181818175 +quinn brown 26.314285714285713 +quinn brown 27.38 +quinn brown 28.39888888888889 +quinn davidson 20.808 +quinn davidson 22.71285714285714 +quinn davidson 23.306250000000002 +quinn davidson 27.34125 +quinn ellison 25.002857142857142 +quinn ellison 33.760000000000005 +quinn garcia 20.544 +quinn garcia 27.055999999999994 +quinn garcia 29.183333333333334 +quinn garcia 31.831249999999997 +quinn ichabod 22.101818181818185 +quinn king 17.535 +quinn king 18.035714285714285 +quinn laertes 24.511111111111113 +quinn laertes 28.876250000000002 +quinn laertes 29.202857142857145 +quinn nixon 21.75111111111111 +quinn ovid 29.64125 +quinn quirinius 18.922 +quinn robinson 23.985 +quinn steinbeck 27.077142857142857 +quinn steinbeck 28.160999999999994 +quinn thompson 21.99888888888889 +quinn thompson 34.46857142857143 +quinn underhill 20.113333333333333 +quinn underhill 23.107272727272726 +quinn underhill 26.224285714285717 +quinn van buren 23.612222222222222 +quinn young 24.5625 +quinn zipper 17.889999999999997 +quinn zipper 20.163333333333338 +rachel allen 28.446666666666665 +rachel allen 41.865 +rachel brown 20.92875 +rachel brown 26.21857142857143 +rachel brown 27.905 +rachel brown 30.166666666666668 +rachel brown 35.345 +rachel carson 28.735000000000003 +rachel carson 31.715999999999998 +rachel davidson 29.100000000000005 +rachel ellison 27.055454545454552 +rachel falkner 14.812499999999998 +rachel falkner 28.876250000000002 +rachel falkner 29.308888888888887 +rachel falkner 31.831249999999997 +rachel johnson 31.63 +rachel king 24.511111111111113 +rachel king 30.873749999999998 +rachel laertes 17.306 +rachel laertes 26.765 +rachel ovid 24.042727272727276 +rachel ovid 28.01181818181818 +rachel polk 21.12375 +rachel quirinius 29.831249999999997 +rachel robinson 10.4925 +rachel robinson 22.264444444444447 +rachel robinson 33.43125 +rachel thompson 16.720000000000002 +rachel thompson 26.905714285714286 +rachel thompson 28.876250000000002 +rachel underhill 21.75111111111111 +rachel white 22.175 +rachel white 31.580000000000002 +rachel young 28.150000000000002 +rachel zipper 22.187142857142856 +rachel zipper 33.760000000000005 +sarah carson 21.86818181818182 +sarah carson 22.175 +sarah carson 33.43125 +sarah ellison 17.535 +sarah falkner 29.34875 +sarah falkner 29.64125 +sarah garcia 10.4925 +sarah garcia 20.812 +sarah garcia 28.2175 +sarah ichabod 26.948333333333334 +sarah ichabod 33.62375 +sarah johnson 21.546666666666667 +sarah johnson 24.978000000000005 +sarah johnson 29.608000000000004 +sarah johnson 33.760000000000005 +sarah king 19.41 +sarah king 27.055999999999994 +sarah miller 24.815454545454543 +sarah ovid 28.31625 +sarah robinson 13.498 +sarah robinson 28.256249999999998 +sarah steinbeck 23.26 +sarah white 21.75111111111111 +sarah white 31.63 +sarah xylophone 21.964545454545455 +sarah young 29.335555555555555 +sarah zipper 29.521666666666665 +tom brown 22.873333333333335 +tom brown 30.415555555555557 +tom carson 22.4025 +tom carson 28.39888888888889 +tom carson 29.64125 +tom davidson 30.61142857142857 +tom ellison 23.568 +tom ellison 27.884999999999998 +tom ellison 32.02625 +tom falkner 13.72 +tom falkner 19.849999999999998 +tom hernandez 16.720000000000002 +tom hernandez 29.974285714285713 +tom ichabod 20.113333333333333 +tom johnson 27.077142857142857 +tom johnson 32.90500000000001 +tom king 21.855 +tom laertes 17.981666666666666 +tom laertes 21.80857142857143 +tom miller 18.922 +tom miller 21.239999999999995 +tom miller 22.396666666666665 +tom nixon 27.005000000000003 +tom ovid 34.84571428571429 +tom polk 29.521666666666665 +tom polk 29.805 +tom quirinius 24.764285714285712 +tom quirinius 36.46857142857143 +tom robinson 16.18 +tom robinson 18.07 +tom robinson 27.34125 +tom robinson 34.958571428571425 +tom steinbeck 30.613999999999997 +tom van buren 22.6475 +tom van buren 23.13555555555556 +tom van buren 24.4025 +tom white 27.715714285714284 +tom young 19.41 +tom young 24.63 +tom zipper 22.902 +ulysses brown 16.223333333333333 +ulysses carson 16.3475 +ulysses carson 22.057500000000005 +ulysses carson 28.256249999999998 +ulysses carson 28.27285714285714 +ulysses davidson 24.701249999999998 +ulysses ellison 29.52333333333333 +ulysses garcia 33.382 +ulysses hernandez 18.421818181818185 +ulysses hernandez 20.443749999999998 +ulysses hernandez 22.365 +ulysses ichabod 24.63 +ulysses ichabod 33.24333333333333 +ulysses johnson 33.43125 +ulysses king 27.083333333333332 +ulysses laertes 26.915000000000003 +ulysses laertes 27.305 +ulysses laertes 28.501111111111115 +ulysses miller 18.22 +ulysses miller 26.21857142857143 +ulysses nixon 30.65222222222222 +ulysses ovid 21.366666666666667 +ulysses polk 22.555714285714288 +ulysses polk 22.66625 +ulysses polk 25.11777777777778 +ulysses polk 27.141666666666666 +ulysses quirinius 33.07833333333334 +ulysses robinson 21.12375 +ulysses steinbeck 23.751428571428573 +ulysses steinbeck 25.931428571428572 +ulysses thompson 22.264444444444447 +ulysses underhill 20.812 +ulysses underhill 23.751428571428573 +ulysses underhill 25.071666666666662 +ulysses underhill 25.828333333333333 +ulysses underhill 25.865 +ulysses underhill 28.722499999999997 +ulysses underhill 35.268888888888895 +ulysses van buren 22.134999999999998 +ulysses white 15.296666666666667 +ulysses white 32.093333333333334 +ulysses xylophone 20.38714285714286 +ulysses xylophone 25.274 +ulysses xylophone 29.64125 +ulysses young 22.213333333333335 +ulysses young 22.90285714285714 +ulysses young 32.93125 +victor allen 24.82875 +victor allen 27.51 +victor brown 21.621250000000003 +victor brown 23.73 +victor brown 26.21857142857143 +victor brown 27.548571428571428 +victor davidson 22.391666666666666 +victor davidson 33.16428571428572 +victor davidson 35.197500000000005 +victor ellison 11.100000000000001 +victor ellison 30.96857142857143 +victor hernandez 10.4925 +victor hernandez 18.922 +victor hernandez 24.301250000000003 +victor hernandez 26.69857142857143 +victor hernandez 35.358333333333334 +victor johnson 16.580000000000002 +victor johnson 27.516666666666666 +victor johnson 32.106 +victor king 19.962857142857143 +victor king 33.01857142857143 +victor laertes 21.78142857142857 +victor laertes 33.10999999999999 +victor miller 21.93285714285714 +victor nixon 20.419999999999998 +victor nixon 33.69 +victor ovid 28.75857142857143 +victor polk 18.43111111111111 +victor quirinius 17.84777777777778 +victor quirinius 27.53 +victor robinson 19.37 +victor robinson 20.38714285714286 +victor steinbeck 20.818333333333335 +victor steinbeck 25.16 +victor steinbeck 30.503749999999997 +victor thompson 23.987142857142857 +victor van buren 27.009999999999998 +victor van buren 33.43125 +victor white 24.322857142857142 +victor white 28.287142857142857 +victor xylophone 11.807142857142859 +victor xylophone 13.988571428571428 +victor xylophone 16.720000000000002 +victor xylophone 19.686666666666667 +victor xylophone 37.20428571428572 +victor young 22.264444444444447 +victor zipper 24.854285714285716 +wendy allen 28.24142857142857 +wendy allen 29.675714285714285 +wendy allen 34.275 +wendy brown 22.482857142857142 +wendy brown 27.79714285714286 +wendy ellison 16.80888888888889 +wendy ellison 18.135 +wendy falkner 22.628888888888884 +wendy falkner 23.325000000000003 +wendy falkner 24.0375 +wendy garcia 19.307142857142853 +wendy garcia 21.761428571428574 +wendy garcia 24.63 +wendy garcia 24.854285714285716 +wendy hernandez 16.60875 +wendy ichabod 28.26 +wendy king 22.5 +wendy king 24.793333333333333 +wendy king 28.252857142857145 +wendy laertes 25.881428571428568 +wendy laertes 30.338333333333328 +wendy laertes 30.52857142857143 +wendy miller 15.478333333333332 +wendy miller 25.34333333333333 +wendy nixon 19.54714285714286 +wendy nixon 27.003333333333334 +wendy ovid 14.283750000000001 +wendy ovid 30.878333333333334 +wendy polk 21.69375 +wendy polk 24.63 +wendy quirinius 28.731428571428573 +wendy quirinius 29.74333333333333 +wendy robinson 16.720000000000002 +wendy robinson 23.834285714285716 +wendy robinson 29.911666666666672 +wendy steinbeck 29.272857142857145 +wendy thompson 18.17875 +wendy thompson 22.544285714285714 +wendy underhill 21.69625 +wendy underhill 27.077142857142857 +wendy underhill 30.03333333333333 +wendy van buren 28.624285714285715 +wendy van buren 29.28333333333333 +wendy white 24.4025 +wendy xylophone 16.84 +wendy xylophone 23.426666666666666 +wendy young 20.80125 +wendy young 32.693333333333335 +xavier allen 16.535 +xavier allen 17.398333333333337 +xavier allen 35.708333333333336 +xavier brown 20.787142857142857 +xavier brown 24.764285714285712 +xavier brown 31.784999999999997 +xavier carson 20.818333333333335 +xavier carson 32.106 +xavier davidson 16.862857142857145 +xavier davidson 20.53625 +xavier davidson 27.353333333333335 +xavier ellison 17.991666666666667 +xavier ellison 23.976666666666663 +xavier garcia 35.84428571428572 +xavier hernandez 22.654285714285713 +xavier hernandez 26.948333333333334 +xavier hernandez 28.075 +xavier ichabod 20.344285714285714 +xavier ichabod 20.818333333333335 +xavier johnson 15.754285714285714 +xavier johnson 19.490000000000002 +xavier king 29.246666666666666 +xavier king 29.521666666666665 +xavier laertes 19.294999999999998 +xavier ovid 28.51 +xavier polk 12.728333333333333 +xavier polk 19.37 +xavier polk 22.548333333333332 +xavier polk 28.465714285714284 +xavier quirinius 9.991428571428571 +xavier quirinius 24.156666666666666 +xavier quirinius 25.69666666666667 +xavier quirinius 25.828333333333333 +xavier thompson 23.961428571428574 +xavier underhill 21.830000000000002 +xavier white 19.331666666666667 +xavier white 35.345 +xavier xylophone 21.187142857142856 +xavier zipper 17.488333333333333 +yuri allen 15.705714285714285 +yuri allen 20.808 +yuri brown 19.53 +yuri brown 22.457142857142856 +yuri carson 25.699999999999996 +yuri carson 27.216666666666665 +yuri ellison 15.034999999999998 +yuri ellison 28.463333333333335 +yuri falkner 17.81833333333333 +yuri falkner 19.294999999999998 +yuri garcia 28.287142857142857 +yuri hernandez 32.395 +yuri johnson 25.828333333333333 +yuri johnson 27.301666666666666 +yuri johnson 29.578333333333337 +yuri king 19.921666666666663 +yuri laertes 16.18 +yuri laertes 30.519999999999996 +yuri nixon 16.383333333333333 +yuri nixon 25.828333333333333 +yuri polk 16.18 +yuri polk 20.503333333333334 +yuri polk 30.16333333333333 +yuri quirinius 20.311666666666667 +yuri quirinius 23.185 +yuri quirinius 24.828333333333333 +yuri steinbeck 19.331666666666667 +yuri steinbeck 28.50666666666667 +yuri thompson 35.27 +yuri underhill 23.042857142857144 +yuri underhill 28.786666666666665 +yuri white 30.72285714285714 +yuri xylophone 24.173333333333332 +zach allen 8.983333333333333 +zach brown 18.922 +zach brown 23.036666666666665 +zach brown 29.72666666666667 +zach brown 31.58285714285714 +zach brown 33.07833333333334 +zach carson 27.110000000000003 +zach ellison 18.168333333333333 +zach falkner 16.18 +zach falkner 30.83285714285714 +zach garcia 16.586666666666666 +zach garcia 22.53333333333333 +zach garcia 28.13166666666667 +zach garcia 34.84571428571429 +zach ichabod 17.535 +zach ichabod 30.72285714285714 +zach king 19.878333333333334 +zach king 25.643333333333334 +zach king 28.646666666666665 +zach miller 23.285 +zach miller 23.366666666666664 +zach miller 30.46833333333333 +zach ovid 23.94666666666667 +zach ovid 28.75166666666667 +zach ovid 28.763333333333335 +zach ovid 34.84571428571429 +zach quirinius 20.755 +zach robinson 21.546666666666667 +zach steinbeck 27.243333333333336 +zach steinbeck 30.073333333333334 +zach thompson 14.222 +zach thompson 24.755 +zach underhill 31.885 +zach white 20.208333333333332 +zach xylophone 10.485 +zach xylophone 20.113333333333333 +zach young 20.176666666666666 +zach zipper 21.709999999999997 +zach zipper 22.264999999999997 +zach zipper 34.01166666666667 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-5-ee44c5cdc80e1c832b702f9fb76d8145 b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-5-ee44c5cdc80e1c832b702f9fb76d8145 new file mode 100644 index 000000000000..a9ae190825a0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-5-ee44c5cdc80e1c832b702f9fb76d8145 @@ -0,0 +1,1049 @@ + 65560 + 65718 + 65740 +alice allen 65662 +alice allen 65720 +alice allen 65758 +alice brown 65696 +alice carson 65559 +alice davidson 65547 +alice falkner 65669 +alice garcia 65613 +alice hernandez 65737 +alice hernandez 65784 +alice johnson 65739 +alice king 65660 +alice king 65738 +alice king 131281 +alice laertes 65669 +alice laertes 65671 +alice miller 65590 +alice nixon 65586 +alice nixon 65595 +alice nixon 65604 +alice ovid 65737 +alice polk 65548 +alice quirinius 65636 +alice quirinius 65728 +alice robinson 65606 +alice robinson 65789 +alice steinbeck 65578 +alice steinbeck 65673 +alice steinbeck 65786 +alice underhill 65750 +alice van buren 65562 +alice xylophone 65585 +alice xylophone 65599 +alice xylophone 131247 +alice zipper 65553 +alice zipper 65662 +alice zipper 65766 +bob brown 65584 +bob brown 65777 +bob brown 65783 +bob carson 65713 +bob davidson 65664 +bob davidson 65693 +bob davidson 65768 +bob ellison 65591 +bob ellison 65624 +bob ellison 65721 +bob ellison 65760 +bob falkner 65789 +bob garcia 65585 +bob garcia 65598 +bob garcia 65673 +bob garcia 65754 +bob garcia 65782 +bob hernandez 131340 +bob ichabod 65549 +bob king 65715 +bob king 65757 +bob king 65783 +bob laertes 65602 +bob laertes 131447 +bob miller 65608 +bob ovid 65564 +bob ovid 65686 +bob ovid 196959 +bob ovid 196973 +bob polk 65594 +bob quirinius 65700 +bob steinbeck 65637 +bob van buren 65778 +bob white 65543 +bob white 65605 +bob xylophone 65574 +bob xylophone 65666 +bob young 65556 +bob zipper 65559 +bob zipper 65633 +bob zipper 65739 +calvin allen 65669 +calvin brown 65537 +calvin brown 131272 +calvin brown 197027 +calvin carson 65637 +calvin davidson 65541 +calvin davidson 65564 +calvin ellison 65667 +calvin falkner 65573 +calvin falkner 65596 +calvin falkner 65778 +calvin falkner 131397 +calvin falkner 131411 +calvin falkner 131433 +calvin garcia 131212 +calvin hernandez 131251 +calvin johnson 65731 +calvin laertes 65570 +calvin laertes 65684 +calvin nixon 65654 +calvin nixon 131386 +calvin nixon 131503 +calvin ovid 65554 +calvin ovid 65643 +calvin ovid 65715 +calvin ovid 196944 +calvin polk 65731 +calvin quirinius 65741 +calvin quirinius 65769 +calvin robinson 131320 +calvin steinbeck 131271 +calvin steinbeck 131326 +calvin steinbeck 131415 +calvin thompson 65560 +calvin thompson 131244 +calvin underhill 196944 +calvin van buren 65771 +calvin van buren 131138 +calvin white 65553 +calvin white 65561 +calvin xylophone 65575 +calvin xylophone 65596 +calvin xylophone 262686 +calvin young 65746 +calvin young 131168 +calvin zipper 65669 +calvin zipper 131476 +david allen 65588 +david allen 131222 +david brown 65637 +david brown 131303 +david davidson 65756 +david davidson 65778 +david davidson 65779 +david davidson 131342 +david ellison 65724 +david ellison 65724 +david ellison 131224 +david hernandez 197083 +david ichabod 131454 +david ichabod 197085 +david laertes 65762 +david nixon 65536 +david ovid 65623 +david ovid 196766 +david quirinius 65759 +david quirinius 65779 +david quirinius 131303 +david robinson 65762 +david robinson 65775 +david thompson 65550 +david underhill 65662 +david underhill 65751 +david underhill 131198 +david van buren 65634 +david van buren 262584 +david white 65678 +david xylophone 65537 +david xylophone 131426 +david xylophone 131447 +david young 65551 +david young 131255 +ethan allen 131460 +ethan brown 65539 +ethan brown 65617 +ethan brown 65685 +ethan brown 65685 +ethan brown 65722 +ethan brown 131483 +ethan carson 197189 +ethan ellison 65714 +ethan ellison 131302 +ethan falkner 131222 +ethan falkner 131333 +ethan garcia 131507 +ethan hernandez 65618 +ethan johnson 65536 +ethan king 131280 +ethan laertes 65562 +ethan laertes 65597 +ethan laertes 65628 +ethan laertes 65680 +ethan laertes 65760 +ethan laertes 131304 +ethan laertes 328329 +ethan miller 328296 +ethan nixon 65766 +ethan ovid 65697 +ethan polk 65589 +ethan polk 65615 +ethan polk 131206 +ethan polk 197082 +ethan quirinius 65591 +ethan quirinius 196912 +ethan quirinius 196957 +ethan robinson 65547 +ethan robinson 65659 +ethan underhill 65570 +ethan van buren 131252 +ethan white 65677 +ethan white 197039 +ethan xylophone 65595 +ethan zipper 65593 +ethan zipper 131365 +fred davidson 65595 +fred davidson 65721 +fred davidson 131221 +fred ellison 65548 +fred ellison 65691 +fred ellison 65771 +fred falkner 65637 +fred falkner 131474 +fred falkner 196920 +fred hernandez 131226 +fred ichabod 131109 +fred ichabod 131520 +fred johnson 131332 +fred king 65694 +fred king 197016 +fred laertes 131354 +fred miller 65536 +fred nixon 65560 +fred nixon 65612 +fred nixon 65705 +fred nixon 196929 +fred polk 65656 +fred polk 131231 +fred polk 262645 +fred polk 262733 +fred quirinius 131486 +fred quirinius 196950 +fred robinson 65623 +fred steinbeck 65544 +fred steinbeck 65755 +fred steinbeck 131253 +fred underhill 131188 +fred van buren 65561 +fred van buren 65745 +fred van buren 131380 +fred van buren 328270 +fred white 131136 +fred young 65594 +fred young 131551 +fred zipper 196885 +gabriella allen 65677 +gabriella allen 131283 +gabriella brown 65753 +gabriella brown 197180 +gabriella carson 65586 +gabriella davidson 65565 +gabriella ellison 65706 +gabriella ellison 131505 +gabriella falkner 65767 +gabriella falkner 131183 +gabriella falkner 131397 +gabriella garcia 131127 +gabriella hernandez 131304 +gabriella hernandez 131304 +gabriella ichabod 65559 +gabriella ichabod 65712 +gabriella ichabod 131297 +gabriella ichabod 131311 +gabriella ichabod 131460 +gabriella king 65657 +gabriella king 197031 +gabriella laertes 131543 +gabriella miller 131300 +gabriella ovid 65556 +gabriella ovid 131260 +gabriella polk 65790 +gabriella polk 131425 +gabriella steinbeck 65582 +gabriella steinbeck 131248 +gabriella thompson 131528 +gabriella thompson 197181 +gabriella thompson 262632 +gabriella van buren 65644 +gabriella van buren 131238 +gabriella white 65638 +gabriella young 65699 +gabriella young 65774 +gabriella zipper 65754 +gabriella zipper 196762 +holly allen 65596 +holly brown 131315 +holly brown 131368 +holly falkner 65720 +holly hernandez 65602 +holly hernandez 65686 +holly hernandez 131387 +holly hernandez 131554 +holly ichabod 65752 +holly ichabod 131308 +holly ichabod 131473 +holly johnson 65755 +holly johnson 131240 +holly johnson 131277 +holly king 131286 +holly king 131303 +holly laertes 196950 +holly miller 131381 +holly nixon 196941 +holly nixon 328184 +holly polk 197132 +holly polk 262782 +holly robinson 131241 +holly thompson 65578 +holly thompson 65713 +holly thompson 197092 +holly underhill 65654 +holly underhill 131323 +holly underhill 131385 +holly underhill 131504 +holly van buren 131449 +holly white 131092 +holly white 262734 +holly xylophone 196792 +holly young 65765 +holly young 131229 +holly zipper 131151 +holly zipper 131545 +irene allen 131109 +irene brown 65765 +irene brown 131368 +irene brown 393929 +irene carson 262770 +irene ellison 196956 +irene ellison 196982 +irene falkner 131287 +irene falkner 197046 +irene garcia 65660 +irene garcia 131286 +irene garcia 131375 +irene ichabod 65645 +irene ichabod 131442 +irene johnson 131179 +irene laertes 131324 +irene laertes 131381 +irene laertes 131407 +irene miller 262822 +irene nixon 197105 +irene nixon 262409 +irene nixon 262565 +irene ovid 65734 +irene ovid 196935 +irene ovid 262836 +irene polk 65551 +irene polk 131189 +irene polk 131189 +irene polk 196943 +irene polk 328365 +irene quirinius 131369 +irene quirinius 196998 +irene quirinius 262855 +irene robinson 131259 +irene steinbeck 65683 +irene thompson 262719 +irene underhill 131291 +irene underhill 131386 +irene van buren 131216 +irene van buren 262539 +irene xylophone 131348 +jessica brown 393772 +jessica carson 65747 +jessica carson 131207 +jessica carson 131232 +jessica davidson 65606 +jessica davidson 65675 +jessica davidson 196917 +jessica davidson 197030 +jessica ellison 131108 +jessica ellison 196885 +jessica falkner 131270 +jessica garcia 197059 +jessica garcia 328458 +jessica ichabod 197028 +jessica johnson 131177 +jessica johnson 197085 +jessica miller 197024 +jessica nixon 131549 +jessica nixon 196682 +jessica ovid 65751 +jessica ovid 196890 +jessica polk 459409 +jessica quirinius 131222 +jessica quirinius 131248 +jessica quirinius 131294 +jessica quirinius 393878 +jessica robinson 131174 +jessica thompson 131336 +jessica thompson 196927 +jessica underhill 131218 +jessica underhill 131267 +jessica underhill 197086 +jessica van buren 65615 +jessica white 65544 +jessica white 65594 +jessica white 197012 +jessica white 262435 +jessica white 262571 +jessica xylophone 196866 +jessica young 65711 +jessica young 131183 +jessica zipper 196897 +jessica zipper 262523 +jessica zipper 262695 +katie allen 196740 +katie brown 328113 +katie davidson 131371 +katie ellison 131248 +katie ellison 197182 +katie falkner 131441 +katie garcia 131384 +katie garcia 197051 +katie hernandez 131296 +katie ichabod 131495 +katie ichabod 197131 +katie ichabod 197275 +katie king 131252 +katie king 262588 +katie king 262861 +katie miller 65661 +katie miller 262723 +katie nixon 65669 +katie ovid 65681 +katie polk 65784 +katie polk 197249 +katie robinson 131251 +katie van buren 131237 +katie van buren 197141 +katie white 262510 +katie white 262860 +katie xylophone 197034 +katie young 65644 +katie young 328173 +katie young 393859 +katie zipper 65733 +katie zipper 328287 +luke allen 65776 +luke allen 131268 +luke allen 196819 +luke allen 196855 +luke allen 328011 +luke brown 196967 +luke davidson 65656 +luke davidson 131573 +luke ellison 65582 +luke ellison 131343 +luke ellison 197118 +luke falkner 196797 +luke falkner 196837 +luke garcia 65778 +luke garcia 393974 +luke ichabod 65629 +luke ichabod 262574 +luke johnson 131302 +luke johnson 131312 +luke johnson 131361 +luke laertes 131226 +luke laertes 131504 +luke laertes 197018 +luke laertes 197153 +luke laertes 197177 +luke miller 197052 +luke ovid 65569 +luke ovid 262745 +luke polk 65658 +luke polk 262627 +luke quirinius 131233 +luke robinson 65634 +luke robinson 262569 +luke thompson 196858 +luke underhill 65651 +luke underhill 131240 +luke underhill 328248 +luke van buren 131398 +luke white 65693 +luke xylophone 131312 +luke zipper 131297 +mike allen 196928 +mike brown 197149 +mike carson 65751 +mike carson 131284 +mike carson 393711 +mike davidson 196917 +mike davidson 262912 +mike ellison 65598 +mike ellison 131366 +mike ellison 131412 +mike ellison 131509 +mike ellison 262704 +mike falkner 328183 +mike garcia 131530 +mike garcia 328305 +mike garcia 328461 +mike hernandez 131301 +mike hernandez 328384 +mike ichabod 131157 +mike king 196965 +mike king 197091 +mike king 197121 +mike king 262471 +mike king 262527 +mike king 328279 +mike miller 131317 +mike nixon 131328 +mike nixon 262653 +mike polk 131240 +mike polk 196899 +mike polk 262885 +mike quirinius 525126 +mike steinbeck 65550 +mike steinbeck 131201 +mike steinbeck 131490 +mike steinbeck 262490 +mike van buren 131548 +mike van buren 262547 +mike white 197000 +mike white 197060 +mike white 262425 +mike white 328482 +mike young 196935 +mike young 196976 +mike young 328084 +mike zipper 131147 +mike zipper 197075 +mike zipper 328517 +nick allen 131192 +nick allen 197024 +nick brown 131503 +nick davidson 262686 +nick ellison 197119 +nick ellison 197119 +nick falkner 65583 +nick falkner 328561 +nick garcia 131318 +nick garcia 262755 +nick garcia 328281 +nick ichabod 131430 +nick ichabod 196812 +nick ichabod 328593 +nick johnson 131453 +nick johnson 262597 +nick laertes 196732 +nick miller 131490 +nick nixon 262547 +nick ovid 328266 +nick polk 196852 +nick quirinius 131438 +nick quirinius 328176 +nick robinson 131326 +nick robinson 196980 +nick steinbeck 131250 +nick thompson 65610 +nick underhill 65619 +nick van buren 196795 +nick xylophone 196972 +nick young 394136 +nick young 459634 +nick zipper 262954 +nick zipper 394218 +oscar allen 262674 +oscar brown 196916 +oscar carson 131099 +oscar carson 131330 +oscar carson 196731 +oscar carson 196733 +oscar carson 196918 +oscar davidson 262554 +oscar ellison 65630 +oscar ellison 197116 +oscar falkner 197145 +oscar garcia 328305 +oscar hernandez 197022 +oscar hernandez 328315 +oscar ichabod 131302 +oscar ichabod 131309 +oscar ichabod 196760 +oscar ichabod 196874 +oscar johnson 196942 +oscar johnson 197203 +oscar king 196793 +oscar king 196944 +oscar king 328236 +oscar laertes 131208 +oscar laertes 262522 +oscar laertes 262842 +oscar laertes 328364 +oscar nixon 65596 +oscar ovid 131228 +oscar ovid 262580 +oscar ovid 393817 +oscar polk 131078 +oscar polk 131260 +oscar quirinius 131103 +oscar quirinius 196748 +oscar quirinius 196829 +oscar quirinius 262838 +oscar robinson 196874 +oscar robinson 262803 +oscar robinson 393773 +oscar robinson 394087 +oscar steinbeck 328432 +oscar thompson 196826 +oscar thompson 196992 +oscar thompson 262593 +oscar thompson 459401 +oscar underhill 131301 +oscar van buren 131134 +oscar van buren 328162 +oscar van buren 394034 +oscar white 131457 +oscar white 262345 +oscar white 328538 +oscar white 459337 +oscar xylophone 65773 +oscar xylophone 262708 +oscar xylophone 262906 +oscar zipper 196904 +oscar zipper 262512 +oscar zipper 328262 +priscilla brown 196950 +priscilla brown 328237 +priscilla brown 328624 +priscilla carson 262488 +priscilla carson 262510 +priscilla carson 262703 +priscilla ichabod 131178 +priscilla ichabod 131303 +priscilla johnson 131224 +priscilla johnson 196906 +priscilla johnson 196994 +priscilla johnson 197184 +priscilla johnson 394171 +priscilla king 262692 +priscilla nixon 262691 +priscilla nixon 394188 +priscilla ovid 65541 +priscilla ovid 197067 +priscilla polk 394009 +priscilla quirinius 131306 +priscilla thompson 196875 +priscilla underhill 197084 +priscilla underhill 262701 +priscilla van buren 65685 +priscilla van buren 131368 +priscilla van buren 196814 +priscilla white 196893 +priscilla xylophone 131473 +priscilla xylophone 262597 +priscilla xylophone 262785 +priscilla young 131392 +priscilla young 262788 +priscilla zipper 393888 +priscilla zipper 394031 +quinn allen 197095 +quinn allen 394225 +quinn brown 131470 +quinn brown 131473 +quinn brown 262642 +quinn davidson 197079 +quinn davidson 197112 +quinn davidson 262510 +quinn davidson 459427 +quinn ellison 197268 +quinn ellison 328130 +quinn garcia 65604 +quinn garcia 131321 +quinn garcia 197067 +quinn garcia 328528 +quinn ichabod 65564 +quinn king 65649 +quinn king 196879 +quinn laertes 65542 +quinn laertes 196877 +quinn laertes 262466 +quinn nixon 196837 +quinn ovid 525126 +quinn quirinius 328235 +quinn robinson 131378 +quinn steinbeck 131484 +quinn steinbeck 262528 +quinn thompson 197030 +quinn thompson 262717 +quinn underhill 262791 +quinn underhill 328146 +quinn underhill 393824 +quinn van buren 197234 +quinn young 65647 +quinn zipper 131466 +quinn zipper 262658 +rachel allen 65661 +rachel allen 196935 +rachel brown 131220 +rachel brown 328076 +rachel brown 328320 +rachel brown 393835 +rachel brown 524988 +rachel carson 131259 +rachel carson 459393 +rachel davidson 262632 +rachel ellison 393845 +rachel falkner 196947 +rachel falkner 262474 +rachel falkner 394046 +rachel falkner 525086 +rachel johnson 65658 +rachel king 131354 +rachel king 196907 +rachel laertes 131391 +rachel laertes 197105 +rachel ovid 262664 +rachel ovid 328195 +rachel polk 328389 +rachel quirinius 262779 +rachel robinson 262491 +rachel robinson 262862 +rachel robinson 590712 +rachel thompson 197034 +rachel thompson 328158 +rachel thompson 394094 +rachel underhill 197033 +rachel white 131399 +rachel white 197190 +rachel young 196967 +rachel zipper 328223 +rachel zipper 394149 +sarah carson 131379 +sarah carson 196870 +sarah carson 262491 +sarah ellison 197095 +sarah falkner 131262 +sarah falkner 328251 +sarah garcia 196963 +sarah garcia 197030 +sarah garcia 459657 +sarah ichabod 262504 +sarah ichabod 262766 +sarah johnson 131409 +sarah johnson 262783 +sarah johnson 328591 +sarah johnson 394043 +sarah king 196998 +sarah king 328416 +sarah miller 196893 +sarah ovid 131199 +sarah robinson 262868 +sarah robinson 394066 +sarah steinbeck 262650 +sarah white 197059 +sarah white 262579 +sarah xylophone 131336 +sarah young 394123 +sarah zipper 262818 +tom brown 196848 +tom brown 328268 +tom carson 197328 +tom carson 262517 +tom carson 656251 +tom davidson 262864 +tom ellison 196974 +tom ellison 328416 +tom ellison 393921 +tom falkner 393809 +tom falkner 459407 +tom hernandez 262525 +tom hernandez 328085 +tom ichabod 197048 +tom johnson 328321 +tom johnson 393865 +tom king 196951 +tom laertes 262657 +tom laertes 459805 +tom miller 131278 +tom miller 131459 +tom miller 262633 +tom nixon 262588 +tom ovid 262595 +tom polk 328470 +tom polk 328584 +tom quirinius 262597 +tom quirinius 262681 +tom robinson 196978 +tom robinson 328481 +tom robinson 459857 +tom robinson 525095 +tom steinbeck 262426 +tom van buren 131389 +tom van buren 328095 +tom van buren 328313 +tom white 328128 +tom young 131080 +tom young 393692 +tom zipper 197167 +ulysses brown 196815 +ulysses carson 131277 +ulysses carson 262450 +ulysses carson 262937 +ulysses carson 328311 +ulysses davidson 262750 +ulysses ellison 262445 +ulysses garcia 328445 +ulysses hernandez 131414 +ulysses hernandez 196871 +ulysses hernandez 394370 +ulysses ichabod 393834 +ulysses ichabod 459582 +ulysses johnson 262966 +ulysses king 131363 +ulysses laertes 262739 +ulysses laertes 328412 +ulysses laertes 328462 +ulysses miller 262661 +ulysses miller 328360 +ulysses nixon 394194 +ulysses ovid 328289 +ulysses polk 65563 +ulysses polk 197046 +ulysses polk 328294 +ulysses polk 590698 +ulysses quirinius 525643 +ulysses robinson 394160 +ulysses steinbeck 196783 +ulysses steinbeck 262778 +ulysses thompson 262607 +ulysses underhill 131214 +ulysses underhill 196937 +ulysses underhill 197027 +ulysses underhill 262623 +ulysses underhill 262623 +ulysses underhill 262648 +ulysses underhill 262836 +ulysses van buren 196944 +ulysses white 197033 +ulysses white 393988 +ulysses xylophone 262695 +ulysses xylophone 328151 +ulysses xylophone 328747 +ulysses young 196903 +ulysses young 394037 +ulysses young 459782 +victor allen 197189 +victor allen 262651 +victor brown 262544 +victor brown 262799 +victor brown 327900 +victor brown 591265 +victor davidson 197173 +victor davidson 262486 +victor davidson 328274 +victor ellison 328618 +victor ellison 393962 +victor hernandez 197041 +victor hernandez 197132 +victor hernandez 262771 +victor hernandez 328261 +victor hernandez 459902 +victor johnson 131155 +victor johnson 131169 +victor johnson 394168 +victor king 131486 +victor king 328509 +victor laertes 262573 +victor laertes 328435 +victor miller 196784 +victor nixon 196987 +victor nixon 394249 +victor ovid 196882 +victor polk 262462 +victor quirinius 65620 +victor quirinius 328301 +victor robinson 328334 +victor robinson 394031 +victor steinbeck 65661 +victor steinbeck 262560 +victor steinbeck 262750 +victor thompson 65548 +victor van buren 197173 +victor van buren 328261 +victor white 262588 +victor white 328039 +victor xylophone 131203 +victor xylophone 262596 +victor xylophone 328191 +victor xylophone 393913 +victor xylophone 459542 +victor young 131258 +victor zipper 131349 +wendy allen 131402 +wendy allen 196954 +wendy allen 328359 +wendy brown 328365 +wendy brown 459501 +wendy ellison 262718 +wendy ellison 328191 +wendy falkner 197009 +wendy falkner 262430 +wendy falkner 328177 +wendy garcia 65746 +wendy garcia 393974 +wendy garcia 459883 +wendy garcia 459926 +wendy hernandez 65650 +wendy ichabod 262665 +wendy king 262545 +wendy king 328229 +wendy king 393951 +wendy laertes 262739 +wendy laertes 262794 +wendy laertes 328315 +wendy miller 131377 +wendy miller 328161 +wendy nixon 131258 +wendy nixon 196893 +wendy ovid 196952 +wendy ovid 459594 +wendy polk 328520 +wendy polk 394310 +wendy quirinius 328703 +wendy quirinius 394360 +wendy robinson 131316 +wendy robinson 394030 +wendy robinson 459665 +wendy steinbeck 262645 +wendy thompson 262725 +wendy thompson 393865 +wendy underhill 328445 +wendy underhill 394295 +wendy underhill 460068 +wendy van buren 65699 +wendy van buren 196964 +wendy white 328135 +wendy xylophone 262894 +wendy xylophone 525344 +wendy young 197017 +wendy young 721936 +xavier allen 197025 +xavier allen 525393 +xavier allen 525839 +xavier brown 197058 +xavier brown 262626 +xavier brown 328388 +xavier carson 196990 +xavier carson 328415 +xavier davidson 65644 +xavier davidson 262745 +xavier davidson 393825 +xavier ellison 197095 +xavier ellison 328447 +xavier garcia 262590 +xavier hernandez 196847 +xavier hernandez 197077 +xavier hernandez 393838 +xavier ichabod 262600 +xavier ichabod 328157 +xavier johnson 197084 +xavier johnson 262785 +xavier king 196919 +xavier king 262774 +xavier laertes 262770 +xavier ovid 328414 +xavier polk 196844 +xavier polk 328474 +xavier polk 394013 +xavier polk 590931 +xavier quirinius 65650 +xavier quirinius 131140 +xavier quirinius 328382 +xavier quirinius 459669 +xavier thompson 393799 +xavier underhill 197012 +xavier white 196858 +xavier white 262712 +xavier xylophone 131250 +xavier zipper 394070 +yuri allen 131129 +yuri allen 459977 +yuri brown 262640 +yuri brown 393858 +yuri carson 459799 +yuri carson 591063 +yuri ellison 197085 +yuri ellison 459558 +yuri falkner 196857 +yuri falkner 525350 +yuri garcia 328378 +yuri hernandez 262588 +yuri johnson 393861 +yuri johnson 394444 +yuri johnson 525638 +yuri king 525526 +yuri laertes 131551 +yuri laertes 459611 +yuri nixon 262644 +yuri nixon 393936 +yuri polk 328197 +yuri polk 328404 +yuri polk 328481 +yuri quirinius 131092 +yuri quirinius 196898 +yuri quirinius 525159 +yuri steinbeck 394037 +yuri steinbeck 525180 +yuri thompson 459710 +yuri underhill 328325 +yuri underhill 459781 +yuri white 131252 +yuri xylophone 262809 +zach allen 394026 +zach brown 262789 +zach brown 262789 +zach brown 459521 +zach brown 459846 +zach brown 590938 +zach carson 262320 +zach ellison 262757 +zach falkner 262608 +zach falkner 262608 +zach garcia 262818 +zach garcia 328314 +zach garcia 393686 +zach garcia 394011 +zach ichabod 262518 +zach ichabod 262563 +zach king 196780 +zach king 196905 +zach king 459991 +zach miller 196923 +zach miller 393813 +zach miller 393892 +zach ovid 196876 +zach ovid 262643 +zach ovid 328023 +zach ovid 459615 +zach quirinius 262471 +zach robinson 196967 +zach steinbeck 131394 +zach steinbeck 459294 +zach thompson 131340 +zach thompson 525538 +zach underhill 131304 +zach white 65733 +zach xylophone 262810 +zach xylophone 459455 +zach young 393615 +zach zipper 197130 +zach zipper 262496 +zach zipper 393937 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-6-4d78f7b1d172d20c91f5867bc13a42a0 b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-6-4d78f7b1d172d20c91f5867bc13a42a0 new file mode 100644 index 000000000000..b3f08818f491 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-6-4d78f7b1d172d20c91f5867bc13a42a0 @@ -0,0 +1,1049 @@ +0.08 0.07999999821186066 +0.1 0.10000000149011612 +0.13 0.12999999523162842 +0.15 0.15000000596046448 +0.27 0.27000001072883606 +0.28 0.2800000011920929 +0.43 0.4300000071525574 +0.52 0.5199999809265137 +0.56 0.5600000023841858 +0.6 0.6000000238418579 +0.61 0.6100000143051147 +0.79 0.7900000214576721 +0.84 0.8399999737739563 +0.98 0.9800000190734863 +1.02 1.2899999916553497 +1.08 1.0800000429153442 +1.08 1.0800000429153442 +1.12 1.1200000047683716 +1.21 2.0000000596046448 +1.25 1.25 +1.27 1.2699999809265137 +1.29 1.2899999618530273 +1.31 1.309999942779541 +1.58 1.5800000429153442 +1.87 1.8700000047683716 +1.91 1.909999966621399 +1.92 3.1699999570846558 +2.07 2.069999933242798 +2.18 2.180000066757202 +2.2 2.200000047683716 +2.35 2.3499999046325684 +2.6 2.5999999046325684 +2.79 2.7899999618530273 +2.92 2.9200000762939453 +2.96 2.9600000381469727 +2.96 2.9600000381469727 +2.97 2.9700000286102295 +3.0 3.0 +3.21 3.340000033378601 +3.28 4.399999976158142 +3.33 3.3299999237060547 +3.61 3.609999895095825 +3.62 3.619999885559082 +3.82 3.819999933242798 +3.86 3.859999895095825 +3.96 3.9600000381469727 +3.97 3.9700000286102295 +4.17 7.7799999713897705 +4.32 4.320000171661377 +4.35 4.349999904632568 +4.41 4.409999847412109 +4.46 4.460000038146973 +4.47 4.46999979019165 +4.57 4.570000171661377 +4.59 4.590000152587891 +4.71 4.710000038146973 +4.72 4.71999979019165 +4.79 4.789999961853027 +4.8 4.800000190734863 +4.92 4.920000076293945 +5.08 5.079999923706055 +5.24 9.559999942779541 +5.28 5.28000020980835 +5.4 5.400000095367432 +5.44 5.440000057220459 +5.45 5.449999809265137 +5.51 5.510000228881836 +5.54 5.539999961853027 +5.62 5.619999885559082 +5.67 5.670000076293945 +5.85 5.849999904632568 +5.88 5.880000114440918 +6.29 6.289999961853027 +6.55 6.550000190734863 +6.57 11.160000324249268 +6.63 9.59000015258789 +6.67 6.670000076293945 +6.72 6.71999979019165 +6.74 6.739999771118164 +6.84 6.840000152587891 +6.87 6.869999885559082 +7.05 7.050000190734863 +7.06 11.769999980926514 +7.11 7.110000133514404 +7.54 7.539999961853027 +7.56 7.559999942779541 +7.79 7.789999961853027 +7.82 7.820000171661377 +7.96 7.960000038146973 +7.96 7.960000038146973 +7.98 7.980000019073486 +8.07 8.069999694824219 +8.07 8.069999694824219 +8.32 8.319999694824219 +8.37 11.339999914169312 +8.42 11.760000109672546 +8.45 8.449999809265137 +8.45 8.449999809265137 +8.45 10.319999814033508 +8.45 16.009999752044678 +8.57 8.569999694824219 +8.61 8.609999656677246 +8.67 8.670000076293945 +8.71 8.710000038146973 +8.79 8.789999961853027 +8.91 13.480000019073486 +9.04 9.039999961853027 +9.13 9.130000114440918 +9.19 15.479999542236328 +9.22 9.220000267028809 +9.25 9.25 +9.26 9.260000228881836 +9.35 12.350000381469727 +9.48 9.479999542236328 +9.56 12.480000495910645 +9.57 9.569999694824219 +9.57 9.569999694824219 +9.68 9.680000305175781 +9.7 9.699999809265137 +9.71 17.25 +9.74 9.739999771118164 +9.8 9.800000190734863 +9.81 9.8100004196167 +9.93 9.930000305175781 +10.09 10.09000015258789 +10.09 10.09000015258789 +10.13 15.640000343322754 +10.16 18.139999866485596 +10.17 14.970000267028809 +10.19 10.1899995803833 +10.2 10.199999809265137 +10.22 10.220000267028809 +10.25 18.859999656677246 +10.26 10.260000228881836 +10.29 10.289999961853027 +10.6 10.600000381469727 +10.66 10.65999984741211 +10.67 10.670000076293945 +10.73 10.729999542236328 +11.15 20.27999973297119 +11.18 13.360000371932983 +11.19 18.979999542236328 +11.22 11.220000267028809 +11.34 16.62000036239624 +11.55 11.550000190734863 +11.57 11.569999694824219 +11.68 17.080000400543213 +11.82 11.819999694824219 +11.89 11.890000343322754 +11.91 11.90999984741211 +12.02 12.020000457763672 +12.16 12.15999984741211 +12.19 20.149999618530273 +12.32 12.319999694824219 +12.42 16.27999997138977 +12.44 12.4399995803833 +12.45 16.799999713897705 +12.46 12.460000038146973 +12.5 15.460000038146973 +12.54 12.539999961853027 +12.85 12.850000381469727 +12.9 12.899999618530273 +13.01 28.47000026702881 +13.1 23.030000686645508 +13.15 20.96999979019165 +13.35 13.350000381469727 +13.87 13.869999885559082 +13.89 13.890000343322754 +13.94 13.9399995803833 +13.99 13.989999771118164 +14.13 24.22000026702881 +14.21 14.210000038146973 +14.3 29.270000457763672 +14.44 14.4399995803833 +14.84 14.84000015258789 +14.92 14.920000076293945 +14.92 25.18000030517578 +14.93 30.40999984741211 +15.1 17.700000286102295 +15.15 24.40999984741211 +15.18 26.730000495910645 +15.22 15.220000267028809 +15.26 15.260000228881836 +15.3 25.5 +15.37 15.369999885559082 +15.45 15.449999809265137 +15.63 28.110000610351562 +15.75 15.75 +15.81 15.8100004196167 +15.86 25.079999923706055 +15.9 21.34999942779541 +15.92 15.920000076293945 +16.08 22.75 +16.09 16.520000159740448 +16.24 27.809999465942383 +16.25 19.419999957084656 +16.48 16.479999542236328 +16.69 16.690000534057617 +16.99 16.989999771118164 +16.99 42.489999771118164 +17.16 21.12999987602234 +17.37 31.360000610351562 +17.74 19.049999713897705 +17.79 45.60000038146973 +17.87 18.710000813007355 +18.2 18.200000762939453 +18.5 31.350000381469727 +18.56 18.559999465942383 +18.63 26.589999198913574 +18.63 30.389999270439148 +18.86 18.96000061184168 +18.89 18.889999389648438 +18.93 18.93000030517578 +19.0 35.62000036239624 +19.03 19.030000686645508 +19.06 19.059999465942383 +19.06 19.059999465942383 +19.13 45.719998359680176 +19.14 19.139999389648438 +19.28 27.600000381469727 +19.69 36.49000024795532 +20.07 46.80000019073486 +20.38 51.72999954223633 +20.64 20.639999389648438 +20.67 31.830000400543213 +20.79 20.790000915527344 +20.81 20.809999465942383 +20.82 21.419999718666077 +20.82 26.359999656677246 +21.18 21.18000030517578 +21.19 21.190000534057617 +21.23 21.229999542236328 +21.28 29.350000381469727 +21.32 21.31999969482422 +21.45 40.510000228881836 +21.49 30.739999771118164 +21.61 37.0600004196167 +21.7 27.58000087738037 +21.8 21.799999237060547 +21.94 23.940000593662262 +22.01 28.850000381469727 +22.08 22.079999923706055 +22.12 22.1200008392334 +22.12 22.1200008392334 +22.25 22.25 +22.27 22.270000457763672 +22.36 22.360000610351562 +22.68 22.68000030517578 +22.78 47.19000053405762 +22.85 33.070000648498535 +22.85 43.980000257492065 +22.94 35.38000011444092 +23.07 23.06999969482422 +23.13 28.799999237060547 +23.17 44.58999979496002 +23.19 23.190000534057617 +23.44 23.440000534057617 +23.45 24.74000072479248 +23.6 33.16000032424927 +23.77 23.770000457763672 +23.96 23.959999084472656 +24.02 24.020000457763672 +24.28 43.310001373291016 +24.49 42.62999963760376 +24.52 32.59000015258789 +24.73 45.369998931884766 +24.79 24.790000915527344 +24.8 34.369998931884766 +24.83 36.05000019073486 +24.86 65.3700008392334 +25.11 44.170000076293945 +25.28 25.280000686645508 +25.37 48.05000114440918 +25.42 40.78999996185303 +25.55 26.62999927997589 +25.67 37.69000053405762 +25.88 61.49999952316284 +26.08 26.079999923706055 +26.39 34.959999084472656 +26.43 26.43000030517578 +26.47 31.389999389648438 +26.49 26.489999771118164 +26.49 48.56999969482422 +26.64 64.32999992370605 +26.71 36.999999046325684 +26.73 45.69000015407801 +26.76 26.760000228881836 +27.07 28.649999737739563 +27.12 32.20000076293945 +27.3 70.61000061035156 +27.31 56.579999923706055 +27.63 27.6299991607666 +27.66 27.65999984741211 +27.72 46.60999870300293 +27.87 27.8700008392334 +28.11 44.59000015258789 +28.31 52.079999923706055 +28.45 74.05000114440918 +28.5 35.36999988555908 +28.56 40.71999931335449 +28.69 28.690000534057617 +28.71 55.46999931335449 +28.79 28.790000915527344 +28.89 56.489999771118164 +28.95 33.410000801086426 +29.02 56.64999961853027 +29.24 99.85000038146973 +29.36 62.52000093460083 +29.4 72.02999925613403 +29.41 64.77999973297119 +29.54 29.540000915527344 +29.59 37.37000012397766 +29.78 66.77999973297119 +30.25 30.32999999821186 +30.36 30.360000610351562 +30.37 31.660000830888748 +30.61 30.610000610351562 +30.62 102.65000009536743 +30.63 30.6299991607666 +30.65 60.19000053405762 +30.71 49.849998474121094 +30.81 55.989999771118164 +31.01 31.010000228881836 +31.15 31.149999618530273 +31.4 31.399999618530273 +31.61 31.610000610351562 +31.67 40.46000003814697 +31.77 42.09000027179718 +31.86 31.860000610351562 +31.91 78.51999855041504 +32.01 60.47999858856201 +32.18 58.61000061035156 +32.2 53.55000019073486 +32.23 42.89999961853027 +32.25 59.83000087738037 +32.37 62.99999809265137 +32.41 32.40999984741211 +32.47 41.14000129699707 +32.52 95.0400013923645 +32.75 56.19000053405762 +32.89 80.07999992370605 +32.92 47.7599983215332 +33.36 45.27000045776367 +33.52 60.010000228881836 +33.55 63.939998507499695 +33.58 54.55000162124634 +33.67 33.66999816894531 +33.76 112.27999687194824 +33.83 59.110002517700195 +33.85 47.719998359680176 +33.87 37.48999881744385 +34.03 71.51999759674072 +34.21 71.57999920845032 +34.35 34.349998474121094 +34.41 59.20000076293945 +34.58 34.58000183105469 +34.73 34.72999954223633 +34.97 45.160000801086426 +35.0 35.0 +35.08 36.060001850128174 +35.13 39.600000858306885 +35.17 64.01999855041504 +35.17 66.52999877929688 +35.56 37.63000130653381 +35.62 80.20999872684479 +35.65 56.83000183105469 +35.68 52.20000046491623 +35.72 98.71999931335449 +35.8 79.96999931335449 +35.89 81.2599983215332 +36.22 55.20000076293945 +36.26 36.2599983215332 +36.58 64.45000267028809 +36.7 73.76000118255615 +36.79 36.790000915527344 +36.89 71.8499984741211 +36.95 36.95000076293945 +37.07 37.06999969482422 +37.1 51.039998054504395 +37.14 53.41999936103821 +37.14 61.15999984741211 +37.24 47.04000186920166 +37.59 50.94000053405762 +37.6 84.39999866485596 +37.72 57.14000117778778 +37.78 68.10999877750874 +37.8 57.94999885559082 +37.85 50.38999843597412 +37.9 77.50000238418579 +38.05 40.24999928474426 +38.05 47.859999656677246 +38.3 98.48999977111816 +38.33 112.09000301361084 +38.57 55.81999969482422 +38.62 73.99999904632568 +38.79 95.44000053405762 +38.85 97.45999908447266 +38.88 48.58000087738037 +38.94 71.3499984741211 +39.01 39.0099983215332 +39.03 57.73999959230423 +39.18 99.6599988937378 +39.21 71.79999923706055 +39.34 39.34000015258789 +39.69 55.69999837875366 +39.81 74.81000137329102 +39.82 39.81999969482422 +39.83 119.80000114440918 +39.87 62.22999954223633 +39.9 64.98000144958496 +39.98 39.97999954223633 +40.0 84.59000015258789 +40.04 101.20000076293945 +40.17 80.41999745368958 +40.24 67.9000015258789 +40.42 78.04999947547913 +40.44 117.94000101089478 +40.78 49.8199987411499 +40.8 40.79999923706055 +40.98 92.01999759674072 +41.2 58.280001163482666 +41.29 41.290000915527344 +41.29 112.87000012397766 +41.31 53.08000135421753 +41.34 53.230000495910645 +41.34 115.33999919891357 +41.36 41.36000061035156 +41.44 85.41999888420105 +41.45 91.29999923706055 +41.62 41.619998931884766 +41.68 109.58000183105469 +41.71 83.06999969482422 +41.81 89.8600025177002 +41.85 76.21999740600586 +41.87 47.48999881744385 +41.89 41.88999938964844 +42.24 72.6500015258789 +42.31 52.91000175476074 +42.42 154.69999504089355 +42.48 71.12999927997589 +42.51 61.069997787475586 +42.55 87.71000003814697 +42.56 71.91000175476074 +42.67 71.45999908447266 +42.76 42.7599983215332 +42.85 85.33999824523926 +43.01 109.53999710083008 +43.02 46.84000039100647 +43.13 43.130001068115234 +43.16 58.079999923706055 +43.17 43.16999816894531 +43.19 156.05999875068665 +43.31 64.10000228881836 +43.37 56.84999895095825 +43.57 43.71999970078468 +43.71 108.69000053405762 +43.73 108.50999927520752 +43.92 48.319998145103455 +44.1 67.1299991607666 +44.22 103.33000373840332 +44.27 74.88000106811523 +44.43 106.65999984741211 +44.57 59.00999927520752 +45.06 45.060001373291016 +45.1 45.099998474121094 +45.19 45.189998626708984 +45.19 117.10000038146973 +45.24 109.34000396728516 +45.34 129.73999881744385 +45.35 109.28999698162079 +45.42 77.61999893188477 +45.45 45.45000076293945 +45.56 137.57999897003174 +45.59 82.54000091552734 +45.68 55.25 +45.92 90.97999954223633 +45.99 47.07000172138214 +46.02 82.27999877929688 +46.09 46.09000015258789 +46.1 98.17999839782715 +46.15 72.7800008058548 +46.18 74.29000091552734 +46.21 52.75999927520752 +46.27 85.61000061035156 +46.43 106.26000118255615 +46.45 110.90000343322754 +46.62 78.0099983215332 +46.8 80.46999740600586 +46.86 62.08000087738037 +46.87 70.80999952554703 +46.88 106.08000183105469 +46.97 88.1100025177002 +47.08 148.28000259399414 +47.27 50.60000038146973 +47.32 118.12999922037125 +47.57 90.32999801635742 +47.59 104.17000007629395 +47.69 99.88999909162521 +47.88 47.880001068115234 +48.01 91.72999802231789 +48.08 79.69000244140625 +48.11 48.11000061035156 +48.15 65.8500018119812 +48.22 105.07000017166138 +48.23 139.52999877929688 +48.25 48.25 +48.28 98.669997215271 +48.37 185.9499979019165 +48.45 48.45000076293945 +48.45 94.54000091552734 +48.52 146.69999885559082 +48.59 89.30999946594238 +49.12 49.119998931884766 +49.28 123.56999969482422 +49.44 110.93999814987183 +49.68 73.63999938964844 +49.77 50.33000046014786 +49.78 66.46999931335449 +50.02 63.380000829696655 +50.08 156.16000366210938 +50.09 106.28000068664551 +50.26 72.05999755859375 +50.28 50.279998779296875 +50.31 117.44000053405762 +50.32 90.13999938964844 +50.4 96.11999988555908 +50.41 98.72999799251556 +50.66 55.37999963760376 +50.7 131.11999821662903 +50.83 98.69000148773193 +50.92 53.70999813079834 +50.96 103.71999835968018 +51.25 67.0 +51.29 87.35000276565552 +51.29 124.93000030517578 +51.72 97.17000198364258 +51.79 139.90000343322754 +51.84 168.94000053405762 +51.85 171.64999961853027 +52.17 206.86999320983887 +52.23 177.1599998474121 +52.44 88.48999881744385 +52.5 105.41000175476074 +52.53 64.98999881744385 +52.72 52.720001220703125 +52.73 74.04999923706055 +52.85 89.63999938964844 +52.87 130.87999725341797 +53.02 100.50999927520752 +53.06 259.9299945831299 +53.18 53.18000030517578 +53.27 53.27000045776367 +53.59 53.59000015258789 +53.78 139.38999938964844 +53.93 57.890000343322754 +53.94 63.529998779296875 +54.1 152.7699956893921 +54.31 77.38000106811523 +54.34 125.46999943256378 +54.43 132.04999923706055 +54.44 103.01999950408936 +54.47 186.52000045776367 +54.73 63.179999351501465 +54.75 112.82999992370605 +54.83 110.82000160217285 +54.99 160.40000343322754 +55.1 161.35999965667725 +55.18 215.58000373840332 +55.2 126.65999984741211 +55.39 137.6699981689453 +55.51 74.55999803543091 +55.63 96.43000030517578 +55.99 187.10999989509583 +56.04 150.5800018310547 +56.07 118.15000057220459 +56.1 135.79000091552734 +56.15 144.64000034332275 +56.33 61.77000188827515 +56.62 78.88999938964844 +56.68 154.13999938964844 +56.81 169.64000129699707 +57.08 69.98000144958496 +57.11 168.0100040435791 +57.12 100.28999710083008 +57.23 65.9399995803833 +57.25 133.46999740600586 +57.29 112.54000091552734 +57.35 110.89999866485596 +57.37 115.109998524189 +57.46 147.78999710083008 +57.64 112.19000101089478 +57.67 57.66999816894531 +57.89 111.15999984741211 +57.93 68.02000045776367 +58.0 123.9399995803833 +58.08 58.08000183105469 +58.09 206.37000274658203 +58.13 105.84999942779541 +58.43 165.0900001525879 +58.52 167.0299997329712 +58.66 136.04000091552734 +58.67 205.36999702453613 +58.75 90.41000083088875 +58.86 165.14000129699707 +59.07 87.86999893188477 +59.16 224.25 +59.21 90.35999870300293 +59.34 127.44999893009663 +59.43 106.50000202655792 +59.45 67.90000057220459 +59.45 197.11999893188477 +59.5 149.63999938964844 +59.55 61.459999203681946 +59.61 85.97000026702881 +59.62 113.3299970626831 +59.68 73.89000034332275 +59.68 94.40999984741211 +59.7 193.1699981689453 +59.71 60.22999906539917 +59.83 145.17000007629395 +59.87 228.80999946594238 +59.99 134.04000282287598 +60.02 66.76000022888184 +60.06 60.060001373291016 +60.12 113.34999942779541 +60.13 214.27000045776367 +60.22 108.10000228881836 +60.26 105.94999847561121 +60.26 165.32999849319458 +60.53 66.37999868392944 +60.6 82.8499984741211 +60.71 72.04999899864197 +60.85 132.36999607086182 +61.21 160.86999797821045 +61.7 127.55000257492065 +61.86 248.9700005054474 +61.88 112.15999984741211 +61.92 125.29999899864197 +61.94 119.6099967956543 +62.14 110.59000015258789 +62.2 149.91000080108643 +62.23 111.3499984741211 +62.3 158.41999912261963 +62.39 110.95999908447266 +62.52 123.97999966144562 +62.72 123.78999900817871 +62.74 153.10000038146973 +62.85 167.01999855041504 +62.9 256.0699996948242 +62.92 129.3899974822998 +63.12 93.47999954223633 +63.33 135.38999938964844 +63.35 116.93999862670898 +63.42 172.76000213623047 +63.51 123.51999855041504 +63.9 135.70000076293945 +64.0 191.55000257492065 +64.22 86.97000122070312 +64.25 131.25 +64.3 122.3800048828125 +64.36 85.59000015258789 +64.46 134.44000053405762 +64.65 143.54000091552734 +64.67 121.15999794006348 +64.77 214.40999603271484 +64.87 194.61000156402588 +64.95 324.87999153137207 +65.02 175.6099967956543 +65.02 259.6299982070923 +65.38 168.7100009918213 +65.43 112.27000069618225 +65.43 289.6800003051758 +65.44 192.89000137150288 +65.55 66.16000306606293 +65.62 139.67000198364258 +65.7 65.69999694824219 +65.72 77.54000091552734 +66.17 177.10999631881714 +66.17 200.60999870300293 +66.36 131.73000144958496 +66.51 83.50000190734863 +66.61 66.61000061035156 +66.61 78.93000030517578 +66.67 129.84999752044678 +66.89 99.96000003814697 +67.12 67.12000274658203 +67.18 234.21000003814697 +67.26 77.9900016784668 +67.38 178.72999572753906 +67.45 197.29999446868896 +67.48 268.0900020599365 +67.59 272.95999336242676 +67.94 125.89000129699707 +67.98 123.36000299453735 +68.01 124.84000396728516 +68.04 166.76000022888184 +68.22 181.57000064849854 +68.25 113.52000045776367 +68.25 163.2900013923645 +68.32 247.04999542236328 +68.41 157.72000312805176 +68.5 156.36999893188477 +68.81 93.02999782562256 +68.85 160.57999649643898 +68.89 89.69999885559082 +68.95 140.99999594688416 +68.96 192.4799976348877 +69.32 156.29000091552734 +69.53 239.17000007629395 +69.74 246.84999418258667 +69.8 101.63000345230103 +69.88 220.45999908447266 +69.96 83.84999942779541 +69.97 169.86000031232834 +70.0 236.76000022888184 +70.04 196.70000076293945 +70.06 92.17999839782715 +70.24 213.77999877929688 +70.35 247.5099983215332 +70.38 179.95999908447266 +70.39 164.79999923706055 +70.52 181.33999824523926 +70.53 70.52999877929688 +70.56 110.53999710083008 +70.85 223.94999885559082 +70.89 149.81999969482422 +70.93 160.79000282287598 +71.01 92.20000267028809 +71.07 112.36000061035156 +71.13 135.149995803833 +71.19 136.1800012588501 +71.26 318.7700004577637 +71.31 231.88999405503273 +71.32 145.21000003814697 +71.35 145.909996509552 +71.5 217.409996509552 +71.54 71.54000091552734 +71.55 90.41000270843506 +71.68 72.9500002861023 +71.68 227.97000122070312 +71.78 189.71999979019165 +71.8 103.66000366210938 +71.89 180.57999992370605 +72.04 156.63000106811523 +72.18 181.46999728679657 +72.51 264.0600047111511 +72.53 265.42000015079975 +72.56 132.61999893188477 +72.62 205.2400016784668 +72.79 72.79000091552734 +72.98 286.76000213623047 +73.18 123.78000068664551 +73.32 175.96999979019165 +73.48 144.8300018310547 +73.63 320.47999143600464 +73.65 114.11000156402588 +73.68 161.0300030708313 +73.88 291.28999376296997 +73.93 156.77999877929688 +74.0 179.8499994277954 +74.02 89.27999687194824 +74.15 74.1500015258789 +74.19 74.47000244259834 +74.19 122.44000244140625 +74.3 159.72000193595886 +74.42 219.62999820709229 +74.45 203.839994430542 +74.52 271.6399955749512 +74.53 342.6200008392334 +74.59 157.12999725341797 +74.62 163.89999961853027 +74.72 139.05000114440918 +74.78 230.83999752998352 +75.03 174.98999881744385 +75.1 214.62999725341797 +75.19 149.6600048840046 +75.29 93.4900016784668 +75.35 256.91999912261963 +75.42 153.46999764442444 +75.66 225.4800033569336 +75.73 236.60000133514404 +75.83 199.4000015258789 +75.88 243.89000129699707 +76.05 179.770001411438 +76.1 257.43999671936035 +76.28 168.48000144958496 +76.28 177.9100022315979 +76.33 363.09000396728516 +76.52 207.39999389648438 +76.69 212.87000370025635 +76.7 307.5399944782257 +76.71 132.17999839782715 +76.72 166.42000007629395 +76.72 258.1899985074997 +76.74 102.81999778747559 +76.92 401.7999897003174 +76.93 384.4699947834015 +77.02 296.6499948501587 +77.1 187.63999557495117 +77.36 281.19999504089355 +77.42 217.32000160217285 +77.57 189.66000270843506 +77.66 267.380003452301 +77.81 178.31999683380127 +77.84 350.7999897003174 +77.89 249.5399990081787 +77.97 178.2599983215332 +78.21 189.1099977493286 +78.26 284.6300048828125 +78.28 306.25 +78.3 168.71000388264656 +78.31 161.80999946594238 +78.62 275.3200035095215 +78.64 94.55999946594238 +78.73 310.6199974119663 +78.89 175.31999969482422 +78.98 293.25000381469727 +79.12 276.419997215271 +79.19 194.30000096559525 +79.21 237.62999820709229 +79.38 151.22999572753906 +79.42 124.6099967956543 +79.48 116.55000305175781 +79.48 200.64000129699707 +79.49 294.1199951171875 +79.54 145.23999786376953 +79.55 254.54000186920166 +79.75 215.13999938964844 +79.83 294.9700012207031 +79.96 79.95999908447266 +79.97 266.4900016784668 +79.99 219.6599998474121 +80.23 248.71000480651855 +80.3 133.4800033569336 +80.3 305.7800064086914 +80.46 249.17000296711922 +80.52 136.21999502182007 +80.58 261.16000175476074 +80.6 281.23999977111816 +80.71 275.0100000500679 +80.74 240.4599997997284 +80.84 142.60999822616577 +80.92 340.5499963760376 +80.96 372.2499928474426 +80.97 93.32000160217285 +80.99 317.74999809265137 +81.17 220.55999755859375 +81.32 158.86000061035156 +81.32 206.16000366210938 +81.47 198.91000175476074 +81.58 173.76000022888184 +81.64 207.5300006866455 +81.66 261.5100030899048 +82.24 243.59999752044678 +82.3 157.18000411987305 +82.34 214.0699977874756 +82.41 135.83000302314758 +82.52 240.23999977111816 +82.55 139.6900042295456 +82.56 175.8799991607666 +82.72 232.63000202178955 +82.97 101.9000015258789 +83.08 181.57000160217285 +83.27 83.2699966430664 +83.33 348.75000198185444 +83.4 196.729998588562 +83.54 299.12000465393066 +83.57 139.38999938964844 +83.58 163.54000091552734 +83.87 263.64000415802 +83.92 432.67000015079975 +83.93 358.9400003552437 +84.03 240.19000244140625 +84.23 209.53000235557556 +84.31 157.0999984741211 +84.38 220.1699981689453 +84.4 443.3400018811226 +84.69 249.4900016784668 +84.72 105.00000095367432 +84.83 157.4800033569336 +85.0 221.21999502182007 +85.03 283.9400005340576 +85.1 217.27999687194824 +85.14 266.7100009918213 +85.23 277.7100009918213 +85.49 261.4599976539612 +85.49 325.9499976634979 +85.51 165.59000205993652 +85.74 341.80999755859375 +85.76 160.57000350952148 +85.87 357.5099983215332 +85.9 202.45000457763672 +86.0 322.60000133514404 +86.22 152.8300018310547 +86.23 152.99000358581543 +86.63 381.5999984741211 +86.69 138.42000198364258 +86.92 147.14999723434448 +86.93 97.59000015258789 +86.93 218.18000030517578 +87.14 219.50999546051025 +87.22 211.01000022888184 +87.4 137.73000198602676 +87.48 153.86000204086304 +87.57 307.0799951553345 +87.61 321.82000064849854 +87.67 248.4600009918213 +87.83 130.96000289916992 +87.94 134.9800043106079 +87.99 227.0399990081787 +88.02 406.7899971008301 +88.05 97.53000259399414 +88.07 210.45000457763672 +88.17 245.6500015258789 +88.22 369.4600009918213 +88.36 305.6399974822998 +88.47 213.9400006532669 +88.48 211.84000635147095 +88.55 199.71000289916992 +88.77 187.49999463558197 +88.78 337.23999977111816 +88.8 254.1300015449524 +88.91 276.40999829769135 +89.01 98.69000244140625 +89.03 431.6499996185303 +89.1 113.83999919891357 +89.15 234.32000160217285 +89.28 343.82000064849854 +89.38 496.16999435424805 +89.53 153.05999755859375 +89.55 89.55000305175781 +89.55 96.27000284194946 +89.55 172.62000274658203 +89.8 205.1400022506714 +89.81 225.64000058174133 +89.93 89.93000030517578 +90.05 90.05000305175781 +90.05 247.1500015258789 +90.07 137.82999801635742 +90.12 234.76000308990479 +90.2 208.32999616861343 +90.25 386.8999948501587 +90.28 334.17000007629395 +90.35 158.36999893188477 +90.38 148.65999841690063 +90.51 268.7700004577637 +90.56 298.8899937272072 +90.69 181.1000051498413 +90.69 523.360002592206 +90.73 260.59000366926193 +90.77 203.03999733924866 +91.05 263.67000579833984 +91.16 298.55999755859375 +91.42 196.41999912261963 +91.48 144.56000471115112 +91.53 301.9800033569336 +91.61 215.59000027179718 +91.63 432.17999362945557 +91.78 330.9499988555908 +91.88 277.8299951553345 +91.97 205.4900016784668 +92.05 317.69000363349915 +92.11 307.70000088214874 +92.33 92.33000183105469 +92.37 226.410005569458 +92.4 265.1600036621094 +92.55 295.59000039100647 +92.61 249.38999938964844 +92.82 182.87000274658203 +92.96 185.98999691009521 +92.98 494.779993057251 +93.03 165.80999958515167 +93.09 456.1800003051758 +93.11 427.2800006866455 +93.61 465.8599934577942 +93.64 314.1999969482422 +93.73 475.3300018310547 +94.08 187.56000137329102 +94.15 362.9200019836426 +94.25 308.1900006532669 +94.27 351.70999336242676 +94.31 216.75 +94.33 193.02000427246094 +94.34 320.75000190734863 +94.38 260.18999683856964 +94.43 174.89999771118164 +94.54 340.19000244140625 +94.66 192.19000625610352 +94.68 189.23999977111816 +94.68 344.1700019836426 +94.72 274.6800003051758 +95.07 417.67000102996826 +95.11 393.99999433755875 +95.28 140.37999725341797 +95.33 244.99000671505928 +95.34 192.92999649047852 +95.38 230.36000156402588 +95.48 406.1000007688999 +95.53 248.58999633789062 +95.53 252.71000289916992 +95.81 134.81999588012695 +95.81 551.9899978637695 +95.84 274.1599931716919 +96.09 191.52999687194824 +96.23 398.2100067138672 +96.25 261.8400020599365 +96.29 199.10999870300293 +96.38 323.4199962615967 +96.62 309.4900064468384 +96.68 336.87000274658203 +96.73 272.0500030517578 +96.78 190.27000045776367 +96.91 180.18000030517578 +96.94 216.54999923706055 +97.09 428.03999519348145 +97.24 346.41000083088875 +97.26 373.67999935150146 +97.39 257.9600028991699 +97.46 449.1699924468994 +97.51 132.09000396728516 +97.56 97.55999755859375 +97.57 592.3499927520752 +97.65 196.34000301361084 +97.68 258.7100033760071 +97.71 175.70000076293945 +97.81 297.20999908447266 +97.83 396.38999938964844 +97.87 254.50000381469727 +98.18 105.29000043869019 +98.22 438.4100036621094 +98.23 525.5100040435791 +98.31 335.93999576568604 +98.48 286.0400047302246 +98.51 132.86000061035156 +98.57 404.3500061035156 +98.72 263.8600025177002 +98.96 288.1999988555908 +99.13 243.6900019645691 +99.15 210.11000061035156 +99.21 347.7999954223633 +99.24 537.6500015258789 +99.29 291.48000717163086 +99.36 106.41000080108643 +99.62 417.3700008392334 +99.65 185.62000179290771 +99.67 180.92999649047852 +99.68 230.6400032043457 +99.91 367.2900071144104 +99.92 376.32999646663666 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-7-20fdc99aa046b2c41d9b85ab338c749c b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-7-20fdc99aa046b2c41d9b85ab338c749c new file mode 100644 index 000000000000..1a4528978b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-7-20fdc99aa046b2c41d9b85ab338c749c @@ -0,0 +1,1049 @@ + 65560 3.1 + 65718 2.38 + 65740 3.04 +alice allen 65662 1.55 +alice allen 65720 1.79 +alice allen 65758 1.98 +alice brown 65696 0.13 +alice carson 65559 4.2 +alice davidson 65547 1.51 +alice falkner 65669 4.19 +alice garcia 65613 0.72 +alice hernandez 65737 0.92 +alice hernandez 65784 2.09 +alice johnson 65739 2.55 +alice king 65660 3.84 +alice king 65734 2.96 +alice king 65738 2.14 +alice laertes 65669 0.28 +alice laertes 65671 1.16 +alice miller 65590 4.95 +alice nixon 65586 2.98 +alice nixon 65595 2.36 +alice nixon 65604 2.45 +alice ovid 65737 3.2 +alice polk 65548 1.23 +alice quirinius 65636 0.23 +alice quirinius 65728 0.82 +alice robinson 65606 3.99 +alice robinson 65789 4.35 +alice steinbeck 65578 4.72 +alice steinbeck 65673 3.97 +alice steinbeck 65786 3.92 +alice underhill 65750 2.06 +alice van buren 65562 2.43 +alice xylophone 65578 2.22 +alice xylophone 65585 2.11 +alice xylophone 65599 2.92 +alice zipper 65553 3.78 +alice zipper 65662 2.61 +alice zipper 65766 3.12 +bob brown 65584 2.09 +bob brown 65777 1.62 +bob brown 65783 2.4 +bob carson 65713 3.87 +bob davidson 65664 4.25 +bob davidson 65693 3.54 +bob davidson 65768 2.91 +bob ellison 65591 2.23 +bob ellison 65624 1.69 +bob ellison 65721 1.69 +bob ellison 65760 1.69 +bob falkner 65789 0.21 +bob garcia 65585 3.91 +bob garcia 65598 2.64 +bob garcia 65673 3.26 +bob garcia 65754 3.29 +bob garcia 65782 2.86 +bob hernandez 65557 3.72 +bob ichabod 65549 1.39 +bob king 65715 2.76 +bob king 65757 1.71 +bob king 65783 2.24 +bob laertes 65602 4.98 +bob laertes 65663 3.56 +bob miller 65608 4.95 +bob ovid 65564 1.23 +bob ovid 65619 1.53 +bob ovid 65686 1.84 +bob ovid 65726 2.38 +bob polk 65594 0.28 +bob quirinius 65700 3.82 +bob steinbeck 65637 0.22 +bob van buren 65778 2.89 +bob white 65543 4.75 +bob white 65605 2.89 +bob xylophone 65574 1.7 +bob xylophone 65666 2.51 +bob young 65556 0.95 +bob zipper 65559 3.18 +bob zipper 65633 3.2 +bob zipper 65739 3.24 +calvin allen 65669 2.3 +calvin brown 65537 1.1 +calvin brown 65580 2.82 +calvin brown 65677 2.57 +calvin carson 65637 1.33 +calvin davidson 65541 1.98 +calvin davidson 65564 1.6 +calvin ellison 65667 1.85 +calvin falkner 65573 3.52 +calvin falkner 65596 3.31 +calvin falkner 65738 2.36 +calvin falkner 65762 2.26 +calvin falkner 65778 2.7 +calvin falkner 65784 2.98 +calvin garcia 65664 2.9 +calvin hernandez 65578 2.08 +calvin johnson 65731 4.69 +calvin laertes 65570 4.7 +calvin laertes 65684 3.68 +calvin nixon 65654 2.74 +calvin nixon 65724 3.27 +calvin nixon 65749 2.88 +calvin ovid 65554 3.34 +calvin ovid 65643 2.38 +calvin ovid 65663 2.74 +calvin ovid 65715 2.47 +calvin polk 65731 4.36 +calvin quirinius 65741 4.0 +calvin quirinius 65769 2.2 +calvin robinson 65581 3.32 +calvin steinbeck 65680 1.29 +calvin steinbeck 65762 1.3 +calvin steinbeck 65779 1.98 +calvin thompson 65560 4.08 +calvin thompson 65640 3.33 +calvin underhill 65732 2.41 +calvin van buren 65552 1.05 +calvin van buren 65771 1.33 +calvin white 65553 4.7 +calvin white 65561 2.68 +calvin xylophone 65575 4.6 +calvin xylophone 65596 4.77 +calvin xylophone 65713 3.23 +calvin young 65574 0.27 +calvin young 65746 0.9 +calvin zipper 65669 4.4 +calvin zipper 65739 2.29 +david allen 65588 3.86 +david allen 65617 3.18 +david brown 65637 1.17 +david brown 65760 1.01 +david davidson 65559 1.37 +david davidson 65756 1.57 +david davidson 65778 1.89 +david davidson 65779 2.43 +david ellison 65634 3.23 +david ellison 65724 2.95 +david ellison 65724 2.95 +david hernandez 65763 1.15 +david ichabod 65699 1.67 +david ichabod 65715 1.26 +david laertes 65762 1.38 +david nixon 65536 1.27 +david ovid 65623 0.23 +david ovid 65628 1.15 +david quirinius 65697 1.14 +david quirinius 65759 1.65 +david quirinius 65779 1.93 +david robinson 65762 3.51 +david robinson 65775 3.38 +david thompson 65550 3.3 +david underhill 65602 0.12 +david underhill 65662 2.29 +david underhill 65751 2.43 +david van buren 65625 1.55 +david van buren 65634 3.25 +david white 65678 0.17 +david xylophone 65537 1.07 +david xylophone 65670 0.72 +david xylophone 65764 0.94 +david young 65551 4.51 +david young 65694 2.74 +ethan allen 65747 3.61 +ethan brown 65539 2.9 +ethan brown 65617 1.59 +ethan brown 65685 2.17 +ethan brown 65685 2.17 +ethan brown 65722 2.64 +ethan brown 65733 2.75 +ethan carson 65742 2.84 +ethan ellison 65714 4.87 +ethan ellison 65732 3.9 +ethan falkner 65577 3.61 +ethan falkner 65614 1.95 +ethan garcia 65736 4.63 +ethan hernandez 65618 0.46 +ethan johnson 65536 1.76 +ethan king 65614 0.92 +ethan laertes 65562 2.89 +ethan laertes 65597 3.45 +ethan laertes 65628 3.15 +ethan laertes 65643 3.53 +ethan laertes 65680 3.18 +ethan laertes 65745 3.36 +ethan laertes 65760 3.01 +ethan miller 65712 1.97 +ethan nixon 65766 4.1 +ethan ovid 65697 3.81 +ethan polk 65589 0.7 +ethan polk 65615 1.92 +ethan polk 65622 2.0 +ethan polk 65622 2.0 +ethan quirinius 65542 4.64 +ethan quirinius 65591 3.97 +ethan quirinius 65706 2.88 +ethan robinson 65547 2.2 +ethan robinson 65659 2.17 +ethan underhill 65570 2.45 +ethan van buren 65572 1.11 +ethan white 65677 3.42 +ethan white 65788 4.09 +ethan xylophone 65595 4.66 +ethan zipper 65593 2.1 +ethan zipper 65680 2.53 +fred davidson 65552 0.37 +fred davidson 65595 2.31 +fred davidson 65721 2.65 +fred ellison 65548 4.32 +fred ellison 65691 3.44 +fred ellison 65771 3.1 +fred falkner 65637 4.59 +fred falkner 65648 2.79 +fred falkner 65783 2.35 +fred hernandez 65541 3.87 +fred ichabod 65572 1.45 +fred ichabod 65789 1.68 +fred johnson 65758 3.86 +fred king 65694 4.05 +fred king 65745 4.5 +fred laertes 65769 3.89 +fred miller 65536 2.59 +fred nixon 65560 2.52 +fred nixon 65612 2.0 +fred nixon 65703 1.35 +fred nixon 65705 1.16 +fred polk 65603 2.9 +fred polk 65656 1.65 +fred polk 65701 1.75 +fred polk 65706 1.4 +fred quirinius 65697 1.91 +fred quirinius 65782 3.24 +fred robinson 65623 2.6 +fred steinbeck 65544 2.0 +fred steinbeck 65651 3.32 +fred steinbeck 65755 3.51 +fred underhill 65629 2.56 +fred van buren 65537 4.97 +fred van buren 65561 3.28 +fred van buren 65745 3.24 +fred van buren 65789 3.27 +fred white 65589 3.16 +fred young 65594 0.34 +fred young 65773 1.56 +fred zipper 65553 0.37 +gabriella allen 65646 1.68 +gabriella allen 65677 1.6 +gabriella brown 65704 0.02 +gabriella brown 65753 1.86 +gabriella carson 65586 0.37 +gabriella davidson 65565 3.45 +gabriella ellison 65706 1.15 +gabriella ellison 65716 2.06 +gabriella falkner 65623 2.09 +gabriella falkner 65711 2.48 +gabriella falkner 65767 1.82 +gabriella garcia 65571 3.17 +gabriella hernandez 65587 0.74 +gabriella hernandez 65717 0.96 +gabriella ichabod 65559 0.63 +gabriella ichabod 65633 2.42 +gabriella ichabod 65702 3.27 +gabriella ichabod 65712 3.6 +gabriella ichabod 65717 3.56 +gabriella king 65651 3.59 +gabriella king 65657 2.17 +gabriella laertes 65781 2.81 +gabriella miller 65646 3.47 +gabriella ovid 65556 1.23 +gabriella ovid 65583 1.95 +gabriella polk 65701 3.58 +gabriella polk 65790 2.08 +gabriella steinbeck 65582 3.6 +gabriella steinbeck 65653 2.7 +gabriella thompson 65682 1.78 +gabriella thompson 65755 3.21 +gabriella thompson 65766 2.71 +gabriella van buren 65581 3.36 +gabriella van buren 65644 2.6 +gabriella white 65638 4.55 +gabriella young 65699 4.13 +gabriella young 65774 3.58 +gabriella zipper 65540 0.96 +gabriella zipper 65754 2.13 +holly allen 65596 0.05 +holly brown 65599 3.23 +holly brown 65619 3.4 +holly falkner 65720 4.06 +holly hernandez 65602 3.67 +holly hernandez 65686 3.39 +holly hernandez 65750 3.2 +holly hernandez 65788 2.95 +holly ichabod 65711 4.73 +holly ichabod 65749 3.54 +holly ichabod 65752 3.27 +holly johnson 65655 4.19 +holly johnson 65662 3.84 +holly johnson 65755 2.65 +holly king 65549 3.61 +holly king 65648 2.37 +holly laertes 65664 4.14 +holly miller 65653 3.84 +holly nixon 65539 4.09 +holly nixon 65658 3.04 +holly polk 65743 2.1 +holly polk 65751 2.78 +holly robinson 65564 0.24 +holly thompson 65538 2.39 +holly thompson 65578 1.49 +holly thompson 65713 1.54 +holly underhill 65634 4.69 +holly underhill 65654 3.08 +holly underhill 65721 3.14 +holly underhill 65759 2.61 +holly van buren 65727 0.04 +holly white 65536 4.59 +holly white 65602 4.13 +holly xylophone 65544 1.49 +holly young 65606 4.39 +holly young 65765 3.81 +holly zipper 65607 4.12 +holly zipper 65755 3.3 +irene allen 65556 3.45 +irene brown 65633 4.8 +irene brown 65650 3.77 +irene brown 65765 3.53 +irene carson 65590 2.83 +irene ellison 65659 3.15 +irene ellison 65696 2.0 +irene falkner 65620 0.99 +irene falkner 65661 1.41 +irene garcia 65660 1.55 +irene garcia 65711 1.72 +irene garcia 65787 1.57 +irene ichabod 65645 0.95 +irene ichabod 65722 1.49 +irene johnson 65583 4.51 +irene laertes 65664 2.5 +irene laertes 65710 1.5 +irene laertes 65722 2.01 +irene miller 65730 4.33 +irene nixon 65631 2.36 +irene nixon 65643 3.43 +irene nixon 65653 2.43 +irene ovid 65691 3.24 +irene ovid 65734 3.17 +irene ovid 65753 3.18 +irene polk 65551 4.51 +irene polk 65575 2.97 +irene polk 65579 3.58 +irene polk 65595 2.82 +irene polk 65610 2.99 +irene quirinius 65724 3.5 +irene quirinius 65769 3.85 +irene quirinius 65773 4.21 +irene robinson 65554 2.67 +irene steinbeck 65683 1.48 +irene thompson 65688 0.06 +irene underhill 65591 3.61 +irene underhill 65787 4.01 +irene van buren 65579 4.26 +irene van buren 65589 4.37 +irene xylophone 65775 4.81 +jessica brown 65588 2.87 +jessica carson 65553 1.02 +jessica carson 65672 1.82 +jessica carson 65747 1.91 +jessica davidson 65549 4.48 +jessica davidson 65606 2.72 +jessica davidson 65675 2.23 +jessica davidson 65727 2.12 +jessica ellison 65567 3.0 +jessica ellison 65663 3.15 +jessica falkner 65584 2.11 +jessica garcia 65676 2.13 +jessica garcia 65789 3.54 +jessica ichabod 65704 3.48 +jessica johnson 65607 3.55 +jessica johnson 65720 3.0 +jessica miller 65733 2.9 +jessica nixon 65590 2.18 +jessica nixon 65774 3.1 +jessica ovid 65582 3.23 +jessica ovid 65751 3.14 +jessica polk 65637 1.12 +jessica quirinius 65562 3.06 +jessica quirinius 65608 1.75 +jessica quirinius 65712 1.51 +jessica quirinius 65716 1.37 +jessica robinson 65576 1.11 +jessica thompson 65581 4.94 +jessica thompson 65675 3.56 +jessica underhill 65656 3.97 +jessica underhill 65702 3.01 +jessica underhill 65783 3.5 +jessica van buren 65615 2.15 +jessica white 65544 1.89 +jessica white 65570 1.92 +jessica white 65594 2.67 +jessica white 65673 2.1 +jessica white 65779 2.04 +jessica xylophone 65562 0.49 +jessica young 65623 0.5 +jessica young 65711 2.37 +jessica zipper 65600 1.71 +jessica zipper 65657 1.0 +jessica zipper 65778 0.9 +katie allen 65542 1.3 +katie brown 65590 3.06 +katie davidson 65619 1.89 +katie ellison 65675 1.92 +katie ellison 65699 2.55 +katie falkner 65728 2.42 +katie garcia 65625 4.18 +katie garcia 65747 4.33 +katie hernandez 65550 2.1 +katie ichabod 65658 1.84 +katie ichabod 65726 2.41 +katie ichabod 65757 3.2 +katie king 65629 0.86 +katie king 65647 2.09 +katie king 65776 2.74 +katie miller 65541 0.66 +katie miller 65661 1.39 +katie nixon 65669 2.06 +katie ovid 65681 1.61 +katie polk 65746 4.74 +katie polk 65784 2.57 +katie robinson 65697 4.74 +katie van buren 65643 0.61 +katie van buren 65730 2.79 +katie white 65620 0.42 +katie white 65719 1.38 +katie xylophone 65585 3.02 +katie young 65644 2.64 +katie young 65746 2.45 +katie young 65764 2.66 +katie zipper 65568 2.33 +katie zipper 65733 1.97 +luke allen 65547 2.04 +luke allen 65552 1.49 +luke allen 65576 2.14 +luke allen 65681 2.2 +luke allen 65776 1.92 +luke brown 65719 4.3 +luke davidson 65656 3.37 +luke davidson 65791 3.9 +luke ellison 65582 0.23 +luke ellison 65664 0.51 +luke ellison 65779 0.87 +luke falkner 65589 2.22 +luke falkner 65618 1.22 +luke garcia 65687 4.73 +luke garcia 65778 3.56 +luke ichabod 65629 3.07 +luke ichabod 65654 3.58 +luke johnson 65545 4.33 +luke johnson 65716 3.01 +luke johnson 65718 3.17 +luke laertes 65608 3.79 +luke laertes 65657 3.89 +luke laertes 65685 2.82 +luke laertes 65730 2.96 +luke laertes 65756 3.19 +luke miller 65752 4.7 +luke ovid 65569 4.4 +luke ovid 65693 2.58 +luke polk 65645 0.57 +luke polk 65658 2.73 +luke quirinius 65655 4.1 +luke robinson 65634 4.9 +luke robinson 65772 4.19 +luke thompson 65626 0.15 +luke underhill 65553 1.28 +luke underhill 65571 0.84 +luke underhill 65651 1.14 +luke van buren 65678 0.34 +luke white 65693 0.91 +luke xylophone 65597 2.27 +luke zipper 65641 4.63 +mike allen 65706 3.06 +mike brown 65654 3.57 +mike carson 65698 4.46 +mike carson 65700 3.89 +mike carson 65751 3.58 +mike davidson 65658 2.06 +mike davidson 65759 3.34 +mike ellison 65598 3.96 +mike ellison 65606 3.28 +mike ellison 65718 3.38 +mike ellison 65738 2.56 +mike ellison 65760 3.03 +mike falkner 65609 4.85 +mike garcia 65571 1.82 +mike garcia 65600 1.42 +mike garcia 65770 1.92 +mike hernandez 65548 1.42 +mike hernandez 65672 1.75 +mike ichabod 65621 3.73 +mike king 65563 4.34 +mike king 65586 3.75 +mike king 65591 3.09 +mike king 65642 2.69 +mike king 65769 2.36 +mike king 65776 2.55 +mike miller 65549 3.96 +mike nixon 65619 0.09 +mike nixon 65704 2.15 +mike polk 65619 4.13 +mike polk 65658 4.27 +mike polk 65704 3.77 +mike quirinius 65717 2.81 +mike steinbeck 65550 3.18 +mike steinbeck 65564 2.58 +mike steinbeck 65573 2.12 +mike steinbeck 65749 1.72 +mike van buren 65620 0.09 +mike van buren 65770 0.88 +mike white 65648 1.72 +mike white 65685 1.12 +mike white 65769 1.74 +mike white 65778 2.05 +mike young 65545 1.69 +mike young 65581 0.92 +mike young 65736 1.84 +mike zipper 65552 4.8 +mike zipper 65695 4.16 +mike zipper 65779 4.22 +nick allen 65641 3.78 +nick allen 65786 3.74 +nick brown 65724 4.61 +nick davidson 65601 0.88 +nick ellison 65691 4.04 +nick ellison 65745 3.22 +nick falkner 65583 4.44 +nick falkner 65676 3.08 +nick garcia 65695 1.05 +nick garcia 65712 2.18 +nick garcia 65720 1.94 +nick ichabod 65572 2.62 +nick ichabod 65681 2.99 +nick ichabod 65737 3.55 +nick johnson 65585 0.56 +nick johnson 65784 0.42 +nick laertes 65624 0.16 +nick miller 65757 4.23 +nick nixon 65650 0.7 +nick ovid 65719 3.93 +nick polk 65716 3.66 +nick quirinius 65588 2.88 +nick quirinius 65723 2.42 +nick robinson 65547 0.21 +nick robinson 65675 0.57 +nick steinbeck 65689 4.11 +nick thompson 65610 2.32 +nick underhill 65619 2.73 +nick van buren 65603 1.84 +nick xylophone 65644 2.06 +nick young 65654 2.06 +nick young 65660 2.28 +nick zipper 65757 3.8 +nick zipper 65765 2.04 +oscar allen 65644 3.01 +oscar brown 65614 3.95 +oscar carson 65537 4.29 +oscar carson 65548 2.89 +oscar carson 65549 1.95 +oscar carson 65624 2.56 +oscar carson 65697 3.03 +oscar davidson 65556 0.6 +oscar ellison 65630 1.47 +oscar ellison 65630 1.47 +oscar falkner 65692 3.44 +oscar garcia 65751 3.71 +oscar hernandez 65683 3.32 +oscar hernandez 65707 2.25 +oscar ichabod 65536 1.8 +oscar ichabod 65562 1.18 +oscar ichabod 65637 1.91 +oscar ichabod 65763 1.96 +oscar johnson 65645 1.69 +oscar johnson 65778 1.59 +oscar king 65541 3.96 +oscar king 65550 4.31 +oscar king 65787 3.54 +oscar laertes 65625 2.26 +oscar laertes 65690 2.55 +oscar laertes 65756 2.85 +oscar laertes 65790 2.39 +oscar nixon 65596 3.12 +oscar ovid 65536 2.94 +oscar ovid 65615 2.95 +oscar ovid 65659 3.41 +oscar polk 65541 1.12 +oscar polk 65643 1.62 +oscar quirinius 65541 3.5 +oscar quirinius 65560 2.87 +oscar quirinius 65689 3.35 +oscar quirinius 65720 3.03 +oscar robinson 65537 0.29 +oscar robinson 65658 0.29 +oscar robinson 65687 1.5 +oscar robinson 65782 1.78 +oscar steinbeck 65709 4.96 +oscar thompson 65542 0.48 +oscar thompson 65698 2.07 +oscar thompson 65727 2.01 +oscar thompson 65738 1.8 +oscar underhill 65626 3.49 +oscar van buren 65581 2.33 +oscar van buren 65635 2.45 +oscar van buren 65705 2.68 +oscar white 65552 3.05 +oscar white 65564 2.58 +oscar white 65671 2.56 +oscar white 65735 2.47 +oscar xylophone 65773 1.51 +oscar xylophone 65773 1.51 +oscar xylophone 65775 1.82 +oscar zipper 65568 4.77 +oscar zipper 65740 3.81 +oscar zipper 65777 3.26 +priscilla brown 65670 2.91 +priscilla brown 65690 2.83 +priscilla brown 65749 2.07 +priscilla carson 65658 1.43 +priscilla carson 65687 2.97 +priscilla carson 65755 2.87 +priscilla ichabod 65627 4.95 +priscilla ichabod 65759 3.41 +priscilla johnson 65543 3.85 +priscilla johnson 65633 2.98 +priscilla johnson 65668 2.23 +priscilla johnson 65681 1.98 +priscilla johnson 65755 1.94 +priscilla king 65646 1.93 +priscilla nixon 65564 0.31 +priscilla nixon 65600 2.35 +priscilla ovid 65541 3.8 +priscilla ovid 65790 2.37 +priscilla polk 65747 3.1 +priscilla quirinius 65672 0.81 +priscilla thompson 65654 2.04 +priscilla underhill 65715 0.82 +priscilla underhill 65729 1.33 +priscilla van buren 65607 0.23 +priscilla van buren 65685 0.81 +priscilla van buren 65749 1.84 +priscilla white 65652 4.79 +priscilla xylophone 65538 3.56 +priscilla xylophone 65763 2.48 +priscilla xylophone 65774 1.84 +priscilla young 65585 2.92 +priscilla young 65658 3.77 +priscilla zipper 65622 4.62 +priscilla zipper 65726 2.67 +quinn allen 65657 3.02 +quinn allen 65708 3.35 +quinn brown 65691 4.37 +quinn brown 65700 3.28 +quinn brown 65733 3.27 +quinn davidson 65549 0.03 +quinn davidson 65714 1.02 +quinn davidson 65776 2.19 +quinn davidson 65779 2.66 +quinn ellison 65705 0.54 +quinn ellison 65778 2.74 +quinn garcia 65568 1.04 +quinn garcia 65604 0.79 +quinn garcia 65610 1.85 +quinn garcia 65773 1.67 +quinn ichabod 65564 0.65 +quinn king 65558 2.38 +quinn king 65649 1.53 +quinn laertes 65542 2.33 +quinn laertes 65560 2.03 +quinn laertes 65627 2.08 +quinn nixon 65659 0.72 +quinn ovid 65699 2.87 +quinn quirinius 65747 0.86 +quinn robinson 65627 4.14 +quinn steinbeck 65578 4.92 +quinn steinbeck 65763 3.18 +quinn thompson 65643 4.62 +quinn thompson 65774 2.43 +quinn underhill 65549 3.89 +quinn underhill 65694 2.39 +quinn underhill 65767 2.48 +quinn van buren 65725 0.53 +quinn young 65647 3.54 +quinn zipper 65579 1.83 +quinn zipper 65693 2.1 +rachel allen 65661 4.66 +rachel allen 65709 4.53 +rachel brown 65586 0.43 +rachel brown 65587 2.69 +rachel brown 65587 2.69 +rachel brown 65610 2.67 +rachel brown 65693 3.07 +rachel carson 65677 4.58 +rachel carson 65682 4.54 +rachel davidson 65755 2.28 +rachel ellison 65761 0.05 +rachel falkner 65616 1.2 +rachel falkner 65681 2.01 +rachel falkner 65693 2.56 +rachel falkner 65764 2.91 +rachel johnson 65658 3.02 +rachel king 65604 2.32 +rachel king 65643 2.4 +rachel laertes 65562 4.43 +rachel laertes 65624 2.8 +rachel ovid 65721 0.3 +rachel ovid 65736 0.92 +rachel polk 65686 2.56 +rachel quirinius 65787 2.95 +rachel robinson 65544 0.06 +rachel robinson 65717 1.55 +rachel robinson 65724 2.27 +rachel thompson 65648 2.49 +rachel thompson 65662 3.16 +rachel thompson 65733 2.51 +rachel underhill 65667 2.29 +rachel white 65615 1.99 +rachel white 65717 3.08 +rachel young 65727 1.75 +rachel zipper 65757 2.82 +rachel zipper 65785 3.62 +sarah carson 65679 1.04 +sarah carson 65693 0.85 +sarah carson 65694 2.06 +sarah ellison 65611 1.64 +sarah falkner 65606 0.77 +sarah falkner 65680 2.85 +sarah garcia 65563 4.89 +sarah garcia 65638 4.78 +sarah garcia 65661 3.63 +sarah ichabod 65667 3.94 +sarah ichabod 65671 2.33 +sarah johnson 65659 3.51 +sarah johnson 65716 4.21 +sarah johnson 65731 3.81 +sarah johnson 65751 3.37 +sarah king 65650 1.05 +sarah king 65699 0.99 +sarah miller 65557 0.2 +sarah ovid 65550 3.21 +sarah robinson 65677 4.9 +sarah robinson 65763 2.99 +sarah steinbeck 65721 2.82 +sarah white 65622 0.07 +sarah white 65747 2.29 +sarah xylophone 65678 0.15 +sarah young 65595 2.15 +sarah zipper 65550 2.22 +tom brown 65593 1.64 +tom brown 65675 2.83 +tom carson 65539 4.38 +tom carson 65624 4.28 +tom carson 65780 4.03 +tom davidson 65780 2.4 +tom ellison 65578 3.3 +tom ellison 65670 4.04 +tom ellison 65756 3.51 +tom falkner 65574 1.09 +tom falkner 65583 2.05 +tom hernandez 65575 2.35 +tom hernandez 65632 2.64 +tom ichabod 65588 1.48 +tom johnson 65536 4.68 +tom johnson 65789 4.6 +tom king 65576 2.87 +tom laertes 65617 1.51 +tom laertes 65701 1.93 +tom miller 65594 1.14 +tom miller 65603 1.25 +tom miller 65704 1.26 +tom nixon 65672 0.04 +tom ovid 65628 1.95 +tom polk 65652 2.54 +tom polk 65742 2.26 +tom quirinius 65563 4.51 +tom quirinius 65783 4.17 +tom robinson 65626 3.12 +tom robinson 65632 3.61 +tom robinson 65691 3.13 +tom robinson 65758 3.45 +tom steinbeck 65666 1.34 +tom van buren 65621 0.66 +tom van buren 65652 2.71 +tom van buren 65669 3.47 +tom white 65548 2.13 +tom young 65544 3.73 +tom young 65546 2.02 +tom zipper 65789 1.51 +ulysses brown 65735 2.92 +ulysses carson 65602 0.34 +ulysses carson 65643 2.11 +ulysses carson 65703 1.43 +ulysses carson 65716 1.41 +ulysses davidson 65750 3.04 +ulysses ellison 65575 4.39 +ulysses garcia 65666 4.2 +ulysses hernandez 65651 1.75 +ulysses hernandez 65702 2.42 +ulysses hernandez 65786 2.53 +ulysses ichabod 65551 0.33 +ulysses ichabod 65566 2.19 +ulysses johnson 65776 4.79 +ulysses king 65649 4.46 +ulysses laertes 65691 4.55 +ulysses laertes 65711 3.54 +ulysses laertes 65781 3.66 +ulysses miller 65610 0.24 +ulysses miller 65637 1.08 +ulysses nixon 65603 1.85 +ulysses ovid 65656 3.17 +ulysses polk 65563 1.32 +ulysses polk 65580 3.05 +ulysses polk 65612 3.46 +ulysses polk 65777 3.75 +ulysses quirinius 65786 2.13 +ulysses robinson 65744 1.97 +ulysses steinbeck 65611 2.74 +ulysses steinbeck 65680 2.64 +ulysses thompson 65788 1.51 +ulysses underhill 65570 0.38 +ulysses underhill 65616 0.8 +ulysses underhill 65620 2.09 +ulysses underhill 65623 2.69 +ulysses underhill 65641 2.54 +ulysses underhill 65713 2.9 +ulysses underhill 65785 2.97 +ulysses van buren 65684 1.42 +ulysses white 65654 0.14 +ulysses white 65675 1.51 +ulysses xylophone 65623 2.3 +ulysses xylophone 65636 2.69 +ulysses xylophone 65781 3.22 +ulysses young 65675 1.34 +ulysses young 65736 2.01 +ulysses young 65748 2.24 +victor allen 65684 0.83 +victor allen 65707 2.31 +victor brown 65550 4.57 +victor brown 65555 3.54 +victor brown 65622 2.61 +victor brown 65673 2.34 +victor davidson 65579 0.61 +victor davidson 65628 1.52 +victor davidson 65783 2.25 +victor ellison 65641 1.32 +victor ellison 65782 2.59 +victor hernandez 65571 3.62 +victor hernandez 65659 3.68 +victor hernandez 65708 3.35 +victor hernandez 65735 2.88 +victor hernandez 65775 2.62 +victor johnson 65606 3.03 +victor johnson 65607 2.3 +victor johnson 65607 2.3 +victor king 65721 4.09 +victor king 65743 2.45 +victor laertes 65638 1.46 +victor laertes 65644 2.38 +victor miller 65570 0.1 +victor nixon 65709 0.74 +victor nixon 65791 1.73 +victor ovid 65649 4.93 +victor polk 65625 1.04 +victor quirinius 65620 1.32 +victor quirinius 65651 3.15 +victor robinson 65596 0.92 +victor robinson 65673 1.76 +victor steinbeck 65618 2.87 +victor steinbeck 65661 2.19 +victor steinbeck 65686 2.81 +victor thompson 65548 1.59 +victor van buren 65664 4.44 +victor van buren 65774 4.06 +victor white 65548 4.67 +victor white 65601 3.87 +victor xylophone 65549 3.8 +victor xylophone 65618 2.13 +victor xylophone 65644 1.59 +victor xylophone 65677 1.89 +victor xylophone 65755 2.27 +victor young 65628 3.16 +victor zipper 65743 3.98 +wendy allen 65628 3.8 +wendy allen 65711 3.44 +wendy allen 65782 2.4 +wendy brown 65580 4.67 +wendy brown 65657 4.68 +wendy ellison 65545 1.51 +wendy ellison 65603 1.6 +wendy falkner 65595 0.58 +wendy falkner 65604 0.82 +wendy falkner 65635 1.59 +wendy garcia 65659 3.47 +wendy garcia 65746 2.35 +wendy garcia 65747 1.67 +wendy garcia 65777 1.32 +wendy hernandez 65650 2.26 +wendy ichabod 65730 0.44 +wendy king 65586 4.46 +wendy king 65664 4.25 +wendy king 65670 2.94 +wendy laertes 65566 3.13 +wendy laertes 65683 3.99 +wendy laertes 65727 3.57 +wendy miller 65582 1.53 +wendy miller 65626 1.4 +wendy nixon 65611 0.26 +wendy nixon 65746 2.27 +wendy ovid 65589 4.75 +wendy ovid 65643 3.42 +wendy polk 65656 0.62 +wendy polk 65692 1.36 +wendy quirinius 65766 1.35 +wendy quirinius 65767 2.76 +wendy robinson 65622 0.85 +wendy robinson 65715 2.13 +wendy robinson 65774 1.8 +wendy steinbeck 65612 0.07 +wendy thompson 65650 2.27 +wendy thompson 65737 3.2 +wendy underhill 65662 4.55 +wendy underhill 65758 2.84 +wendy underhill 65775 2.54 +wendy van buren 65680 1.1 +wendy van buren 65699 1.0 +wendy white 65705 0.5 +wendy xylophone 65687 0.46 +wendy xylophone 65773 1.39 +wendy young 65674 0.48 +wendy young 65685 1.31 +xavier allen 65611 1.53 +xavier allen 65618 2.07 +xavier allen 65771 2.33 +xavier brown 65600 0.89 +xavier brown 65704 0.58 +xavier brown 65723 1.55 +xavier carson 65731 4.42 +xavier carson 65758 3.91 +xavier davidson 65644 1.84 +xavier davidson 65664 3.4 +xavier davidson 65755 2.67 +xavier ellison 65541 1.47 +xavier ellison 65654 2.49 +xavier garcia 65672 2.76 +xavier hernandez 65541 0.96 +xavier hernandez 65544 1.47 +xavier hernandez 65766 1.28 +xavier ichabod 65597 4.76 +xavier ichabod 65663 2.99 +xavier johnson 65654 1.34 +xavier johnson 65744 3.06 +xavier king 65590 2.68 +xavier king 65601 1.4 +xavier laertes 65743 0.75 +xavier ovid 65788 2.06 +xavier polk 65587 0.99 +xavier polk 65653 1.15 +xavier polk 65675 1.9 +xavier polk 65696 1.93 +xavier quirinius 65599 1.66 +xavier quirinius 65650 1.94 +xavier quirinius 65656 2.46 +xavier quirinius 65737 1.92 +xavier thompson 65608 1.65 +xavier underhill 65710 0.13 +xavier white 65703 0.98 +xavier white 65732 2.22 +xavier xylophone 65572 1.0 +xavier zipper 65561 0.94 +yuri allen 65565 2.03 +yuri allen 65682 1.14 +yuri brown 65538 2.73 +yuri brown 65688 2.02 +yuri carson 65670 3.06 +yuri carson 65769 3.43 +yuri ellison 65570 1.05 +yuri ellison 65581 1.68 +yuri falkner 65658 2.85 +yuri falkner 65681 2.14 +yuri garcia 65639 3.41 +yuri hernandez 65706 1.64 +yuri johnson 65587 1.27 +yuri johnson 65697 1.44 +yuri johnson 65712 2.29 +yuri king 65721 0.33 +yuri laertes 65637 4.3 +yuri laertes 65773 2.15 +yuri nixon 65635 4.02 +yuri nixon 65740 4.18 +yuri polk 65607 0.08 +yuri polk 65713 0.37 +yuri polk 65742 1.25 +yuri quirinius 65544 2.58 +yuri quirinius 65617 2.1 +yuri quirinius 65695 1.91 +yuri steinbeck 65592 4.89 +yuri steinbeck 65679 3.24 +yuri thompson 65676 2.67 +yuri underhill 65718 2.86 +yuri underhill 65750 2.51 +yuri white 65659 4.59 +yuri xylophone 65714 2.53 +zach allen 65667 0.88 +zach brown 65559 4.88 +zach brown 65588 4.53 +zach brown 65691 3.49 +zach brown 65759 3.4 +zach brown 65762 3.55 +zach carson 65572 2.03 +zach ellison 65748 1.76 +zach falkner 65620 0.34 +zach falkner 65627 0.25 +zach garcia 65544 0.99 +zach garcia 65623 2.84 +zach garcia 65629 3.01 +zach garcia 65786 2.55 +zach ichabod 65599 3.36 +zach ichabod 65612 1.92 +zach king 65556 2.36 +zach king 65702 1.52 +zach king 65773 2.58 +zach miller 65584 1.6 +zach miller 65665 0.99 +zach miller 65719 1.55 +zach ovid 65578 1.51 +zach ovid 65703 1.92 +zach ovid 65750 2.63 +zach ovid 65784 2.72 +zach quirinius 65691 2.95 +zach robinson 65599 2.87 +zach steinbeck 65602 2.45 +zach steinbeck 65695 1.86 +zach thompson 65636 0.25 +zach thompson 65696 0.51 +zach underhill 65573 3.97 +zach white 65733 2.31 +zach xylophone 65542 1.69 +zach xylophone 65780 0.88 +zach young 65576 1.82 +zach zipper 65579 4.5 +zach zipper 65649 4.02 +zach zipper 65676 3.12 diff --git a/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-8-45a1d7c2aba45d761e19ff4dfdf5463e b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-8-45a1d7c2aba45d761e19ff4dfdf5463e new file mode 100644 index 000000000000..84b934fad85b --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing_windowspec.q (deterministic)-8-45a1d7c2aba45d761e19ff4dfdf5463e @@ -0,0 +1,1049 @@ + 65560 20.0 + 65718 20.0 + 65740 20.0 +alice allen 65662 20.0 +alice allen 65720 20.0 +alice allen 65758 20.0 +alice brown 65696 20.0 +alice carson 65559 20.0 +alice davidson 65547 20.0 +alice falkner 65669 20.0 +alice garcia 65613 20.0 +alice hernandez 65737 20.0 +alice hernandez 65784 20.0 +alice johnson 65739 20.0 +alice king 65660 20.0 +alice king 65734 20.0 +alice king 65738 20.0 +alice laertes 65669 20.0 +alice laertes 65671 20.0 +alice miller 65590 20.0 +alice nixon 65586 20.0 +alice nixon 65595 20.0 +alice nixon 65604 20.0 +alice ovid 65737 20.0 +alice polk 65548 20.0 +alice quirinius 65636 20.0 +alice quirinius 65728 20.0 +alice robinson 65606 20.0 +alice robinson 65789 20.0 +alice steinbeck 65578 20.0 +alice steinbeck 65673 20.0 +alice steinbeck 65786 20.0 +alice underhill 65750 20.0 +alice van buren 65562 20.0 +alice xylophone 65578 20.0 +alice xylophone 65585 20.0 +alice xylophone 65599 20.0 +alice zipper 65553 20.0 +alice zipper 65662 20.0 +alice zipper 65766 20.0 +bob brown 65584 20.0 +bob brown 65777 20.0 +bob brown 65783 20.0 +bob carson 65713 20.0 +bob davidson 65664 20.0 +bob davidson 65693 20.0 +bob davidson 65768 20.0 +bob ellison 65591 20.0 +bob ellison 65624 20.0 +bob ellison 65721 20.0 +bob ellison 65760 20.0 +bob falkner 65789 20.0 +bob garcia 65585 20.0 +bob garcia 65598 20.0 +bob garcia 65673 20.0 +bob garcia 65754 20.0 +bob garcia 65782 20.0 +bob hernandez 65557 20.0 +bob ichabod 65549 20.0 +bob king 65715 20.0 +bob king 65757 20.0 +bob king 65783 20.0 +bob laertes 65602 20.0 +bob laertes 65663 20.0 +bob miller 65608 20.0 +bob ovid 65564 20.0 +bob ovid 65619 20.0 +bob ovid 65686 20.0 +bob ovid 65726 20.0 +bob polk 65594 20.0 +bob quirinius 65700 20.0 +bob steinbeck 65637 20.0 +bob van buren 65778 20.0 +bob white 65543 20.0 +bob white 65605 20.0 +bob xylophone 65574 20.0 +bob xylophone 65666 20.0 +bob young 65556 20.0 +bob zipper 65559 20.0 +bob zipper 65633 20.0 +bob zipper 65739 20.0 +calvin allen 65669 20.0 +calvin brown 65537 20.0 +calvin brown 65580 20.0 +calvin brown 65677 20.0 +calvin carson 65637 20.0 +calvin davidson 65541 20.0 +calvin davidson 65564 20.0 +calvin ellison 65667 20.0 +calvin falkner 65573 20.0 +calvin falkner 65596 20.0 +calvin falkner 65738 20.0 +calvin falkner 65762 20.0 +calvin falkner 65778 20.0 +calvin falkner 65784 20.0 +calvin garcia 65664 20.0 +calvin hernandez 65578 20.0 +calvin johnson 65731 20.0 +calvin laertes 65570 20.0 +calvin laertes 65684 20.0 +calvin nixon 65654 20.0 +calvin nixon 65724 20.0 +calvin nixon 65749 20.0 +calvin ovid 65554 20.0 +calvin ovid 65643 20.0 +calvin ovid 65663 20.0 +calvin ovid 65715 20.0 +calvin polk 65731 20.0 +calvin quirinius 65741 20.0 +calvin quirinius 65769 20.0 +calvin robinson 65581 20.0 +calvin steinbeck 65680 20.0 +calvin steinbeck 65762 20.0 +calvin steinbeck 65779 20.0 +calvin thompson 65560 20.0 +calvin thompson 65640 20.0 +calvin underhill 65732 20.0 +calvin van buren 65552 20.0 +calvin van buren 65771 20.0 +calvin white 65553 20.0 +calvin white 65561 20.0 +calvin xylophone 65575 20.0 +calvin xylophone 65596 20.0 +calvin xylophone 65713 20.0 +calvin young 65574 20.0 +calvin young 65746 20.0 +calvin zipper 65669 20.0 +calvin zipper 65739 20.0 +david allen 65588 20.0 +david allen 65617 20.0 +david brown 65637 20.0 +david brown 65760 20.0 +david davidson 65559 20.0 +david davidson 65756 20.0 +david davidson 65778 20.0 +david davidson 65779 20.0 +david ellison 65634 20.0 +david ellison 65724 20.0 +david ellison 65724 20.0 +david hernandez 65763 20.0 +david ichabod 65699 20.0 +david ichabod 65715 20.0 +david laertes 65762 20.0 +david nixon 65536 20.0 +david ovid 65623 20.0 +david ovid 65628 20.0 +david quirinius 65697 20.0 +david quirinius 65759 20.0 +david quirinius 65779 20.0 +david robinson 65762 20.0 +david robinson 65775 20.0 +david thompson 65550 20.0 +david underhill 65602 20.0 +david underhill 65662 20.0 +david underhill 65751 20.0 +david van buren 65625 20.0 +david van buren 65634 20.0 +david white 65678 20.0 +david xylophone 65537 20.0 +david xylophone 65670 20.0 +david xylophone 65764 20.0 +david young 65551 20.0 +david young 65694 20.0 +ethan allen 65747 20.0 +ethan brown 65539 20.0 +ethan brown 65617 20.0 +ethan brown 65685 20.0 +ethan brown 65685 20.0 +ethan brown 65722 20.0 +ethan brown 65733 20.0 +ethan carson 65742 20.0 +ethan ellison 65714 20.0 +ethan ellison 65732 20.0 +ethan falkner 65577 20.0 +ethan falkner 65614 20.0 +ethan garcia 65736 20.0 +ethan hernandez 65618 20.0 +ethan johnson 65536 20.0 +ethan king 65614 20.0 +ethan laertes 65562 20.0 +ethan laertes 65597 20.0 +ethan laertes 65628 20.0 +ethan laertes 65643 20.0 +ethan laertes 65680 20.0 +ethan laertes 65745 20.0 +ethan laertes 65760 20.0 +ethan miller 65712 20.0 +ethan nixon 65766 20.0 +ethan ovid 65697 20.0 +ethan polk 65589 20.0 +ethan polk 65615 20.0 +ethan polk 65622 20.0 +ethan polk 65622 20.0 +ethan quirinius 65542 20.0 +ethan quirinius 65591 20.0 +ethan quirinius 65706 20.0 +ethan robinson 65547 20.0 +ethan robinson 65659 20.0 +ethan underhill 65570 20.0 +ethan van buren 65572 20.0 +ethan white 65677 20.0 +ethan white 65788 20.0 +ethan xylophone 65595 20.0 +ethan zipper 65593 20.0 +ethan zipper 65680 20.0 +fred davidson 65552 20.0 +fred davidson 65595 20.0 +fred davidson 65721 20.0 +fred ellison 65548 20.0 +fred ellison 65691 20.0 +fred ellison 65771 20.0 +fred falkner 65637 20.0 +fred falkner 65648 20.0 +fred falkner 65783 20.0 +fred hernandez 65541 20.0 +fred ichabod 65572 20.0 +fred ichabod 65789 20.0 +fred johnson 65758 20.0 +fred king 65694 20.0 +fred king 65745 20.0 +fred laertes 65769 20.0 +fred miller 65536 20.0 +fred nixon 65560 20.0 +fred nixon 65612 20.0 +fred nixon 65703 20.0 +fred nixon 65705 20.0 +fred polk 65603 20.0 +fred polk 65656 20.0 +fred polk 65701 20.0 +fred polk 65706 20.0 +fred quirinius 65697 20.0 +fred quirinius 65782 20.0 +fred robinson 65623 20.0 +fred steinbeck 65544 20.0 +fred steinbeck 65651 20.0 +fred steinbeck 65755 20.0 +fred underhill 65629 20.0 +fred van buren 65537 20.0 +fred van buren 65561 20.0 +fred van buren 65745 20.0 +fred van buren 65789 20.0 +fred white 65589 20.0 +fred young 65594 20.0 +fred young 65773 20.0 +fred zipper 65553 20.0 +gabriella allen 65646 20.0 +gabriella allen 65677 20.0 +gabriella brown 65704 20.0 +gabriella brown 65753 20.0 +gabriella carson 65586 20.0 +gabriella davidson 65565 20.0 +gabriella ellison 65706 20.0 +gabriella ellison 65716 20.0 +gabriella falkner 65623 20.0 +gabriella falkner 65711 20.0 +gabriella falkner 65767 20.0 +gabriella garcia 65571 20.0 +gabriella hernandez 65587 20.0 +gabriella hernandez 65717 20.0 +gabriella ichabod 65559 20.0 +gabriella ichabod 65633 20.0 +gabriella ichabod 65702 20.0 +gabriella ichabod 65712 20.0 +gabriella ichabod 65717 20.0 +gabriella king 65651 20.0 +gabriella king 65657 20.0 +gabriella laertes 65781 20.0 +gabriella miller 65646 20.0 +gabriella ovid 65556 20.0 +gabriella ovid 65583 20.0 +gabriella polk 65701 20.0 +gabriella polk 65790 20.0 +gabriella steinbeck 65582 20.0 +gabriella steinbeck 65653 20.0 +gabriella thompson 65682 20.0 +gabriella thompson 65755 20.0 +gabriella thompson 65766 20.0 +gabriella van buren 65581 20.0 +gabriella van buren 65644 20.0 +gabriella white 65638 20.0 +gabriella young 65699 20.0 +gabriella young 65774 20.0 +gabriella zipper 65540 20.0 +gabriella zipper 65754 20.0 +holly allen 65596 20.0 +holly brown 65599 20.0 +holly brown 65619 20.0 +holly falkner 65720 20.0 +holly hernandez 65602 20.0 +holly hernandez 65686 20.0 +holly hernandez 65750 20.0 +holly hernandez 65788 20.0 +holly ichabod 65711 20.0 +holly ichabod 65749 20.0 +holly ichabod 65752 20.0 +holly johnson 65655 20.0 +holly johnson 65662 20.0 +holly johnson 65755 20.0 +holly king 65549 20.0 +holly king 65648 20.0 +holly laertes 65664 20.0 +holly miller 65653 20.0 +holly nixon 65539 20.0 +holly nixon 65658 20.0 +holly polk 65743 20.0 +holly polk 65751 20.0 +holly robinson 65564 20.0 +holly thompson 65538 20.0 +holly thompson 65578 20.0 +holly thompson 65713 20.0 +holly underhill 65634 20.0 +holly underhill 65654 20.0 +holly underhill 65721 20.0 +holly underhill 65759 20.0 +holly van buren 65727 20.0 +holly white 65536 20.0 +holly white 65602 20.0 +holly xylophone 65544 20.0 +holly young 65606 20.0 +holly young 65765 20.0 +holly zipper 65607 20.0 +holly zipper 65755 20.0 +irene allen 65556 20.0 +irene brown 65633 20.0 +irene brown 65650 20.0 +irene brown 65765 20.0 +irene carson 65590 20.0 +irene ellison 65659 20.0 +irene ellison 65696 20.0 +irene falkner 65620 20.0 +irene falkner 65661 20.0 +irene garcia 65660 20.0 +irene garcia 65711 20.0 +irene garcia 65787 20.0 +irene ichabod 65645 20.0 +irene ichabod 65722 20.0 +irene johnson 65583 20.0 +irene laertes 65664 20.0 +irene laertes 65710 20.0 +irene laertes 65722 20.0 +irene miller 65730 20.0 +irene nixon 65631 20.0 +irene nixon 65643 20.0 +irene nixon 65653 20.0 +irene ovid 65691 20.0 +irene ovid 65734 20.0 +irene ovid 65753 20.0 +irene polk 65551 20.0 +irene polk 65575 20.0 +irene polk 65579 20.0 +irene polk 65595 20.0 +irene polk 65610 20.0 +irene quirinius 65724 20.0 +irene quirinius 65769 20.0 +irene quirinius 65773 20.0 +irene robinson 65554 20.0 +irene steinbeck 65683 20.0 +irene thompson 65688 20.0 +irene underhill 65591 20.0 +irene underhill 65787 20.0 +irene van buren 65579 20.0 +irene van buren 65589 20.0 +irene xylophone 65775 20.0 +jessica brown 65588 20.0 +jessica carson 65553 20.0 +jessica carson 65672 20.0 +jessica carson 65747 20.0 +jessica davidson 65549 20.0 +jessica davidson 65606 20.0 +jessica davidson 65675 20.0 +jessica davidson 65727 20.0 +jessica ellison 65567 20.0 +jessica ellison 65663 20.0 +jessica falkner 65584 20.0 +jessica garcia 65676 20.0 +jessica garcia 65789 20.0 +jessica ichabod 65704 20.0 +jessica johnson 65607 20.0 +jessica johnson 65720 20.0 +jessica miller 65733 20.0 +jessica nixon 65590 20.0 +jessica nixon 65774 20.0 +jessica ovid 65582 20.0 +jessica ovid 65751 20.0 +jessica polk 65637 20.0 +jessica quirinius 65562 20.0 +jessica quirinius 65608 20.0 +jessica quirinius 65712 20.0 +jessica quirinius 65716 20.0 +jessica robinson 65576 20.0 +jessica thompson 65581 20.0 +jessica thompson 65675 20.0 +jessica underhill 65656 20.0 +jessica underhill 65702 20.0 +jessica underhill 65783 20.0 +jessica van buren 65615 20.0 +jessica white 65544 20.0 +jessica white 65570 20.0 +jessica white 65594 20.0 +jessica white 65673 20.0 +jessica white 65779 20.0 +jessica xylophone 65562 20.0 +jessica young 65623 20.0 +jessica young 65711 20.0 +jessica zipper 65600 20.0 +jessica zipper 65657 20.0 +jessica zipper 65778 20.0 +katie allen 65542 20.0 +katie brown 65590 20.0 +katie davidson 65619 20.0 +katie ellison 65675 20.0 +katie ellison 65699 20.0 +katie falkner 65728 20.0 +katie garcia 65625 20.0 +katie garcia 65747 20.0 +katie hernandez 65550 20.0 +katie ichabod 65658 20.0 +katie ichabod 65726 20.0 +katie ichabod 65757 20.0 +katie king 65629 20.0 +katie king 65647 20.0 +katie king 65776 20.0 +katie miller 65541 20.0 +katie miller 65661 20.0 +katie nixon 65669 20.0 +katie ovid 65681 20.0 +katie polk 65746 20.0 +katie polk 65784 20.0 +katie robinson 65697 20.0 +katie van buren 65643 20.0 +katie van buren 65730 20.0 +katie white 65620 20.0 +katie white 65719 20.0 +katie xylophone 65585 20.0 +katie young 65644 20.0 +katie young 65746 20.0 +katie young 65764 20.0 +katie zipper 65568 20.0 +katie zipper 65733 20.0 +luke allen 65547 20.0 +luke allen 65552 20.0 +luke allen 65576 20.0 +luke allen 65681 20.0 +luke allen 65776 20.0 +luke brown 65719 20.0 +luke davidson 65656 20.0 +luke davidson 65791 20.0 +luke ellison 65582 20.0 +luke ellison 65664 20.0 +luke ellison 65779 20.0 +luke falkner 65589 20.0 +luke falkner 65618 20.0 +luke garcia 65687 20.0 +luke garcia 65778 20.0 +luke ichabod 65629 20.0 +luke ichabod 65654 20.0 +luke johnson 65545 20.0 +luke johnson 65716 20.0 +luke johnson 65718 20.0 +luke laertes 65608 20.0 +luke laertes 65657 20.0 +luke laertes 65685 20.0 +luke laertes 65730 20.0 +luke laertes 65756 20.0 +luke miller 65752 20.0 +luke ovid 65569 20.0 +luke ovid 65693 20.0 +luke polk 65645 20.0 +luke polk 65658 20.0 +luke quirinius 65655 20.0 +luke robinson 65634 20.0 +luke robinson 65772 20.0 +luke thompson 65626 20.0 +luke underhill 65553 20.0 +luke underhill 65571 20.0 +luke underhill 65651 20.0 +luke van buren 65678 20.0 +luke white 65693 20.0 +luke xylophone 65597 20.0 +luke zipper 65641 20.0 +mike allen 65706 20.0 +mike brown 65654 20.0 +mike carson 65698 20.0 +mike carson 65700 20.0 +mike carson 65751 20.0 +mike davidson 65658 20.0 +mike davidson 65759 20.0 +mike ellison 65598 20.0 +mike ellison 65606 20.0 +mike ellison 65718 20.0 +mike ellison 65738 20.0 +mike ellison 65760 20.0 +mike falkner 65609 20.0 +mike garcia 65571 20.0 +mike garcia 65600 20.0 +mike garcia 65770 20.0 +mike hernandez 65548 20.0 +mike hernandez 65672 20.0 +mike ichabod 65621 20.0 +mike king 65563 20.0 +mike king 65586 20.0 +mike king 65591 20.0 +mike king 65642 20.0 +mike king 65769 20.0 +mike king 65776 20.0 +mike miller 65549 20.0 +mike nixon 65619 20.0 +mike nixon 65704 20.0 +mike polk 65619 20.0 +mike polk 65658 20.0 +mike polk 65704 20.0 +mike quirinius 65717 20.0 +mike steinbeck 65550 20.0 +mike steinbeck 65564 20.0 +mike steinbeck 65573 20.0 +mike steinbeck 65749 20.0 +mike van buren 65620 20.0 +mike van buren 65770 20.0 +mike white 65648 20.0 +mike white 65685 20.0 +mike white 65769 20.0 +mike white 65778 20.0 +mike young 65545 20.0 +mike young 65581 20.0 +mike young 65736 20.0 +mike zipper 65552 20.0 +mike zipper 65695 20.0 +mike zipper 65779 20.0 +nick allen 65641 20.0 +nick allen 65786 20.0 +nick brown 65724 20.0 +nick davidson 65601 20.0 +nick ellison 65691 20.0 +nick ellison 65745 20.0 +nick falkner 65583 20.0 +nick falkner 65676 20.0 +nick garcia 65695 20.0 +nick garcia 65712 20.0 +nick garcia 65720 20.0 +nick ichabod 65572 20.0 +nick ichabod 65681 20.0 +nick ichabod 65737 20.0 +nick johnson 65585 20.0 +nick johnson 65784 20.0 +nick laertes 65624 20.0 +nick miller 65757 20.0 +nick nixon 65650 20.0 +nick ovid 65719 20.0 +nick polk 65716 20.0 +nick quirinius 65588 20.0 +nick quirinius 65723 20.0 +nick robinson 65547 20.0 +nick robinson 65675 20.0 +nick steinbeck 65689 20.0 +nick thompson 65610 20.0 +nick underhill 65619 20.0 +nick van buren 65603 20.0 +nick xylophone 65644 20.0 +nick young 65654 20.0 +nick young 65660 20.0 +nick zipper 65757 20.0 +nick zipper 65765 20.0 +oscar allen 65644 20.0 +oscar brown 65614 20.0 +oscar carson 65537 20.0 +oscar carson 65548 20.0 +oscar carson 65549 20.0 +oscar carson 65624 20.0 +oscar carson 65697 20.0 +oscar davidson 65556 20.0 +oscar ellison 65630 20.0 +oscar ellison 65630 20.0 +oscar falkner 65692 20.0 +oscar garcia 65751 20.0 +oscar hernandez 65683 20.0 +oscar hernandez 65707 20.0 +oscar ichabod 65536 20.0 +oscar ichabod 65562 20.0 +oscar ichabod 65637 20.0 +oscar ichabod 65763 20.0 +oscar johnson 65645 20.0 +oscar johnson 65778 20.0 +oscar king 65541 20.0 +oscar king 65550 20.0 +oscar king 65787 20.0 +oscar laertes 65625 20.0 +oscar laertes 65690 20.0 +oscar laertes 65756 20.0 +oscar laertes 65790 20.0 +oscar nixon 65596 20.0 +oscar ovid 65536 20.0 +oscar ovid 65615 20.0 +oscar ovid 65659 20.0 +oscar polk 65541 20.0 +oscar polk 65643 20.0 +oscar quirinius 65541 20.0 +oscar quirinius 65560 20.0 +oscar quirinius 65689 20.0 +oscar quirinius 65720 20.0 +oscar robinson 65537 20.0 +oscar robinson 65658 20.0 +oscar robinson 65687 20.0 +oscar robinson 65782 20.0 +oscar steinbeck 65709 20.0 +oscar thompson 65542 20.0 +oscar thompson 65698 20.0 +oscar thompson 65727 20.0 +oscar thompson 65738 20.0 +oscar underhill 65626 20.0 +oscar van buren 65581 20.0 +oscar van buren 65635 20.0 +oscar van buren 65705 20.0 +oscar white 65552 20.0 +oscar white 65564 20.0 +oscar white 65671 20.0 +oscar white 65735 20.0 +oscar xylophone 65773 20.0 +oscar xylophone 65773 20.0 +oscar xylophone 65775 20.0 +oscar zipper 65568 20.0 +oscar zipper 65740 20.0 +oscar zipper 65777 20.0 +priscilla brown 65670 20.0 +priscilla brown 65690 20.0 +priscilla brown 65749 20.0 +priscilla carson 65658 20.0 +priscilla carson 65687 20.0 +priscilla carson 65755 20.0 +priscilla ichabod 65627 20.0 +priscilla ichabod 65759 20.0 +priscilla johnson 65543 20.0 +priscilla johnson 65633 20.0 +priscilla johnson 65668 20.0 +priscilla johnson 65681 20.0 +priscilla johnson 65755 20.0 +priscilla king 65646 20.0 +priscilla nixon 65564 20.0 +priscilla nixon 65600 20.0 +priscilla ovid 65541 20.0 +priscilla ovid 65790 20.0 +priscilla polk 65747 20.0 +priscilla quirinius 65672 20.0 +priscilla thompson 65654 20.0 +priscilla underhill 65715 20.0 +priscilla underhill 65729 20.0 +priscilla van buren 65607 20.0 +priscilla van buren 65685 20.0 +priscilla van buren 65749 20.0 +priscilla white 65652 20.0 +priscilla xylophone 65538 20.0 +priscilla xylophone 65763 20.0 +priscilla xylophone 65774 20.0 +priscilla young 65585 20.0 +priscilla young 65658 20.0 +priscilla zipper 65622 20.0 +priscilla zipper 65726 20.0 +quinn allen 65657 20.0 +quinn allen 65708 20.0 +quinn brown 65691 20.0 +quinn brown 65700 20.0 +quinn brown 65733 20.0 +quinn davidson 65549 20.0 +quinn davidson 65714 20.0 +quinn davidson 65776 20.0 +quinn davidson 65779 20.0 +quinn ellison 65705 20.0 +quinn ellison 65778 20.0 +quinn garcia 65568 20.0 +quinn garcia 65604 20.0 +quinn garcia 65610 20.0 +quinn garcia 65773 20.0 +quinn ichabod 65564 20.0 +quinn king 65558 20.0 +quinn king 65649 20.0 +quinn laertes 65542 20.0 +quinn laertes 65560 20.0 +quinn laertes 65627 20.0 +quinn nixon 65659 20.0 +quinn ovid 65699 20.0 +quinn quirinius 65747 20.0 +quinn robinson 65627 20.0 +quinn steinbeck 65578 20.0 +quinn steinbeck 65763 20.0 +quinn thompson 65643 20.0 +quinn thompson 65774 20.0 +quinn underhill 65549 20.0 +quinn underhill 65694 20.0 +quinn underhill 65767 20.0 +quinn van buren 65725 20.0 +quinn young 65647 20.0 +quinn zipper 65579 20.0 +quinn zipper 65693 20.0 +rachel allen 65661 20.0 +rachel allen 65709 20.0 +rachel brown 65586 20.0 +rachel brown 65587 20.0 +rachel brown 65587 20.0 +rachel brown 65610 20.0 +rachel brown 65693 20.0 +rachel carson 65677 20.0 +rachel carson 65682 20.0 +rachel davidson 65755 20.0 +rachel ellison 65761 20.0 +rachel falkner 65616 20.0 +rachel falkner 65681 20.0 +rachel falkner 65693 20.0 +rachel falkner 65764 20.0 +rachel johnson 65658 20.0 +rachel king 65604 20.0 +rachel king 65643 20.0 +rachel laertes 65562 20.0 +rachel laertes 65624 20.0 +rachel ovid 65721 20.0 +rachel ovid 65736 20.0 +rachel polk 65686 20.0 +rachel quirinius 65787 20.0 +rachel robinson 65544 20.0 +rachel robinson 65717 20.0 +rachel robinson 65724 20.0 +rachel thompson 65648 20.0 +rachel thompson 65662 20.0 +rachel thompson 65733 20.0 +rachel underhill 65667 20.0 +rachel white 65615 20.0 +rachel white 65717 20.0 +rachel young 65727 20.0 +rachel zipper 65757 20.0 +rachel zipper 65785 20.0 +sarah carson 65679 20.0 +sarah carson 65693 20.0 +sarah carson 65694 20.0 +sarah ellison 65611 20.0 +sarah falkner 65606 20.0 +sarah falkner 65680 20.0 +sarah garcia 65563 20.0 +sarah garcia 65638 20.0 +sarah garcia 65661 20.0 +sarah ichabod 65667 20.0 +sarah ichabod 65671 20.0 +sarah johnson 65659 20.0 +sarah johnson 65716 20.0 +sarah johnson 65731 20.0 +sarah johnson 65751 20.0 +sarah king 65650 20.0 +sarah king 65699 20.0 +sarah miller 65557 20.0 +sarah ovid 65550 20.0 +sarah robinson 65677 20.0 +sarah robinson 65763 20.0 +sarah steinbeck 65721 20.0 +sarah white 65622 20.0 +sarah white 65747 20.0 +sarah xylophone 65678 20.0 +sarah young 65595 20.0 +sarah zipper 65550 20.0 +tom brown 65593 20.0 +tom brown 65675 20.0 +tom carson 65539 20.0 +tom carson 65624 20.0 +tom carson 65780 20.0 +tom davidson 65780 20.0 +tom ellison 65578 20.0 +tom ellison 65670 20.0 +tom ellison 65756 20.0 +tom falkner 65574 20.0 +tom falkner 65583 20.0 +tom hernandez 65575 20.0 +tom hernandez 65632 20.0 +tom ichabod 65588 20.0 +tom johnson 65536 20.0 +tom johnson 65789 20.0 +tom king 65576 20.0 +tom laertes 65617 20.0 +tom laertes 65701 20.0 +tom miller 65594 20.0 +tom miller 65603 20.0 +tom miller 65704 20.0 +tom nixon 65672 20.0 +tom ovid 65628 20.0 +tom polk 65652 20.0 +tom polk 65742 20.0 +tom quirinius 65563 20.0 +tom quirinius 65783 20.0 +tom robinson 65626 20.0 +tom robinson 65632 20.0 +tom robinson 65691 20.0 +tom robinson 65758 20.0 +tom steinbeck 65666 20.0 +tom van buren 65621 20.0 +tom van buren 65652 20.0 +tom van buren 65669 20.0 +tom white 65548 20.0 +tom young 65544 20.0 +tom young 65546 20.0 +tom zipper 65789 20.0 +ulysses brown 65735 20.0 +ulysses carson 65602 20.0 +ulysses carson 65643 20.0 +ulysses carson 65703 20.0 +ulysses carson 65716 20.0 +ulysses davidson 65750 20.0 +ulysses ellison 65575 20.0 +ulysses garcia 65666 20.0 +ulysses hernandez 65651 20.0 +ulysses hernandez 65702 20.0 +ulysses hernandez 65786 20.0 +ulysses ichabod 65551 20.0 +ulysses ichabod 65566 20.0 +ulysses johnson 65776 20.0 +ulysses king 65649 20.0 +ulysses laertes 65691 20.0 +ulysses laertes 65711 20.0 +ulysses laertes 65781 20.0 +ulysses miller 65610 20.0 +ulysses miller 65637 20.0 +ulysses nixon 65603 20.0 +ulysses ovid 65656 20.0 +ulysses polk 65563 20.0 +ulysses polk 65580 20.0 +ulysses polk 65612 20.0 +ulysses polk 65777 20.0 +ulysses quirinius 65786 20.0 +ulysses robinson 65744 20.0 +ulysses steinbeck 65611 20.0 +ulysses steinbeck 65680 20.0 +ulysses thompson 65788 20.0 +ulysses underhill 65570 20.0 +ulysses underhill 65616 20.0 +ulysses underhill 65620 20.0 +ulysses underhill 65623 20.0 +ulysses underhill 65641 20.0 +ulysses underhill 65713 20.0 +ulysses underhill 65785 20.0 +ulysses van buren 65684 20.0 +ulysses white 65654 20.0 +ulysses white 65675 20.0 +ulysses xylophone 65623 20.0 +ulysses xylophone 65636 20.0 +ulysses xylophone 65781 20.0 +ulysses young 65675 20.0 +ulysses young 65736 20.0 +ulysses young 65748 20.0 +victor allen 65684 20.0 +victor allen 65707 20.0 +victor brown 65550 20.0 +victor brown 65555 20.0 +victor brown 65622 20.0 +victor brown 65673 20.0 +victor davidson 65579 20.0 +victor davidson 65628 20.0 +victor davidson 65783 20.0 +victor ellison 65641 20.0 +victor ellison 65782 20.0 +victor hernandez 65571 20.0 +victor hernandez 65659 20.0 +victor hernandez 65708 20.0 +victor hernandez 65735 20.0 +victor hernandez 65775 20.0 +victor johnson 65606 20.0 +victor johnson 65607 20.0 +victor johnson 65607 20.0 +victor king 65721 20.0 +victor king 65743 20.0 +victor laertes 65638 20.0 +victor laertes 65644 20.0 +victor miller 65570 20.0 +victor nixon 65709 20.0 +victor nixon 65791 20.0 +victor ovid 65649 20.0 +victor polk 65625 20.0 +victor quirinius 65620 20.0 +victor quirinius 65651 20.0 +victor robinson 65596 20.0 +victor robinson 65673 20.0 +victor steinbeck 65618 20.0 +victor steinbeck 65661 20.0 +victor steinbeck 65686 20.0 +victor thompson 65548 20.0 +victor van buren 65664 20.0 +victor van buren 65774 20.0 +victor white 65548 20.0 +victor white 65601 20.0 +victor xylophone 65549 20.0 +victor xylophone 65618 20.0 +victor xylophone 65644 20.0 +victor xylophone 65677 20.0 +victor xylophone 65755 20.0 +victor young 65628 20.0 +victor zipper 65743 20.0 +wendy allen 65628 20.0 +wendy allen 65711 20.0 +wendy allen 65782 20.0 +wendy brown 65580 20.0 +wendy brown 65657 20.0 +wendy ellison 65545 20.0 +wendy ellison 65603 20.0 +wendy falkner 65595 20.0 +wendy falkner 65604 20.0 +wendy falkner 65635 20.0 +wendy garcia 65659 20.0 +wendy garcia 65746 20.0 +wendy garcia 65747 20.0 +wendy garcia 65777 20.0 +wendy hernandez 65650 20.0 +wendy ichabod 65730 20.0 +wendy king 65586 20.0 +wendy king 65664 20.0 +wendy king 65670 20.0 +wendy laertes 65566 20.0 +wendy laertes 65683 20.0 +wendy laertes 65727 20.0 +wendy miller 65582 20.0 +wendy miller 65626 20.0 +wendy nixon 65611 20.0 +wendy nixon 65746 20.0 +wendy ovid 65589 20.0 +wendy ovid 65643 20.0 +wendy polk 65656 20.0 +wendy polk 65692 20.0 +wendy quirinius 65766 20.0 +wendy quirinius 65767 20.0 +wendy robinson 65622 20.0 +wendy robinson 65715 20.0 +wendy robinson 65774 20.0 +wendy steinbeck 65612 20.0 +wendy thompson 65650 20.0 +wendy thompson 65737 20.0 +wendy underhill 65662 20.0 +wendy underhill 65758 20.0 +wendy underhill 65775 20.0 +wendy van buren 65680 20.0 +wendy van buren 65699 20.0 +wendy white 65705 20.0 +wendy xylophone 65687 20.0 +wendy xylophone 65773 20.0 +wendy young 65674 20.0 +wendy young 65685 20.0 +xavier allen 65611 20.0 +xavier allen 65618 20.0 +xavier allen 65771 20.0 +xavier brown 65600 20.0 +xavier brown 65704 20.0 +xavier brown 65723 20.0 +xavier carson 65731 20.0 +xavier carson 65758 20.0 +xavier davidson 65644 20.0 +xavier davidson 65664 20.0 +xavier davidson 65755 20.0 +xavier ellison 65541 20.0 +xavier ellison 65654 20.0 +xavier garcia 65672 20.0 +xavier hernandez 65541 20.0 +xavier hernandez 65544 20.0 +xavier hernandez 65766 20.0 +xavier ichabod 65597 20.0 +xavier ichabod 65663 20.0 +xavier johnson 65654 20.0 +xavier johnson 65744 20.0 +xavier king 65590 20.0 +xavier king 65601 20.0 +xavier laertes 65743 20.0 +xavier ovid 65788 20.0 +xavier polk 65587 20.0 +xavier polk 65653 20.0 +xavier polk 65675 20.0 +xavier polk 65696 20.0 +xavier quirinius 65599 20.0 +xavier quirinius 65650 20.0 +xavier quirinius 65656 20.0 +xavier quirinius 65737 20.0 +xavier thompson 65608 20.0 +xavier underhill 65710 20.0 +xavier white 65703 20.0 +xavier white 65732 20.0 +xavier xylophone 65572 20.0 +xavier zipper 65561 20.0 +yuri allen 65565 20.0 +yuri allen 65682 20.0 +yuri brown 65538 20.0 +yuri brown 65688 20.0 +yuri carson 65670 20.0 +yuri carson 65769 20.0 +yuri ellison 65570 20.0 +yuri ellison 65581 20.0 +yuri falkner 65658 20.0 +yuri falkner 65681 20.0 +yuri garcia 65639 20.0 +yuri hernandez 65706 20.0 +yuri johnson 65587 20.0 +yuri johnson 65697 20.0 +yuri johnson 65712 20.0 +yuri king 65721 20.0 +yuri laertes 65637 20.0 +yuri laertes 65773 20.0 +yuri nixon 65635 20.0 +yuri nixon 65740 20.0 +yuri polk 65607 20.0 +yuri polk 65713 20.0 +yuri polk 65742 20.0 +yuri quirinius 65544 20.0 +yuri quirinius 65617 20.0 +yuri quirinius 65695 20.0 +yuri steinbeck 65592 20.0 +yuri steinbeck 65679 20.0 +yuri thompson 65676 20.0 +yuri underhill 65718 20.0 +yuri underhill 65750 20.0 +yuri white 65659 20.0 +yuri xylophone 65714 20.0 +zach allen 65667 20.0 +zach brown 65559 20.0 +zach brown 65588 20.0 +zach brown 65691 20.0 +zach brown 65759 20.0 +zach brown 65762 20.0 +zach carson 65572 20.0 +zach ellison 65748 20.0 +zach falkner 65620 20.0 +zach falkner 65627 20.0 +zach garcia 65544 20.0 +zach garcia 65623 20.0 +zach garcia 65629 20.0 +zach garcia 65786 20.0 +zach ichabod 65599 20.0 +zach ichabod 65612 20.0 +zach king 65556 20.0 +zach king 65702 20.0 +zach king 65773 20.0 +zach miller 65584 20.0 +zach miller 65665 20.0 +zach miller 65719 20.0 +zach ovid 65578 20.0 +zach ovid 65703 20.0 +zach ovid 65750 20.0 +zach ovid 65784 20.0 +zach quirinius 65691 20.0 +zach robinson 65599 20.0 +zach steinbeck 65602 20.0 +zach steinbeck 65695 20.0 +zach thompson 65636 20.0 +zach thompson 65696 20.0 +zach underhill 65573 20.0 +zach white 65733 20.0 +zach xylophone 65542 20.0 +zach xylophone 65780 20.0 +zach young 65576 20.0 +zach zipper 65579 20.0 +zach zipper 65649 20.0 +zach zipper 65676 20.0 diff --git a/sql/hive/src/test/resources/hive-contrib-0.13.1.jar b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar new file mode 100644 index 000000000000..ce0740d9245a Binary files /dev/null and b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar differ diff --git a/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar new file mode 100644 index 000000000000..37af9aafad8a Binary files /dev/null and b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar differ diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index 5bc08062d30e..fea3404769d9 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -33,7 +33,7 @@ log4j.appender.FA.layout=org.apache.log4j.PatternLayout log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n # Set the logger level of File Appender to WARN -log4j.appender.FA.Threshold = INFO +log4j.appender.FA.Threshold = DEBUG # Some packages are noisy for no good reason. log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false @@ -48,9 +48,14 @@ log4j.logger.hive.log=OFF log4j.additivity.parquet.hadoop.ParquetRecordReader=false log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF + log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR - diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q index 9e036c1a91d3..e911fbf2d2c5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q @@ -5,6 +5,6 @@ SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY key, value)t ORDER BY ke SELECT key, value FROM src CLUSTER BY (key, value); -SELECT key, value FROM src ORDER BY (key ASC, value ASC); +SELECT key, value FROM src ORDER BY key ASC, value ASC; SELECT key, value FROM src SORT BY (key, value); SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY (key, value))t ORDER BY key, value; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q index b26a2e2799f7..a989800cbf85 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q @@ -1,42 +1,41 @@ +-- SORT_QUERY_RESULTS explain SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; CREATE TABLE union_out (id int); -insert overwrite table union_out +insert overwrite table union_out SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; -select * from union_out cluster by id; +select * from union_out; diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala new file mode 100644 index 000000000000..2590040f2ec1 --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.hive.HiveContext + +/** + * Entry point in test application for SPARK-8489. + * + * This file is not meant to be compiled during tests. It is already included + * in a pre-built "test.jar" located in the same directory as this file. + * This is included here for reference only and should NOT be modified without + * rebuilding the test jar itself. + * + * This is used in org.apache.spark.sql.hive.HiveSparkSubmitSuite. + */ +object Main { + def main(args: Array[String]) { + // scalastyle:off println + println("Running regression test for SPARK-8489.") + val sc = new SparkContext("local", "testing") + val hc = new HiveContext(sc) + // This line should not throw scala.reflect.internal.MissingRequirementError. + // See SPARK-8470 for more detail. + val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) + df.collect() + println("Regression test for SPARK-8489 success!") + // scalastyle:on println + sc.stop() + } +} + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala new file mode 100644 index 000000000000..b1681745c2ef --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Dummy class used in regression test SPARK-8489. */ +case class MyCoolClass(past: String, present: String, future: String) + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar new file mode 100644 index 000000000000..5944aa6076a5 Binary files /dev/null and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index ba391293884b..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util._ - - -/** - * *** DUPLICATED FROM sql/core. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class QueryTest extends PlanTest { - - /** - * Runs the plan and makes sure the answer contains all of the keywords, or the - * none of keywords are listed in the answer - * @param rdd the [[DataFrame]] to be executed - * @param exists true for make sure the keywords are listed in the output, otherwise - * to make sure none of the keyword are not listed in the output - * @param keywords keyword in string array - */ - def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") - } - } - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. - */ - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case o => o - }) - } - if (!isSorted) converted.sortBy(_.toString) else converted - } - val sparkAnswer = try rdd.collect().toSeq catch { - case e: Exception => - fail( - s""" - |Exception thrown while executing query: - |${rdd.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) - } - - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" - |Results do not match for query: - |${rdd.logicalPlan} - |== Analyzed Plan == - |${rdd.queryExecution.analyzed} - |== Physical Plan == - |${rdd.queryExecution.executedPlan} - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) - } - } - - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 081d94b6fc02..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.scalatest.FunSuite - -/** - * *** DUPLICATED FROM sql/catalyst/plans. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class PlanTest extends FunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) - val minId = if (list.isEmpty) 0 else list.min - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 7c8b5205e239..39d315aaeab5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,28 +17,16 @@ package org.apache.spark.sql.hive +import java.io.File + import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId +import org.apache.spark.util.Utils class CachedTableSuite extends QueryTest { - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -69,7 +57,7 @@ class CachedTableSuite extends QueryTest { checkAnswer( sql("SELECT * FROM src s"), preCacheResults) - + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } @@ -92,12 +80,12 @@ class CachedTableSuite extends QueryTest { } test("Drop cached table") { - sql("CREATE TABLE test(a INT)") - cacheTable("test") - sql("SELECT * FROM test").collect() - sql("DROP TABLE test") - intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { - sql("SELECT * FROM test").collect() + sql("CREATE TABLE cachedTableTest(a INT)") + cacheTable("cachedTableTest") + sql("SELECT * FROM cachedTableTest").collect() + sql("DROP TABLE cachedTableTest") + intercept[AnalysisException] { + sql("SELECT * FROM cachedTableTest").collect() } } @@ -170,4 +158,49 @@ class CachedTableSuite extends QueryTest { assertCached(table("udfTest")) uncacheTable("udfTest") } + + test("REFRESH TABLE also needs to recache the data (data source tables)") { + val tempPath: File = Utils.createTempDir() + tempPath.delete() + table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) + sql("DROP TABLE IF EXISTS refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Cache the table. + sql("CACHE TABLE refreshTable") + assertCached(table("refreshTable")) + // Append new data. + table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) + // We are still using the old data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Refresh the table. + sql("REFRESH TABLE refreshTable") + // We are using the new data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + + // Drop the table and create it again. + sql("DROP TABLE refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + // Refresh the table. REFRESH TABLE command should not make a uncached + // table cached. + sql("REFRESH TABLE refreshTable") + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + + sql("DROP TABLE refreshTable") + Utils.deleteRecursively(tempPath) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala new file mode 100644 index 000000000000..34b2edb44b03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URL + +import org.apache.spark.SparkFunSuite + +/** + * Verify that some classes load and that others are not found on the classpath. + * + * + * This is used to detect classpath and shading conflict, especially between + * Spark's required Kryo version and that which can be found in some Hive versions. + */ +class ClasspathDependenciesSuite extends SparkFunSuite { + private val classloader = this.getClass.getClassLoader + + private def assertLoads(classname: String): Unit = { + val resourceURL: URL = Option(findResource(classname)).getOrElse { + fail(s"Class $classname not found as ${resourceName(classname)}") + } + + logInfo(s"Class $classname at $resourceURL") + classloader.loadClass(classname) + } + + private def assertLoads(classes: String*): Unit = { + classes.foreach(assertLoads) + } + + private def findResource(classname: String): URL = { + val resource = resourceName(classname) + classloader.getResource(resource) + } + + private def resourceName(classname: String): String = { + classname.replace(".", "/") + ".class" + } + + private def assertClassNotFound(classname: String): Unit = { + Option(findResource(classname)).foreach { resourceURL => + fail(s"Class $classname found at $resourceURL") + } + + intercept[ClassNotFoundException] { + classloader.loadClass(classname) + } + } + + private def assertClassNotFound(classes: String*): Unit = { + classes.foreach(assertClassNotFound) + } + + private val KRYO = "com.esotericsoftware.kryo.Kryo" + + private val SPARK_HIVE = "org.apache.hive." + private val SPARK_SHADED = "org.spark-project.hive.shaded." + + test("shaded Protobuf") { + assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + } + + test("hive-common") { + assertLoads("org.apache.hadoop.hive.conf.HiveConf") + } + + test("hive-exec") { + assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") + } + + private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" + + test("unshaded kryo") { + assertLoads(KRYO, STD_INSTANTIATOR) + } + + test("Forbidden Dependencies") { + assertClassNotFound( + SPARK_HIVE + KRYO, + SPARK_SHADED + KRYO, + "org.apache.hive." + KRYO, + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR + ) + } + + test("parquet-hadoop-bundle") { + assertLoads( + "parquet.hadoop.ParquetOutputFormat", + "parquet.hadoop.ParquetInputFormat" + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala new file mode 100644 index 000000000000..30f5313d2b81 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.Try + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.{AnalysisException, QueryTest} + + +class ErrorPositionSuite extends QueryTest with BeforeAndAfter { + + before { + Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") + } + + positionTest("ambiguous attribute reference 1", + "SELECT a from dupAttributes", "a") + + positionTest("ambiguous attribute reference 2", + "SELECT a, b from dupAttributes", "a") + + positionTest("ambiguous attribute reference 3", + "SELECT b, a from dupAttributes", "a") + + positionTest("unresolved attribute 1", + "SELECT x FROM src", "x") + + positionTest("unresolved attribute 2", + "SELECT x FROM src", "x") + + positionTest("unresolved attribute 3", + "SELECT key, x FROM src", "x") + + positionTest("unresolved attribute 4", + """SELECT key, + |x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 5", + """SELECT key, + | x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 6", + """SELECT key, + | + | 1 + x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 7", + """SELECT key, + | + | 1 + x + 1 FROM src + """.stripMargin, "x") + + positionTest("multi-char unresolved attribute", + """SELECT key, + | + | 1 + abcd + 1 FROM src + """.stripMargin, "abcd") + + positionTest("unresolved attribute group by", + """SELECT key FROM src GROUP BY + |x + """.stripMargin, "x") + + positionTest("unresolved attribute order by", + """SELECT key FROM src ORDER BY + |x + """.stripMargin, "x") + + positionTest("unresolved attribute where", + """SELECT key FROM src + |WHERE x = true + """.stripMargin, "x") + + positionTest("unresolved attribute backticks", + "SELECT `x` FROM src", "`x`") + + positionTest("parse error", + "SELECT WHERE", "WHERE") + + positionTest("bad relation", + "SELECT * FROM badTable", "badTable") + + ignore("other expressions") { + positionTest("bad addition", + "SELECT 1 + array(1)", "1 + array") + } + + /** + * Creates a test that checks to see if the error thrown when analyzing a given query includes + * the location of the given token in the query string. + * + * @param name the name of the test + * @param query the query to analyze + * @param token a unique token in the string that should be indicated by the exception + */ + def positionTest(name: String, query: String, token: String): Unit = { + def parseTree = + Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + + test(name) { + val error = intercept[AnalysisException] { + quietly(sql(query)) + } + + assert(!error.getMessage.contains("Seq(")) + assert(!error.getMessage.contains("List(")) + + val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect { + case (l, i) if l.contains(token) => (l, i + 1) + }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query")) + val actualLine = error.line.getOrElse { + fail( + s"line not returned for error '${error.getMessage}' on token $token\n$parseTree" + ) + } + assert(actualLine === expectedLineNum, "wrong line") + + val expectedStart = line.indexOf(token) + val actualStart = error.startPosition.getOrElse { + fail( + s"start not returned for error on token $token\n" + + HiveQl.dumpTree(HiveQl.getAst(query)) + ) + } + assert(expectedStart === actualStart, + s"""Incorrect start position. + |== QUERY == + |$query + | + |== AST == + |$parseTree + | + |Actual: $actualStart, Expected: $expectedStart + |$line + |${" " * actualStart}^ + |0123456789 123456789 1234567890 + | 2 3 + """.stripMargin) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala new file mode 100644 index 000000000000..fb10f8583da9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.scalatest.BeforeAndAfterAll + +// TODO ideally we should put the test suite into the package `sql`, as +// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't +// support the `cube` or `rollup` yet. +class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { + private var testData: DataFrame = _ + + override def beforeAll() { + testData = Seq((1, 2), (2, 4)).toDF("a", "b") + TestHive.registerDataFrameAsTable(testData, "mytable") + } + + override def afterAll(): Unit = { + TestHive.dropTempTable("mytable") + } + + test("rollup") { + checkAnswer( + testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + ) + + checkAnswer( + testData.rollup("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + ) + } + + test("cube") { + checkAnswer( + testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + ) + + checkAnswer( + testData.cube("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala new file mode 100644 index 000000000000..52e782768cb7 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive.implicits._ + + +class HiveDataFrameJoinSuite extends QueryTest { + + // We should move this into SQL package if we make case sensitivity configurable in SQL. + test("join - self join auto resolve ambiguity with case insensitivity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("Key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("Key")), + Row(2, "2", 2, "2") :: Nil) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala new file mode 100644 index 000000000000..c177cbdd991c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +class HiveDataFrameWindowSuite extends QueryTest { + + test("reuse window partitionBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.orderBy("value").partitionBy("key") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lead(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lead with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), + sql( + """SELECT + | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile(2).over(Window.partitionBy("value").orderBy("key")), + rowNumber().over(Window.partitionBy("value").orderBy("key")), + denseRank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cumeDist().over(Window.partitionBy("value").orderBy("key")), + percentRank().over(Window.partitionBy("value").orderBy("key"))), + sql( + s"""SELECT + |key, + |max(key) over (partition by value order by key), + |min(key) over (partition by value order by key), + |avg(key) over (partition by value order by key), + |count(key) over (partition by value order by key), + |sum(key) over (partition by value order by key), + |ntile(2) over (partition by value order by key), + |row_number() over (partition by value order by key), + |dense_rank() over (partition by value order by key), + |rank() over (partition by value order by key), + |cume_dist() over (partition by value order by key), + |percent_rank() over (partition by value order by key) + |FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows between") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens with unbounded") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) + .equalTo("2") + .as("last_v"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) + .as("avg_key1"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) + .as("avg_key2"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) + .as("avg_key3") + ), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) + | FROM window_table""".stripMargin).collect()) + } + + test("reverse sliding range frame") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window. + partitionBy($"category"). + orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse unbounded range frame") { + val df = Seq(1, 2, 4, 3, 2, 1). + map(Tuple1.apply). + toDF("value") + val window = Window.orderBy($"value".desc) + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Long.MinValue, 1)), + sum($"value").over(window.rangeBetween(1, Long.MaxValue))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: + Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2d3ff680125a..81a70b8d4226 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.util -import java.sql.Date import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.ql.udf.UDAFPercentile @@ -27,12 +26,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable -import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions.{Literal, Row} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row -class HiveInspectorSuite extends FunSuite with HiveInspectors { +class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { val udaf = new UDAFPercentile.PercentileLongEvaluator() udaf.init() @@ -46,8 +47,12 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { classOf[UDAFPercentile.State], ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] - val a = unwrap(state, soi).asInstanceOf[Row] - val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] + val a = unwrap(state, soi).asInstanceOf[InternalRow] + + val dt = new StructType() + .add("counts", MapType(LongType, LongType)) + .add("percentiles", ArrayType(DoubleType)) + val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") val sfPercentiles = soi.getStructFieldRef("percentiles") @@ -76,13 +81,13 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(0.asInstanceOf[Float]) :: Literal(0.asInstanceOf[Double]) :: Literal("0") :: - Literal(new Date(2014, 9, 23)) :: + Literal(java.sql.Date.valueOf("2014-09-23")) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: - Literal(Array[Byte](1,2,3)) :: - Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal(Row(1,2.0d,3.0f), + Literal(Array[Byte](1, 2, 3)) :: + Literal.create(Seq[Int](1, 2, 3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1 -> 2, 2 -> 1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1, 2.0d, 3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -112,26 +117,25 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { - dt1.zip(dt2).map { - case (dd1, dd2) => - assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info + dt1.zip(dt2).foreach { case (dd1, dd2) => + assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info } } def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { - row1.zip(row2).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2).foreach { case (r1, r2) => + checkValue(r1, r2) } } - def checkValues(row1: Seq[Any], row2: Row): Unit = { - row1.zip(row2.toSeq).map { - case (r1, r2) => checkValue(r1, r2) + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => + checkValue(r1, r2) } } @@ -142,8 +146,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { assert(r1.compare(r2) === 0) case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => - r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } - case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0) + r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) } + // We don't support equality & ordering for map type, so skip it. + case (r1: MapData, r2: MapData) => case (r1, r2) => assert(r1 === r2) } } @@ -159,43 +164,45 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val writableOIs = dataTypes.map(toWritableInspector) val nullRow = data.map(d => null) - checkValues(nullRow, nullRow.zip(writableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) // struct couldn't be constant, sweep it out val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType]) + val constantTypes = constantExprs.map(_.dataType) val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + val constantNullWritableOIs = + constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) - checkValues(constantData, constantData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantData, constantData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantData.zip(constantNullWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantNullData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) } test("wrap / unwrap primitive writable object inspector") { val writableOIs = dataTypes.map(toWritableInspector) - checkValues(row, row.zip(writableOIs).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(writableOIs).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } test("wrap / unwrap primitive java object inspector") { val ois = dataTypes.map(toInspector) - checkValues(row, row.zip(ois).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(ois).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } @@ -203,28 +210,37 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) }) - - checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + val inspector = toInspector(dt) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + val d = new GenericArrayData(Array(row(0), row(0))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) - val d = Map(row(0) -> row(1)) - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + val d = ArrayBasedMapData(Array(row(0)), Array(row(1))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index aad48ada5264..574624d501f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,23 +17,146 @@ package org.apache.spark.sql.hive -import org.scalatest.FunSuite +import java.io.File -import org.apache.spark.sql.test.ExamplePointUDT -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.sources.DataSourceTest +import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.{Row, SaveMode, SQLContext} +import org.apache.spark.{Logging, SparkFunSuite} -class HiveMetastoreCatalogSuite extends FunSuite { - test("struct field should accept underscore in sub-column name") { - val metastr = "struct" +class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { - val datatype = HiveMetastoreTypes.toDataType(metastr) - assert(datatype.isInstanceOf[StructType]) + test("struct field should accept underscore in sub-column name") { + val hiveTypeStr = "struct" + val dateType = HiveMetastoreTypes.toDataType(hiveTypeStr) + assert(dateType.isInstanceOf[StructType]) } test("udt to metastore type conversion") { val udt = new ExamplePointUDT - assert(HiveMetastoreTypes.toMetastoreType(udt) === - HiveMetastoreTypes.toMetastoreType(udt.sqlType)) + assertResult(HiveMetastoreTypes.toMetastoreType(udt.sqlType)) { + HiveMetastoreTypes.toMetastoreType(udt) + } + } + + test("duplicated metastore relations") { + val df = sql("SELECT * FROM src") + logInfo(df.queryExecution.toString) + df.as('a).join(df.as('b), $"a.key" === $"b.key") + } +} + +class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { + override def _sqlContext: SQLContext = TestHive + import testImplicits._ + + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) + + Seq( + "parquet" -> ( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ), + + "orc" -> ( + "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcSerde" + ) + ).foreach { case (provider, (inputFormat, outputFormat, serde)) => + test(s"Persist non-partitioned $provider relation into metastore as managed table") { + withTable("t") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(!hiveTable.isPartitioned) + assert(hiveTable.tableType === ManagedTable) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + } + } + + test(s"Persist non-partitioned $provider relation into metastore as external table") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalFile + + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .option("path", path.toString) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + } + } + } + + test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalPath + + sql( + s"""CREATE TABLE t USING $provider + |OPTIONS (path '$path') + |AS SELECT 1 AS d1, "val_1" AS d2 + """.stripMargin) + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.isPartitioned === false) + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.partitionColumns.length === 0) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), Row(1, "val_1")) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + } + } + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala similarity index 80% rename from sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 581f66639949..fe0db5228de1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.{QueryTest, Row, SQLContext} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ + private val ctx = TestHive + override def _sqlContext: SQLContext = ctx test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -53,8 +52,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -63,12 +62,11 @@ class HiveParquetSuite extends QueryTest with ParquetTest { } } - test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala new file mode 100644 index 000000000000..79cf40aba4bf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} +import org.scalatest.BeforeAndAfterAll + + +class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { + override def beforeAll() { + if (SessionState.get() == null) { + SessionState.start(new HiveConf()) + } + } + + private def extractTableDesc(sql: String): (HiveTable, Boolean) = { + HiveQl.createPlan(sql).collect { + case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) + }.head + } + + test("Test CTAS #1") { + val s1 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |(viewTime INT, + |userid BIGINT, + |page_url STRING, + |referrer_url STRING, + |ip STRING COMMENT 'IP Address of the User', + |country STRING COMMENT 'country of origination') + |COMMENT 'This is the staging page view table' + |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s1) + assert(exists == true) + assert(desc.specifiedDatabase == Some("mydb")) + assert(desc.name == "page_view") + assert(desc.tableType == ExternalTable) + assert(desc.location == Some("/user/external/page_view")) + assert(desc.schema == + HiveColumn("viewtime", "int", null) :: + HiveColumn("userid", "bigint", null) :: + HiveColumn("page_url", "string", null) :: + HiveColumn("referrer_url", "string", null) :: + HiveColumn("ip", "string", "IP Address of the User") :: + HiveColumn("country", "string", "country of origination") :: Nil) + // TODO will be SQLText + assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.partitionColumns == + HiveColumn("dt", "string", "date type") :: + HiveColumn("hour", "string", "hour of the day") :: Nil) + assert(desc.serdeProperties == + Map((serdeConstants.SERIALIZATION_FORMAT, "\054"), (serdeConstants.FIELD_DELIM, "\054"))) + assert(desc.inputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + } + + test("Test CTAS #2") { + val s2 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |(viewTime INT, + |userid BIGINT, + |page_url STRING, + |referrer_url STRING, + |ip STRING COMMENT 'IP Address of the User', + |country STRING COMMENT 'country of origination') + |COMMENT 'This is the staging page view table' + |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s2) + assert(exists == true) + assert(desc.specifiedDatabase == Some("mydb")) + assert(desc.name == "page_view") + assert(desc.tableType == ExternalTable) + assert(desc.location == Some("/user/external/page_view")) + assert(desc.schema == + HiveColumn("viewtime", "int", null) :: + HiveColumn("userid", "bigint", null) :: + HiveColumn("page_url", "string", null) :: + HiveColumn("referrer_url", "string", null) :: + HiveColumn("ip", "string", "IP Address of the User") :: + HiveColumn("country", "string", "country of origination") :: Nil) + // TODO will be SQLText + assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.partitionColumns == + HiveColumn("dt", "string", "date type") :: + HiveColumn("hour", "string", "hour of the day") :: Nil) + assert(desc.serdeProperties == Map()) + assert(desc.inputFormat == Option("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.outputFormat == Option("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.serde == Option("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + } + + test("Test CTAS #3") { + val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.specifiedDatabase == None) + assert(desc.name == "page_view") + assert(desc.tableType == ManagedTable) + assert(desc.location == None) + assert(desc.schema == Seq.empty[HiveColumn]) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.serdeProperties == Map()) + assert(desc.inputFormat == Option("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + assert(desc.serde.isEmpty) + assert(desc.properties == Map()) + } + + test("Test CTAS #4") { + val s4 = + """CREATE TABLE page_view + |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin + intercept[AnalysisException] { + extractTableDesc(s4) + } + } + + test("Test CTAS #5") { + val s5 = """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin + val (desc, exists) = extractTableDesc(s5) + assert(exists == false) + assert(desc.specifiedDatabase == None) + assert(desc.name == "ctas2") + assert(desc.tableType == ManagedTable) + assert(desc.location == None) + assert(desc.schema == Seq.empty[HiveColumn]) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.inputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) + } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + HiveQl.parseSql(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala new file mode 100644 index 000000000000..dc2d85f48624 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File +import java.sql.Timestamp +import java.util.Date + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.util.{ResetSystemProperties, Utils} + +/** + * This suite tests spark-submit with applications using HiveContext. + */ +class HiveSparkSubmitSuite + extends SparkFunSuite + with Matchers + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. + with ResetSystemProperties + with Timeouts { + + // TODO: rewrite these or mark them as slow tests to be run sparingly + + def beforeAll() { + System.setProperty("spark.testing", "true") + } + + test("SPARK-8368: includes jars passed in through --jars") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") + val args = Seq( + "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSubmitClassLoaderTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("SPARK-8020: set sql conf in spark conf") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + unusedJar.toString) + runSparkSubmit(args) + } + + test("SPARK-8489: MissingRequirementError during reflection") { + // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates + // a HiveContext and uses it to create a data frame from an RDD using reflection. + // Before the fix in SPARK-8470, this results in a MissingRequirementError because + // the HiveContext code mistakenly overrides the class loader that contains user classes. + // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. + val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--class", "Main", + testJar) + runSparkSubmit(args) + } + + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + private def runSparkSubmit(args: Seq[String]): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val history = ArrayBuffer.empty[String] + val commands = Seq("./bin/spark-submit") ++ args + val commandLine = commands.mkString("'", "' '", "'") + + val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) + val env = builder.environment() + env.put("SPARK_TESTING", "1") + env.put("SPARK_HOME", sparkHome) + + def captureOutput(source: String)(line: String): Unit = { + // This test suite has some weird behaviors when executed on Jenkins: + // + // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a + // timestamp to provide more diagnosis information. + // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print + // them out for debugging purposes. + val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" + // scalastyle:off println + println(logLine) + // scalastyle:on println + history += logLine + } + + val process = builder.start() + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() + + try { + val exitCode = failAfter(180.seconds) { process.waitFor() } + if (exitCode != 0) { + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail { + s"""spark-submit returned with exit code $exitCode. + |Command line: $commandLine + | + |$historyLog + """.stripMargin + } + } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + +// This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. +// We test if we can load user jars in both driver and executors when HiveContext is used. +object SparkSubmitClassLoaderTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + logInfo("Testing load classes at the driver side.") + // First, we load classes at driver side. + try { + Utils.classForName(args(0)) + Utils.classForName(args(1)) + } catch { + case t: Throwable => + throw new Exception("Could not load user class from jar:\n", t) + } + // Second, we load classes at the executor side. + logInfo("Testing load classes at the executor side.") + val result = df.mapPartitions { x => + var exception: String = null + try { + Utils.classForName(args(0)) + Utils.classForName(args(1)) + } catch { + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") + } + Option(exception).toSeq.iterator + }.collect() + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) + } + + // Load a Hive UDF from the jar. + logInfo("Registering temporary Hive UDF provided in a jar.") + hiveContext.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Load a Hive SerDe from the jar. + logInfo("Creating a Hive table with a SerDe provided in a jar.") + hiveContext.sql( + """ + |CREATE TABLE t1(key int, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + // Actually use the loaded UDF and SerDe. + logInfo("Writing data into the table.") + hiveContext.sql( + "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = hiveContext.table("t1").orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"table t1 should have 10 rows instead of $count rows") + } + logInfo("Test finishes.") + sc.stop() + } +} + +// This object is used for testing SPARK-8020: https://issues.apache.org/jira/browse/SPARK-8020. +// We test if we can correctly set spark sql configurations when HiveContext is used. +object SparkSQLConfTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + // We override the SparkConf to add spark.sql.hive.metastore.version and + // spark.sql.hive.metastore.jars to the beginning of the conf entry array. + // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but + // before spark.sql.hive.metastore.jars get set, we will see the following exception: + // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only + // be used when hive execution version == hive metastore version. + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars + // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. + val conf = new SparkConf() { + override def getAll: Array[(String, String)] = { + def isMetastoreSetting(conf: String): Boolean = { + conf == "spark.sql.hive.metastore.version" || conf == "spark.sql.hive.metastore.jars" + } + // If there is any metastore settings, remove them. + val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) + + // Always add these two metastore settings at the beginning. + ("spark.sql.hive.metastore.version" -> "0.12") +: + ("spark.sql.hive.metastore.jars" -> "maven") +: + filteredSettings + } + + // For this simple test, we do not really clone this object. + override def clone: SparkConf = this + } + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Run a simple command to make sure all lazy vals in hiveContext get instantiated. + hiveContext.tables().collect() + sc.stop() + } +} + +object SPARK_9757 extends QueryTest with Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven")) + + val hiveContext = new TestHiveContext(sparkContext) + import hiveContext.implicits._ + + import org.apache.spark.sql.functions._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 4dd96bd5a1b7..d33e81227db8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,70 +19,86 @@ package org.apache.spark.sql.hive import java.io.File -import com.google.common.io.Files +import org.apache.hadoop.hive.conf.HiveConf +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) -class InsertIntoHiveTableSuite extends QueryTest { +case class ThreeCloumntable(key: Int, value: String, key1: String) + +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val testData = TestHive.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))) - testData.registerTempTable("testData") + (1 to 100).map(i => TestData(i, i.toString))).toDF() + + before { + // Since every we are doing tests for DDL statements, + // it is better to reset before every test. + TestHive.reset() + // Register the testData, which will be used in every test. + testData.registerTempTable("testData") + } test("insertInto() HiveTable") { - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) // Add more data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq + testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq ) // Now overwrite. - testData.insertInto("createAndInsertTest", overwrite = true) + testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) } test("Double create fails when allowExisting = false") { - createTable[TestData]("doubleCreateAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[org.apache.hadoop.hive.ql.metadata.HiveException] { - createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false) - } + val message = intercept[QueryExecutionException] { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + }.getMessage } test("Double create does not fail when allowExisting = true") { - createTable[TestData]("createAndInsertTest") - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -96,16 +112,43 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("SPARK-4203:random partition directory order") { - createTable[TestData]("tmp_table") - val tmpDir = Files.createTempDir() - sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + sql("CREATE TABLE tmp_table (key int, value string)") + val tmpDir = Utils.createTempDir() + val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + + sql( + s""" + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='1') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='2') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='3') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='4') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() - val folders = dir.filter(_.isDirectory).toList + val folders = dir.filter { e => e.isDirectory && !e.getName().startsWith(stagingDir) }.toList if (folders.isEmpty) { List(acc.reverse) } else { @@ -118,7 +161,7 @@ class InsertIntoHiveTableSuite extends QueryTest { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } @@ -127,7 +170,7 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -144,7 +187,7 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -161,7 +204,7 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = applySchema(rowRDD, schema) + val df = TestHive.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -172,4 +215,51 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE hiveTableWithStructValue") } + + test("SPARK-5498:partition schema does not match table schema") { + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() + + val tmpDir = Utils.createTempDir() + sql( + s""" + |CREATE TABLE table_with_partition(key int,value string) + |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (ds='1') SELECT key,value FROM testData + """.stripMargin) + + // test schema the same between partition and table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect().toSeq + ) + + // test difference type of field + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect().toSeq + ) + + // add column to table + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + testDatawithNull.collect().toSeq + ) + + // change column name to table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + testData.collect().toSeq + ) + + sql("DROP TABLE table_with_partition") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala new file mode 100644 index 000000000000..d3388a9429e4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -0,0 +1,75 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.Row + +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + + override def beforeAll(): Unit = { + // The catalog in HiveContext is a case insensitive one. + catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") + sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") + sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") + } + + override def afterAll(): Unit = { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + } + + test("get all tables of current database") { + Seq(tables(), sql("SHOW TABLes")).foreach { + case allTables => + // We are using default DB. + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + checkAnswer( + allTables.filter("tableName = 'hivelisttablessuitetable'"), + Row("hivelisttablessuitetable", false)) + assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0) + } + } + + test("getting all tables with a database name") { + Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach { + case allTables => + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hiveindblisttablessuitetable'"), + Row("hiveindblisttablessuitetable", false)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 85795acb658e..20a50586d520 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -17,423 +17,832 @@ package org.apache.spark.sql.hive -import java.io.File +import java.io.{IOException, File} -import org.scalatest.BeforeAndAfterEach +import scala.collection.mutable.ArrayBuffer -import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.util +import org.apache.spark.Logging import org.apache.spark.sql._ -import org.apache.spark.util.Utils -import org.apache.spark.sql.types._ - -/* Implicits */ +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { - override def afterEach(): Unit = { - reset() - if (ctasPath.exists()) Utils.deleteRecursively(ctasPath) - } - - val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll + with Logging { + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext - test ("persistent JSON table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + var jsonFilePath: String = _ - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + override def beforeAll(): Unit = { + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - test ("persistent JSON table with a user specified schema") { - sql( - s""" - |CREATE TABLE jsonTable ( - |a string, - |b String, - |`c_!@(3)` int, - |`` Struct<`d!`:array, `=`:array>>) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - jsonFile(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } - test ("persistent JSON table with a user specified schema with a subset of fields") { - // This works because JSON objects are self-describing and JSONRelation can get needed - // field values based on field names. - sql( - s""" - |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val innerStruct = StructType( - StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))) :: Nil) - val expectedSchema = StructType( - StructField("", innerStruct, true) :: - StructField("b", StringType, true) :: Nil) - - assert(expectedSchema === table("jsonTable").schema) - - jsonFile(filePath).registerTempTable("expectedJsonTable") + test("persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + } + } + } - checkAnswer( - sql("SELECT b, ``.`=` FROM jsonTable"), - sql("SELECT b, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema with a subset of fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val innerStruct = StructType(Seq( + StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))))) + + val expectedSchema = StructType(Seq( + StructField("", innerStruct, true), + StructField("b", StringType, true))) + + assert(expectedSchema === table("jsonTable").schema) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable")) + } + } } test("resolve shortened provider names") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } test("drop table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) - - sql("DROP TABLE jsonTable") - - intercept[Exception] { - sql("SELECT * FROM jsonTable").collect() + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath)) + + sql("DROP TABLE jsonTable") + + intercept[Exception] { + sql("SELECT * FROM jsonTable").collect() + } + + assert( + new File(jsonFilePath).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") } - - assert( - (new File(filePath)).exists(), - "The table with specified path is considered as an external table, " + - "its data should not deleted after DROP TABLE.") } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json") - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) + withTempPath { tempDir => + withTable("jsonTable") { + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a1", "b1", "c1") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + // Schema is cached so the new column does not show. The updated values in existing columns + // will show. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1")) + + sql("REFRESH TABLE jsonTable") + + // Check that the refresh worked + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1", "c1")) + } + } + } - FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + test("drop, change, recreate") { + withTempPath { tempDir => + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql("DROP TABLE jsonTable") + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + // New table should reflect new schema. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b", "c")) + } + } + } - // Schema is cached so the new column does not show. The updated values in existing columns - // will show. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1")) + test("invalidate cache and reload") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable (`c_!@(3)` int) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - refreshTable("jsonTable") + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") - // Check that the refresh worked - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1", "c1")) - FileUtils.deleteDirectory(tempDir) - } + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json") - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) + // Discard the cached relation. + invalidateTable("jsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + invalidateTable("jsonTable") + val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - sql("DROP TABLE jsonTable") + assert(expectedSchema === table("jsonTable").schema) + } + } + } - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) + test("CTAS") { + withTempPath { tempPath => + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } + } - // New table should reflect new schema. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b", "c")) - FileUtils.deleteDirectory(tempDir) + test("CTAS with IF NOT EXISTS") { + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + + assert( + message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s"""CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } - test("invalidate cache and reload") { - sql( - s""" - |CREATE TABLE jsonTable (`c_!@(3)` int) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + test("CTAS a managed table") { + withTable("jsonTable", "ctasJsonTable", "loadedTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |AS + |SELECT * FROM jsonTable + """.stripMargin) - jsonFile(filePath).registerTempTable("expectedJsonTable") + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + sql( + s"""CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$expectedPath' + |) + """.stripMargin) - // Discard the cached relation. - invalidateTable("jsonTable") + assert(table("ctasJsonTable").schema === table("loadedTable").schema) - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable")) - invalidateTable("jsonTable") - val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + } + } - assert(expectedSchema === table("jsonTable").schema) + test("SPARK-5286 Fail to drop an invalid table when using the data source API") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path 'it is not a path at all!' + |) + """.stripMargin) + + sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) + } } - test("CTAS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${ctasPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("jsonTable").schema) + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { + withTable("savedJsonTable") { + // Save the df as a managed table (by not specifying the path). + (1 to 10) + .map(i => i -> s"str$i") + .toDF("a", "b") + .write + .format("json") + .saveAsTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) + + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) + + invalidateTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) + + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) + } + } - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + test("save table") { + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("savedJsonTable") { + val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + // Save the df as a managed table (by not specifying the path). + df.write.saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // We can overwrite it. + df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") + // TODO in ResolvedDataSource, will convert the schema into nullable = true + // hence the df.schema is not exactly the same as table("savedJsonTable").schema + // assert(df.schema === table("savedJsonTable").schema) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[IOException] { + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + } + + // Create an external table by specifying the path. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { + df.write + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + } + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer(read.json(tempPath.toString), df) + } + } } - test("CTAS with IF NOT EXISTS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${ctasPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - // Create the table again should trigger a AlreadyExistsException. - val message = intercept[RuntimeException] { - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${ctasPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - }.getMessage - assert(message.contains("Table ctasJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // The following statement should be fine if it has IF NOT EXISTS. - // It tries to create a table ctasJsonTable with a new schema. - // The actual table's schema and data should not be changed. - sql( - s""" - |CREATE TABLE IF NOT EXISTS ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${ctasPath}' - |) AS - |SELECT a FROM jsonTable - """.stripMargin) - - // Discard the cached relation. - invalidateTable("ctasJsonTable") - - // Schema should not be changed. - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - // Table data should not be changed. - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + test("create external table") { + withTempPath { tempPath => + withTable("savedJsonTable", "createdJsonTable") { + val df = read.json(sparkContext.parallelize((1 to 10).map { i => + s"""{ "a": $i, "b": "str$i" }""" + })) + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { + df.write + .format("json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + } + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer(sql("SELECT * FROM createdJsonTable"), df) + + assert( + intercept[AnalysisException] { + createExternalTable("createdJsonTable", jsonFilePath.toString) + }.getMessage.contains("Table createdJsonTable already exists."), + "We should complain that createdJsonTable already exists") + } + + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") + checkAnswer(read.json(tempPath.toString), df) + + // Try to specify the schema. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable")) + + sql("DROP TABLE createdJsonTable") + + assert( + intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage.contains("key not found: path"), + "We should complain that path is not specified.") + } + } + } } - test("CTAS a managed table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - new Path("/Users/yhuai/Desktop/whatever") - - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") - val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) - - // It is a managed table when we do not specify the location. - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - - sql( - s""" - |CREATE TABLE loadedTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${expectedPath}' - |) - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("loadedTable").schema) + test("scan a parquet table created through a CTAS statement") { + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "true") { + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") + } + } + } + } + } - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM loadedTable").collect() - ) + test("Pre insert nullability check (ArrayType)") { + withTable("arrayInParquet") { + { + val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("arrayInParquet") + } + + { + val df = (Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("arrayInParquet") + } + + (Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + + (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") + + refreshTable("arrayInParquet") + + checkAnswer( + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) + } + } - sql("DROP TABLE ctasJsonTable") - assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + test("Pre insert nullability check (MapType)") { + withTable("mapInParquet") { + { + val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("mapInParquet") + } + + { + val df = (Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("mapInParquet") + } + + (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + + (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") + + refreshTable("mapInParquet") + + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + } } - test("SPARK-5286 Fail to drop an invalid table when using the data source API") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path 'it is not a path at all!' - |) - """.stripMargin) - - sql("DROP TABLE jsonTable").collect().foreach(println) + test("SPARK-6024 wide schema support") { + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { + withTable("wide_schema") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) + + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + } } - test("save and load table") { - val originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + withTable(tableName) { + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + val hiveTable = HiveTable( + specifiedDatabase = Some("default"), + name = tableName, + schema = Seq.empty, + partitionColumns = Seq.empty, + properties = Map( + "spark.sql.sources.provider" -> "json", + "spark.sql.sources.schema" -> schema.json, + "EXTERNAL" -> "FALSE"), + tableType = ManagedTable, + serdeProperties = Map( + "path" -> catalog.hiveDefaultTableFilePath(tableName))) + + catalog.client.createTable(hiveTable) + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + } + } - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) + test("Saving partition columns information") { + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") + val tableName = s"partitionInfo_${System.currentTimeMillis()}" + + withTable(tableName) { + df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val actualPartitionColumns = + StructType( + metastoreTable.partitionColumns.map(c => + StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + // Make sure partition columns are correctly stored in metastore. + assert( + expectedPartitionColumns.sameType(actualPartitionColumns), + s"Partitions columns stored in metastore $actualPartitionColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } + } - df.saveAsTable("savedJsonTable") + test("insert into a table") { + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) + withTable("insertParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) + + intercept[AnalysisException] { + createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + } + + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) + + createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) + + intercept[AnalysisException] { + createDF(30, 39).write.saveAsTable("insertParquet") + } + + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) + + createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) + + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) + } + } - createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false) - assert(table("createdJsonTable").schema === df.schema) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - df.collect()) - - val message = intercept[RuntimeException] { - createTable("createdJsonTable", filePath.toString, false) - }.getMessage - assert(message.contains("Table createdJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - createTable("createdJsonTable", filePath.toString, true) - // createdJsonTable should be not changed. - assert(table("createdJsonTable").schema === df.schema) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - df.collect()) + test("SPARK-8156:create table to specific database by 'use dbname' ") { + + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + sqlContext.sql("""create database if not exists testdb8156""") + sqlContext.sql("""use testdb8156""") + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("ttt3") - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + checkAnswer( + sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + Row("ttt3", false)) + sqlContext.sql("""use default""") + sqlContext.sql("""drop database if exists testdb8156 CASCADE""") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala new file mode 100644 index 000000000000..997c667ec0d1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode} + +class MultiDatabaseSuite extends QueryTest with SQLTestUtils { + override val _sqlContext: HiveContext = TestHive + private val sqlContext = _sqlContext + + private val df = sqlContext.range(10).coalesce(1) + + private def checkTablePath(dbName: String, tableName: String): Unit = { + // val hiveContext = sqlContext.asInstanceOf[HiveContext] + val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName) + val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName + + assert(metastoreTable.serdeProperties("path") === expectedPath) + } + + test(s"saveAsTable() to non-default database - with USE - Overwrite") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"saveAsTable() to non-default database - without USE - Overwrite") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"createExternalTable() to non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + + sqlContext.createExternalTable("t", path, "parquet") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table("t"), df) + + sql( + s""" + |CREATE TABLE t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table("t1"), df) + } + } + } + } + + test(s"createExternalTable() to non-default database - without USE") { + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + sqlContext.createExternalTable(s"$db.t", path, "parquet") + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + sql( + s""" + |CREATE TABLE $db.t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table(s"$db.t1"), df) + } + } + } + + test(s"saveAsTable() to non-default database - with USE - Append") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + df.write.mode(SaveMode.Append).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df.unionAll(df)) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") + } + } + + test(s"saveAsTable() to non-default database - without USE - Append") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") + } + } + + test(s"insertInto() non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + } + + test(s"insertInto() non-default database - without USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + } + + assert(sqlContext.tableNames(db).contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test("Looks up tables in non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql("CREATE TABLE t (key INT)") + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + } + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + } + } + + test("Drops a table in a non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql(s"CREATE TABLE t (key INT)") + assert(sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(sqlContext.tableNames(db).contains("t")) + + activateDatabase(db) { + sql(s"DROP TABLE t") + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames(db).contains("t")) + } + } + + test("Refreshes a table in a non-default database - with USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + activateDatabase(db) { + sql( + s"""CREATE EXTERNAL TABLE t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql("ALTER TABLE t ADD PARTITION (p=1)") + sql("REFRESH TABLE t") + checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql("ALTER TABLE t ADD PARTITION (p=2)") + sqlContext.refreshTable("t") + checkAnswer( + sqlContext.table("t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + } + + test("Refreshes a table in a non-default database - without USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + sql( + s"""CREATE EXTERNAL TABLE $db.t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") + sql(s"REFRESH TABLE $db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") + sqlContext.refreshTable(s"$db.t") + checkAnswer( + sqlContext.table(s"$db.t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + + test("invalid database name and table names") { + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`t:a`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`table`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`t:a` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`table` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala new file mode 100644 index 000000000000..91d7a48208e8 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp +import java.util.{Locale, TimeZone} + +import org.apache.hadoop.hive.conf.HiveConf +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.{Row, SQLConf, SQLContext} + +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with BeforeAndAfterAll { + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext + + /** + * Set the staging directory (and hence path to ignore Parquet files under) + * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. + */ + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + protected override def beforeAll(): Unit = { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Locale.setDefault(Locale.US) + } + + override protected def afterAll(): Unit = { + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + } + + override protected def logParquetSchema(path: String): Unit = { + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |$schema + """.stripMargin) + } + + private def testParquetHiveCompatibility(row: Row, hiveTypes: String*): Unit = { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // Hive columns are always nullable, so here we append a all-null row. + val rows = row :: Row(Seq.fill(row.length)(null): _*) :: Nil + + // Don't convert Hive metastore Parquet tables to let Hive write those Parquet files. + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + val fields = hiveTypes.zipWithIndex.map { case (typ, index) => s" col_$index $typ" } + + val ddl = + s"""CREATE TABLE parquet_compat( + |${fields.mkString(",\n")} + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin + + logInfo( + s"""Creating testing Parquet table with the following DDL: + |$ddl + """.stripMargin) + + sqlContext.sql(ddl) + + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } + + logParquetSchema(path) + + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), rows) + } + } + } + } + + test("simple primitives") { + testParquetHiveCompatibility( + Row(true, 1.toByte, 2.toShort, 3, 4.toLong, 5.1f, 6.1d, "foo"), + "BOOLEAN", "TINYINT", "SMALLINT", "INT", "BIGINT", "FLOAT", "DOUBLE", "STRING") + } + + test("SPARK-10177 timestamp") { + testParquetHiveCompatibility(Row(Timestamp.valueOf("2015-08-24 00:31:00")), "TIMESTAMP") + } + + test("array") { + testParquetHiveCompatibility( + Row( + Seq[Integer](1: Integer, null, 2: Integer, null), + Seq[String]("foo", null, "bar", null), + Seq[Seq[Integer]]( + Seq[Integer](1: Integer, null), + Seq[Integer](2: Integer, null))), + "ARRAY", + "ARRAY", + "ARRAY>") + } + + test("map") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + (1: Integer) -> "foo", + (2: Integer) -> null)), + "MAP") + } + + // HIVE-11625: Parquet map entries with null keys are dropped by Hive + ignore("map entries with null keys") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + null.asInstanceOf[Integer] -> "bar", + null.asInstanceOf[Integer] -> null)), + "MAP") + } + + test("struct") { + testParquetHiveCompatibility( + Row(Row(1, Seq("foo", "bar", null))), + "STRUCT>") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala new file mode 100644 index 000000000000..1cc8a93e8341 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.google.common.io.Files +import org.apache.spark.sql.test.SQLTestUtils + +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.util.Utils + + +class QueryPartitionSuite extends QueryTest with SQLTestUtils { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ + + protected def _sqlContext = ctx + + test("SPARK-5068: query data when path doesn't exist"){ + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { + val testData = ctx.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala new file mode 100644 index 000000000000..93dcb10f7a29 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer + +class SerializationSuite extends SparkFunSuite { + + test("[SPARK-5840] HiveContext should be serializable") { + val hiveContext = org.apache.spark.sql.hive.test.TestHive + hiveContext.hiveconf + val serializer = new JavaSerializer(new SparkConf()).newInstance() + val bytes = serializer.serialize(hiveContext) + val deSer = serializer.deserialize[AnyRef](bytes) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 6f07fd5a879c..e4fec7e2c8a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,13 +23,18 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - TestHive.reset() - TestHive.cacheTables = false + + private lazy val ctx: HiveContext = { + val ctx = org.apache.spark.sql.hive.test.TestHive + ctx.reset() + ctx.cacheTables = false + ctx + } + + import ctx.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -49,6 +54,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { } } + // Ensure session state is initialized. + ctx.parseSql("use default") + assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[HiveNativeCommand]) @@ -72,17 +80,13 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - // TODO: How does it works? needs to add it back for other hive version. - if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) - } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) @@ -110,7 +114,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -120,18 +124,18 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { - analyze("tempTable") + intercept[UnsupportedOperationException] { + ctx.analyze("tempTable") } - catalog.unregisterTable(Seq("tempTable")) + ctx.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { - val rdd = sql("""SELECT * FROM src""") - val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => + val df = sql("""SELECT * FROM src""") + val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") + assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -142,40 +146,40 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { after: () => Unit, query: String, expectedAnswer: Seq[Row], - ct: ClassTag[_]) = { + ct: ClassTag[_]): Unit = { before() - var rdd = sql(query) + var df = sql(query) // Assert src has a size smaller than the threshold. - val sizes = rdd.queryExecution.analyzed.collect { + val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold - && sizes(1) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. - var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.size === 1, - s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + s"actual query plans do not contain broadcast join: ${df.queryExecution}") - checkAnswer(rdd, expectedAnswer) // check correctness of output + checkAnswer(df, expectedAnswer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") - rdd = sql(query) - bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") + df = sql(query) + bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } after() @@ -199,45 +203,45 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin val answer = Row(86, "val_86") - var rdd = sql(leftSemiJoinQuery) + var df = sql(leftSemiJoinQuery) // Assert src has a size smaller than the threshold. - val sizes = rdd.queryExecution.analyzed.collect { + val sizes = df.queryExecution.analyzed.collect { case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold - && sizes(0) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. - var bhj = rdd.queryExecution.sparkPlan.collect { + var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } assert(bhj.size === 1, - s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + s"actual query plans do not contain broadcast join: ${df.queryExecution}") - checkAnswer(rdd, answer) // check correctness of output + checkAnswer(df, answer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") - rdd = sql(leftSemiJoinQuery) - bhj = rdd.queryExecution.sparkPlan.collect { + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") + df = sql(leftSemiJoinQuery) + bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = rdd.queryExecution.sparkPlan.collect { + val shj = df.queryExecution.sparkPlan.collect { case j: LeftSemiJoinHash => j } assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala new file mode 100644 index 000000000000..7ee1c8d13aa3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.QueryTest + +case class FunctionResult(f1: String, f2: String) + +class UDFSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + + test("UDF case insensitive") { + ctx.udf.register("random0", () => { Math.random() }) + ctx.udf.register("RANDOM1", () => { Math.random() }) + ctx.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala new file mode 100644 index 000000000000..5e7b93d45710 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.util.Collections + +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A set of tests for the filter conversion logic used when pushing partition pruning into the + * metastore + */ +class FiltersSuite extends SparkFunSuite with Logging { + private val shim = new Shim_v0_13 + + private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") + private val varCharCol = new FieldSchema() + varCharCol.setName("varchar") + varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) + testTable.setPartCols(Collections.singletonList(varCharCol)) + + filterTest("string filter", + (a("stringcol", StringType) > Literal("test")) :: Nil, + "stringcol > \"test\"") + + filterTest("string filter backwards", + (Literal("test") > a("stringcol", StringType)) :: Nil, + "\"test\" > stringcol") + + filterTest("int filter", + (a("intcol", IntegerType) === Literal(1)) :: Nil, + "intcol = 1") + + filterTest("int filter backwards", + (Literal(1) === a("intcol", IntegerType)) :: Nil, + "1 = intcol") + + filterTest("int and string filter", + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + "1 = intcol and \"a\" = strcol") + + filterTest("skip varchar", + (Literal("") === a("varchar", StringType)) :: Nil, + "") + + private def filterTest(name: String, filters: Seq[Expression], result: String) = { + test(name){ + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail( + s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala new file mode 100644 index 000000000000..f0bb77092c0c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.File + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.Utils + +/** + * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality + * is not fully tested. + */ +class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + + private def buildConf() = { + lazy val warehousePath = Utils.createTempDir() + lazy val metastorePath = Utils.createTempDir() + metastorePath.delete() + Map( + "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", + "hive.metastore.warehouse.dir" -> warehousePath.toString) + } + + test("success sanity check") { + val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, + buildConf(), + ivyPath).client + val db = new HiveDatabase("default", "") + badClient.createDatabase(db) + } + + private def getNestedMessages(e: Throwable): String = { + var causes = "" + var lastException = e + while (lastException != null) { + causes += lastException.toString + "\n" + lastException = lastException.getCause + } + causes + } + + private val emptyDir = Utils.createTempDir().getCanonicalPath + + private def partSpec = { + val hashMap = new java.util.LinkedHashMap[String, String] + hashMap.put("key", "1") + hashMap + } + + // Its actually pretty easy to mess things up and have all of your tests "pass" by accidentally + // connecting to an auto-populated, in-process metastore. Let's make sure we are getting the + // versions right by forcing a known compatibility failure. + // TODO: currently only works on mysql where we manually create the schema... + ignore("failure sanity check") { + val e = intercept[Throwable] { + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } + } + assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") + } + + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") + + private var client: ClientInterface = null + + versions.foreach { version => + test(s"$version: create client") { + client = null + System.gc() // Hack to avoid SEGV on some JVM versions. + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client + } + + test(s"$version: createDatabase") { + val db = HiveDatabase("default", "") + client.createDatabase(db) + } + + test(s"$version: createTable") { + val table = + HiveTable( + specifiedDatabase = Option("default"), + name = "src", + schema = Seq(HiveColumn("key", "int", "")), + partitionColumns = Seq.empty, + properties = Map.empty, + serdeProperties = Map.empty, + tableType = ManagedTable, + location = None, + inputFormat = + Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), + outputFormat = + Some(classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = + Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName())) + + client.createTable(table) + } + + test(s"$version: getTable") { + client.getTable("default", "src") + } + + test(s"$version: listTables") { + assert(client.listTables("default") === Seq("src")) + } + + test(s"$version: currentDatabase") { + assert(client.currentDatabase === "default") + } + + test(s"$version: getDatabase") { + client.getDatabase("default") + } + + test(s"$version: alterTable") { + client.alterTable(client.getTable("default", "src")) + } + + test(s"$version: set command") { + client.runSqlHive("SET spark.sql.test.key=1") + } + + test(s"$version: create partitioned table DDL") { + client.runSqlHive("CREATE TABLE src_part (value INT) PARTITIONED BY (key INT)") + client.runSqlHive("ALTER TABLE src_part ADD PARTITION (key = '1')") + } + + test(s"$version: getPartitions") { + client.getAllPartitions(client.getTable("default", "src_part")) + } + + test(s"$version: getPartitionsByFilter") { + client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( + AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), + Literal(1)))) + } + + test(s"$version: loadPartition") { + client.loadPartition( + emptyDir, + "default.src_part", + partSpec, + false, + false, + false, + false) + } + + test(s"$version: loadTable") { + client.loadTable( + emptyDir, + "src", + false, + false) + } + + test(s"$version: loadDynamicPartitions") { + client.loadDynamicPartitions( + emptyDir, + "default.src_part", + partSpec, + false, + 1, + false, + false) + } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala new file mode 100644 index 000000000000..4886a8594836 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -0,0 +1,637 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} + +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext + import sqlContext.implicits._ + + var originalUseAggregate2: Boolean = _ + + override def beforeAll(): Unit = { + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") + val data1 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 20), + (1, 30), + (2, 0), + (null, -10), + (2, -1), + (2, null), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") + + val data2 = Seq[(Integer, Integer, Integer)]( + (1, 10, -10), + (null, -60, 60), + (1, 30, -30), + (1, 30, 30), + (2, 1, 1), + (null, -10, 10), + (2, -1, null), + (2, 1, 1), + (2, null, 1), + (null, 100, -10), + (3, null, 3), + (null, null, null), + (3, null, null)).toDF("key", "value1", "value2") + data2.write.saveAsTable("agg2") + + val emptyDF = sqlContext.createDataFrame( + sqlContext.sparkContext.emptyRDD[Row], + StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) + emptyDF.registerTempTable("emptyTable") + + // Register UDAFs + sqlContext.udf.register("mydoublesum", new MyDoubleSum) + sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + } + + override def afterAll(): Unit = { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.dropTempTable("emptyTable") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) + } + + test("empty table") { + // If there is no GROUP BY clause and the table is empty, we will generate a single row. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key), + | COUNT(DISTINCT value) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil) + + // If there is a GROUP BY clause and the table is empty, there is no output. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(value), + | FIRST(value), + | LAST(value), + | MAX(value), + | MIN(value), + | SUM(value), + | COUNT(DISTINCT value) + |FROM emptyTable + |GROUP BY key + """.stripMargin), + Nil) + } + + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + + test("only do grouping") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT value1, key + |FROM agg2 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + } + + test("case in-sensitive resolution") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), kEY - 100 + |FROM agg1 + |GROUP BY Key - 100 + """.stripMargin), + Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT sum(distinct value1), kEY - 100, count(distinct value1) + |FROM agg2 + |GROUP BY Key - 100 + """.stripMargin), + Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT valUe * key - 100 + |FROM agg1 + |GROUP BY vAlue * keY - 100 + """.stripMargin), + Row(-90) :: + Row(-80) :: + Row(-70) :: + Row(-100) :: + Row(-102) :: + Row(null) :: Nil) + } + + test("test average no key in output") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) + } + + test("test average") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + 1.5, key + 10 + |FROM agg1 + |GROUP BY key + 10 + """.stripMargin), + Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) FROM agg1 + """.stripMargin), + Row(11.125) :: Nil) + } + + test("udaf") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | mydoubleavg(value), + | avg(value - key), + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) :: + Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) :: + Row(3, null, null, null, null, null) :: + Row(null, null, 110.0, null, null, 10.0) :: Nil) + } + + test("non-AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value) FROM agg1 + """.stripMargin), + Row(89.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoublesum(value + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) + } + + test("single distinct column set") { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) + } + + test("test count") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1) :: + Row(1, -60, 1, 1, null) :: + Row(2, 30, 2, 2, 1) :: + Row(2, 1, 2, 2, 2) :: + Row(1, -10, 1, 1, null) :: + Row(0, -1, 1, 1, 2) :: + Row(1, null, 1, 1, 2) :: + Row(1, 100, 1, 1, null) :: + Row(1, null, 2, 2, 3) :: + Row(0, null, 1, 1, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key, + | count(DISTINCT abs(value2)) + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1, 1) :: + Row(1, -60, 1, 1, null, 1) :: + Row(2, 30, 2, 2, 1, 1) :: + Row(2, 1, 2, 2, 2, 1) :: + Row(1, -10, 1, 1, null, 1) :: + Row(0, -1, 1, 1, 2, 0) :: + Row(1, null, 1, 1, 2, 1) :: + Row(1, 100, 1, 1, null, 1) :: + Row(1, null, 2, 2, 3, 1) :: + Row(0, null, 1, 1, null, 0) :: Nil) + } + + test("test Last implemented based on AggregateExpression1") { + // TODO: Remove this test once we remove AggregateExpression1. + import org.apache.spark.sql.functions._ + val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + + checkAnswer( + df.groupBy("i").agg(last("j")), + df + ) + } + } + + test("error handling") { + withSQLConf("spark.sql.useAggregate2" -> "false") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + } + + // TODO: once we support Hive UDAF in the new interface, + // we can remove the following two tests. + withSQLConf("spark.sql.useAggregate2" -> "true") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // This will fall back to the old aggregate + val newAggregateOperators = sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).queryExecution.executedPlan.collect { + case agg: aggregate.SortBasedAggregate => agg + case agg: aggregate.TungstenAggregate => agg + } + val message = + "We should fallback to the old aggregation code path if " + + "there is any aggregate function that cannot be converted to the new interface." + assert(newAggregateOperators.isEmpty, message) + } + } +} + +class SortBasedAggregationQuerySuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } +} + +class TungstenAggregationQuerySuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } +} + +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } + + override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + sqlContext.setConf( + "spark.sql.TungstenAggregate.testFallbackStartsAt", + fallbackStartsAt.toString) + + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 42a82c1fbf5c..a3f5921a0cb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class BigDataBenchmarkSuite extends HiveComparisonTest { val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") + val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath val testTables = Seq( TestTable( "rankings", @@ -63,7 +64,7 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | searchWord STRING, | duration INT) | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," - | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), TestTable( "documents", @@ -83,7 +84,10 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") createQueryTest("query2", - "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + """ + |SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits + |GROUP BY SUBSTR(sourceIP, 1, 10) + """.stripMargin) createQueryTest("query3", """ @@ -113,8 +117,8 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { |CREATE TABLE url_counts_total AS | SELECT SUM(count) AS totalCount, destpage | FROM url_counts_partial GROUP BY destpage - |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic - |-- given different input splits. + |-- The following queries run, but generate different results in HIVE + |-- likely because the UDF is not deterministic given different input splits. |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial |-- SELECT COUNT(*) FROM url_counts_partial |-- SELECT * FROM url_counts_partial diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 23ece7e7cf6e..b0d3dd44daed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index f8a957d55d57..4d45249d9c6b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} +import scala.util.control.NonFatal -import org.apache.spark.Logging -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} -import org.apache.spark.sql.hive.DescribeCommand +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} + +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.test.TestHive /** @@ -40,7 +42,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -124,12 +126,12 @@ abstract class HiveComparisonTest protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") protected def getMd5(str: String): String = { val digest = java.security.MessageDigest.getInstance("MD5") - digest.update(str.getBytes("utf-8")) + digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes("utf-8")) new java.math.BigInteger(1, digest.digest).toString(16) } protected def prepareAnswer( - hiveQuery: TestHive.type#HiveQLQueryExecution, + hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { @@ -138,7 +140,7 @@ abstract class HiveComparisonTest case _ => plan.children.iterator.exists(isSorted) } - val orderedAnswer = hiveQuery.logical match { + val orderedAnswer = hiveQuery.analyzed match { // Clean out non-deterministic time schema info. // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. @@ -241,7 +243,10 @@ abstract class HiveComparisonTest // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val allQueries = sql.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq + val sqlWithoutComment = + sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") + val allQueries = + sqlWithoutComment.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq // TODO: DOCUMENT UNSUPPORTED val queryList = @@ -252,8 +257,9 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") .filterNot(_ contains "hive.exec.post.hooks") - if (allQueries != queryList) + if (allQueries != queryList) { logWarning(s"Simplifications made on unsupported operations for test $testCaseName") + } lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -269,7 +275,7 @@ abstract class HiveComparisonTest } val hiveCacheFiles = queryList.zipWithIndex.map { - case (queryString, i) => + case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" new File(answerCache, cachedAnswerName) } @@ -294,21 +300,26 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) + val hiveQueries = queryList.map(new TestHive.QueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. + // Note this must only look at the logical plan as we might not be able to analyze if + // other DDL has not been executed yet. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { - case ((queryString, i), hiveQuery, cachedAnswerFile)=> + case ((queryString, i), hiveQuery, cachedAnswerFile) => try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. - if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) { sys.error("hive exec hooks not supported for tests.") + } - logWarning(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i + 1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { - case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _: ExplainCommand => + // No need to execute EXPLAIN queries as we don't check the output. + Nil case _ => TestHive.runSqlHive(queryString) } @@ -339,7 +350,7 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.HiveQLQueryExecution(queryString) + val query = new TestHive.QueryExecution(queryString) try { (query, prepareAnswer(query, query.stringResult())) } catch { case e: Throwable => val errorMessage = @@ -361,7 +372,11 @@ abstract class HiveComparisonTest // Check that the results match unless its an EXPLAIN query. val preparedHive = prepareAnswer(hiveQuery, hive) - if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { + // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction + if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && + (!hiveQuery.logical.isInstanceOf[ShowFunctions]) && + (!hiveQuery.logical.isInstanceOf[DescribeFunction]) && + preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst @@ -373,11 +388,45 @@ abstract class HiveComparisonTest hiveCacheFiles.foreach(_.delete()) } + // If this query is reading other tables that were created during this test run + // also print out the query plans and results for those. + val computedTablesMessages: String = try { + val tablesRead = new TestHive.QueryExecution(query).executedPlan.collect { + case ts: HiveTableScan => ts.relation.tableName + }.toSet + + TestHive.reset() + val executions = queryList.map(new TestHive.QueryExecution(_)) + executions.foreach(_.toRdd) + val tablesGenerated = queryList.zip(executions).flatMap { + case (q, e) => e.executedPlan.collect { + case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => + (q, e, i) + } + } + + tablesGenerated.map { case (hiveql, execution, insert) => + s""" + |=== Generated Table === + |$hiveql + |$execution + |== Results == + |${insert.child.execute().collect().mkString("\n")} + """.stripMargin + }.mkString("\n") + + } catch { + case NonFatal(e) => + logError("Failed to compute generated tables", e) + s"Couldn't compute dependent tables: $e" + } + val errorMessage = s""" |Results do not match for $testCaseName: |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")} |$resultComparison + |$computedTablesMessages """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) @@ -391,21 +440,24 @@ abstract class HiveComparisonTest case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => if (System.getProperty("spark.hive.canarytest") != null) { - // When we encounter an error we check to see if the environment is still okay by running a simple query. - // If this fails then we halt testing since something must have gone seriously wrong. + // When we encounter an error we check to see if the environment is still + // okay by running a simple query. If this fails then we halt testing since + // something must have gone seriously wrong. try { - new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() + new TestHive.QueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") - // The testing setup traps exits so wait here for a long time so the developer can see when things started - // to go wrong. + logError(s"FATAL ERROR: Canary query threw $e This implies that the " + + "testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer + // can see when things started to go wrong. Thread.sleep(1000000) } } - // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + // If the canary query didn't fail then the environment is still okay, + // so just throw the original exception. throw originalException } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 697211222b90..11d7a872dff0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest { +class HiveExplainSuite extends QueryTest with SQLTestUtils { + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext + test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") @@ -36,7 +41,7 @@ class HiveExplainSuite extends QueryTest { "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "Code Generation", "== RDD ==") + "Code Generation") } test("explain create table command") { @@ -74,4 +79,30 @@ class HiveExplainSuite extends QueryTest { "Limit", "src") } + + test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + withTempTable("jt") { + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + read.json(rdd).registerTempTable("jt") + val outputs = sql( + s""" + |EXPLAIN EXTENDED + |CREATE TABLE t1 + |AS + |SELECT * FROM jt + """.stripMargin).collect().map(_.mkString).mkString + + val shouldContain = + "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: + "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: + "CreateTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil + for (key <- shouldContain) { + assert(outputs.contains(key), s"$key doesn't exist in result") + } + + val physicalIndex = outputs.indexOf("== Physical Plan ==") + assert(!outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should not contain Subquery since it's eliminated by optimizer") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala new file mode 100644 index 000000000000..efbef68cd444 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A set of tests that validates commands can also be queried by like a table + */ +class HiveOperatorQueryableSuite extends QueryTest { + test("SPARK-5324 query result of describe command") { + loadTestTable("src") + + // register a describe command to be a temp table + sql("desc src").registerTempTable("mydesc") + checkAnswer( + sql("desc mydesc"), + Seq( + Row("col_name", "string", "name of the column"), + Row("data_type", "string", "data type of the column"), + Row("comment", "string", "comment of the column"))) + + checkAnswer( + sql("select * from mydesc"), + Seq( + Row("key", "int", null), + Row("value", "string", null))) + + checkAnswer( + sql("select col_name, data_type, comment from mydesc"), + Seq( + Row("key", "int", null), + Row("value", "string", null))) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index c939e6e99d28..ba56a8a6b689 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,16 +17,36 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { import TestHive._ + import TestHive.implicits._ test("udf constant folding") { - val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } + + test("window expressions sharing the same partition by and order by clause") { + val df = Seq.empty[(Int, String, Int, Int)].toDF("id", "grp", "seq", "val") + val window = Window. + partitionBy($"grp"). + orderBy($"val") + val query = df.select( + $"id", + sum($"val").over(window.rowsBetween(-1, 1)), + sum($"val").over(window.rangeBetween(-1, 1)) + ) + val plan = query.queryExecution.analyzed + assert(plan.collect{ case w: logical.Window => w }.size === 1, + "Should have only 1 Window operator.") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 02518d516261..f7b37dae0a5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. - * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * that should be included. Additionally, there is support for whitelisting and blacklisting + * tests as development progresses. */ abstract class HiveQueryFileTest extends HiveComparisonTest { /** A list of tests deemed out of scope and thus completely disregarded */ @@ -54,15 +55,17 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { logDebug(s"Blacklisted test skipped $testCaseName") - } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || + runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) createQueryTest(testCaseName, queriesString) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. - if(System.getProperty(whiteListProperty) == null && !runAll) + if (System.getProperty(whiteListProperty) == null && !runAll) { ignore(testCaseName) {} + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4c53b10ba96e..83f9f3eaa3a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.scalatest.BeforeAndAfter - import scala.util.Try +import org.scalatest.BeforeAndAfter + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -37,12 +37,15 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault + import org.apache.spark.sql.hive.test.TestHive.implicits._ + override def beforeAll() { TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) @@ -55,20 +58,108 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) + sql("DROP TEMPORARY FUNCTION udtf_count2") } - test("SPARK-4908: concurent hive native commands") { + test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") - sql("SHOW TABLES") + sql("SHOW DATABASES") + } + } + + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM src group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #3", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for GroupingSet", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("insert table with generator with column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + createQueryTest("insert table with generator with multiple column names", + """ + | CREATE TABLE gen_tmp (key Int, value String); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3; + | SELECT key, value FROM gen_tmp ORDER BY key, value ASC; + """.stripMargin) + + createQueryTest("insert table with generator without column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + test("multiple generators in projection") { + intercept[AnalysisException] { + sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() + } + + intercept[AnalysisException] { + sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() } } createQueryTest("! operator", """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table + | SELECT 1 AS a UNION ALL SELECT 2 AS a) t |WHERE !(a>1) """.stripMargin) @@ -77,7 +168,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { lower("AA"), "10", repeat(lower("AA"), 3), "11", lower(repeat("AA", 3)), "12", - printf("Bb%d", 12), "13", + printf("bb%d", 12), "13", repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") createQueryTest("NaN to Decimal", @@ -120,71 +211,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |FROM src LIMIT 1; """.stripMargin) - createQueryTest("count distinct 0 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 0) table - """.stripMargin) - - createQueryTest("count distinct 1 value strings", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL - | SELECT 'b' AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values including null", - """ - |SELECT COUNT(DISTINCT a, 1) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 2L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -200,6 +226,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("having no references", "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + createQueryTest("no from clause", + "SELECT 1, +1, -1") + createQueryTest("boolean = number", """ |SELECT @@ -232,7 +261,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } createQueryTest("modulus", - "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), " + + "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") @@ -253,8 +283,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("Cast Timestamp to Timestamp in UDF", """ - | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) - | FROM src LIMIT 1 + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 1", + """ + | SELECT + | CAST(CAST('1970-01-01 22:00:00' AS timestamp) AS date) == + | CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) + | FROM src LIMIT 1 """.stripMargin) createQueryTest("Simple Average", @@ -282,7 +320,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { "SELECT * FROM src a JOIN src b ON a.key = b.key") createQueryTest("small.cartesian", - "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN " + + "(SELECT key FROM src WHERE key = 2) b") createQueryTest("length.udf", "SELECT length(\"test\") FROM src LIMIT 1") @@ -307,6 +346,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |DROP DATABASE IF EXISTS testdb CASCADE """.stripMargin) + createQueryTest("create table as with db name within backticks", + """ + |CREATE DATABASE IF NOT EXISTS testdb; + |CREATE TABLE `testdb`.`createdtable` AS SELECT * FROM default.src; + |SELECT * FROM testdb.createdtable; + |DROP DATABASE IF EXISTS testdb CASCADE + """.stripMargin) + createQueryTest("insert table with db name", """ |CREATE DATABASE IF NOT EXISTS testdb; @@ -326,6 +373,25 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |SELECT * FROM createdtable; """.stripMargin) + test("SPARK-7270: consider dynamic partition when comparing table output") { + sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + + val analyzedPlan = sql( + """ + |INSERT OVERWRITE table test_partition PARTITION (b=1, c) + |SELECT 'a', 'c' from ptest + """.stripMargin).queryExecution.analyzed + + assertResult(false, "Incorrect cast detected\n" + analyzedPlan) { + var hasCast = false + analyzedPlan.collect { + case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c } + } + hasCast + } + } + createQueryTest("transform", "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") @@ -361,7 +427,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |USING 'cat' AS (tKey, tValue) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) test("transform with SerDe2") { @@ -380,7 +446,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('avro.schema.literal'='{"namespace": "testing.hive.avro.serde","name": |"src","type": "record","fields": [{"name":"key","type":"int"}]}') |FROM small_src - """.stripMargin.replaceAll("\n", " ")).collect().head + """.stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head assert(expected(0) === res(0)) } @@ -392,7 +458,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' AS (tKey, tValue) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES ('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("transform with SerDe4", """ @@ -401,8 +467,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES |('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) - + """.stripMargin.replaceAll(System.lineSeparator(), " ")) + createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") @@ -418,10 +484,10 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view2", "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") - createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + // scalastyle:off createQueryTest("lateral view4", """ |create table src_lv1 (key string, value string); @@ -431,6 +497,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX """.stripMargin) + // scalastyle:on createQueryTest("lateral view5", "FROM src SELECT explode(array(key+3, key+4))") @@ -438,8 +505,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view6", "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + createQueryTest("Specify the udtf output", + "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") } test("DataFrame toString") { @@ -501,11 +572,77 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("select null from table", "SELECT null FROM src LIMIT 1") + createQueryTest("CTE feature #1", + "with q1 as (select key from src) select * from q1 where key = 5") + + createQueryTest("CTE feature #2", + """with q1 as (select * from src where key= 5), + |q2 as (select * from src s2 where key = 4) + |select value from q1 union all select value from q2 + """.stripMargin) + + createQueryTest("CTE feature #3", + """with q1 as (select key from src) + |from q1 + |select * where key = 4 + """.stripMargin) + + // test get_json_object again Hive, because the HiveCompatabilitySuite cannot handle result + // with newline in it. + createQueryTest("get_json_object #1", + "SELECT get_json_object(src_json.json, '$') FROM src_json") + + createQueryTest("get_json_object #2", + "SELECT get_json_object(src_json.json, '$.owner'), get_json_object(src_json.json, '$.store')" + + " FROM src_json") + + createQueryTest("get_json_object #3", + "SELECT get_json_object(src_json.json, '$.store.bicycle'), " + + "get_json_object(src_json.json, '$.store.book') FROM src_json") + + createQueryTest("get_json_object #4", + "SELECT get_json_object(src_json.json, '$.store.book[0]'), " + + "get_json_object(src_json.json, '$.store.book[*]') FROM src_json") + + createQueryTest("get_json_object #5", + "SELECT get_json_object(src_json.json, '$.store.book[0].category'), " + + "get_json_object(src_json.json, '$.store.book[*].category'), " + + "get_json_object(src_json.json, '$.store.book[*].isbn'), " + + "get_json_object(src_json.json, '$.store.book[*].reader') FROM src_json") + + createQueryTest("get_json_object #6", + "SELECT get_json_object(src_json.json, '$.store.book[*].reader[0].age'), " + + "get_json_object(src_json.json, '$.store.book[*].reader[*].age') FROM src_json") + + createQueryTest("get_json_object #7", + "SELECT get_json_object(src_json.json, '$.store.basket[0][1]'), " + + "get_json_object(src_json.json, '$.store.basket[*]'), " + + // Hive returns wrong result with [*][0], so this expression is change to make test pass + "get_json_object(src_json.json, '$.store.basket[0][0]'), " + + "get_json_object(src_json.json, '$.store.basket[0][*]'), " + + "get_json_object(src_json.json, '$.store.basket[*][*]'), " + + "get_json_object(src_json.json, '$.store.basket[0][2].b'), " + + "get_json_object(src_json.json, '$.store.basket[0][*].b') FROM src_json") + + createQueryTest("get_json_object #8", + "SELECT get_json_object(src_json.json, '$.non_exist_key'), " + + "get_json_object(src_json.json, '$..no_recursive'), " + + "get_json_object(src_json.json, '$.store.book[10]'), " + + "get_json_object(src_json.json, '$.store.book[0].non_exist_key'), " + + "get_json_object(src_json.json, '$.store.basket[*].non_exist_key'), " + + "get_json_object(src_json.json, '$.store.basket[0][*].non_exist_key') FROM src_json") + + createQueryTest("get_json_object #9", + "SELECT get_json_object(src_json.json, '$.zip code') FROM src_json") + + createQueryTest("get_json_object #10", + "SELECT get_json_object(src_json.json, '$.fb:testid') FROM src_json") + test("predicates contains an empty AttributeSet() references") { sql( """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 ) table + | SELECT 1 AS a FROM src LIMIT 1 ) t |WHERE abs(20141202) is not null """.stripMargin).collect() } @@ -540,7 +677,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerTempTable("REGisteredTABle") + testData.toDF().registerTempTable("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -548,7 +685,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: DataFrame) = { + def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } @@ -556,8 +693,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-1704: Explain commands as a DataFrame") { sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val rdd = sql("explain select key, count(value) from src group by key") - assert(isExplanation(rdd)) + val df = sql("explain select key, count(value) from src group by key") + assert(isExplanation(df)) TestHive.reset() } @@ -565,7 +702,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") + TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -583,30 +720,44 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("SPARK-5383 alias for udfs with multi output columns") { + assert( + sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") + .collect() + .size == 5) + + assert( + sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + .collect() + .size == 5) + } + test("SPARK-5367: resolve star expression in udf") { assert(sql("select concat(*) from src limit 5").collect().size == 5) assert(sql("select array(*) from src limit 5").collect().size == 5) + assert(sql("select concat(key, *) from src limit 5").collect().size == 5) + assert(sql("select array(key, *) from src limit 5").collect().size == 5) } test("Query Hive native command execution result") { - val tableName = "test_native_commands" + val databaseName = "test_native_commands" assertResult(0) { - sql(s"DROP TABLE IF EXISTS $tableName").count() + sql(s"DROP DATABASE IF EXISTS $databaseName").count() } assertResult(0) { - sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + sql(s"CREATE DATABASE $databaseName").count() } assert( - sql("SHOW TABLES") + sql("SHOW DATABASES") .select('result) .collect() .map(_.getString(0)) - .contains(tableName)) + .contains(databaseName)) - assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) TestHive.reset() } @@ -699,12 +850,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerTempTable("test_describe_commands2") + testData.toDF().registerTempTable("test_describe_commands2") assertResult( Array( - Row("a", "IntegerType", null), - Row("b", "StringType", null)) + Row("a", "int", ""), + Row("b", "string", "")) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) @@ -731,18 +882,22 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatibility - // current TestSerDe.jar is from 0.12.0 - if (HiveShim.version == "0.12.0") { - sql(s"ADD JAR $testJar") - sql( - """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' - |WITH serdeproperties('s1'='9') - """.stripMargin) - } sql("DROP TABLE alter1") } + test("ADD JAR command 2") { + // this is a test case from mapjoin_addjar.q + val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") @@ -779,7 +934,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |DROP TABLE IF EXISTS dynamic_part_table; """.stripMargin) - test("Dynamic partition folder layout") { + ignore("Dynamic partition folder layout") { sql("DROP TABLE IF EXISTS dynamic_part_table") sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") sql("SET hive.exec.dynamic.partition.mode=nonstrict") @@ -800,7 +955,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .zip(parts) .map { case (k, v) => if (v == "NULL") { - s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultStrVal}" } else { s"$k=$v" } @@ -818,6 +973,22 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") { + sql(""" + |create table sc as select * + |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows) + |union all + |select '2011-01-11', '2011-01-11+15:18:26' from src tablesample (1 rows) + |union all + |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s + """.stripMargin) + sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") + sql("set hive.exec.dynamic.partition=true") + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("insert overwrite table sc_part partition(ts) select * from sc") + sql("drop table sc_part") + } + test("Partition spec validation") { sql("DROP TABLE IF EXISTS dp_test") sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") @@ -843,8 +1014,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") - sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") + sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") sql( """ @@ -921,14 +1092,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: DataFrame): Set[(String, String)] = - rdd.collect().map { + def collectResults(df: DataFrame): Set[Any] = + df.collect().map { case Row(key: String, value: String) => key -> value - case Row(KV(key, value)) => key -> value + case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc) }.toSet conf.clear() + val expectedConfs = conf.getAllDefinedConfs.toSet + assertResult(expectedConfs)(collectResults(sql("SET -v"))) + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -939,16 +1112,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET -v")) - } // "SET key" assertResult(Set(testKey -> testVal)) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 422e843d2b0d..b08db6de2d2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -17,27 +17,37 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { - case class NestedData(a: Seq[NestedData2], B: NestedData2) - case class NestedData2(a: NestedData3, B: NestedData3) - case class NestedData3(a: Int, B: Int) - test("SPARK-3698: case insensitive test for nested data") { - sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested") + read.json(sparkContext.makeRDD( + """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } + test("SPARK-5278: check ambiguous reference to fields") { + read.json(sparkContext.makeRDD( + """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") + + // there are 2 filed matching field name "b", we should report Ambiguous reference error + val exception = intercept[AnalysisException] { + sql("SELECT a[0].b from nested").queryExecution.analyzed + } + assert(exception.getMessage.contains("Ambiguous reference to fields")) + } + createQueryTest("table.attr", "SELECT src.key FROM src ORDER BY key LIMIT 1") @@ -67,8 +77,8 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), @@ -78,18 +88,27 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("nestedRepeatedTest") + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) + .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } + createQueryTest("test ambiguousReferences resolved as hive", + """ + |CREATE TABLE t1(x INT); + |CREATE TABLE t2(a STRUCT, k INT); + |INSERT OVERWRITE TABLE t1 SELECT 1 FROM src LIMIT 1; + |INSERT OVERWRITE TABLE t2 SELECT named_struct("x",1),1 FROM src LIMIT 1; + |SELECT a.x FROM t1 a JOIN t2 b ON a.x = b.k; + """.stripMargin) + /** * Negative examples. Currently only left here for documentation purposes. * TODO(marmbrus): Test that catalyst fails on these queries. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7486bfa82b00..5586a793618b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -25,17 +25,25 @@ import org.apache.spark.sql.hive.test.TestHive * A set of tests that validates support for Hive SerDe. */ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { - - override def beforeAll() = { + override def beforeAll(): Unit = { + import TestHive._ + import org.apache.hadoop.hive.serde2.RegexSerDe + super.beforeAll() TestHive.cacheTables = false + sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin) + sql(s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales") } + // table sales is not a cache table, and will be clear after reset + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales", false) + createQueryTest( "Read and write with LazySimpleSerDe (tab separated)", "SELECT * from serdeins") - createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") - createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 8fb5e050a237..2209fc2f30a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils @@ -60,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } - + test("Spark-4077: timestamp query for null value") { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( @@ -70,12 +71,12 @@ class HiveTableScanSuite extends HiveComparisonTest { FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' """.stripMargin) - val location = + val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() - === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index ab0e0443c7fa..197e9bfb02c4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,8 +35,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { val nullVal = "null" baseTypes.init.foreach { i => - createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") - createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") + createQueryTest(s"case when then $i else $nullVal end ", + s"SELECT case when true then $i else $nullVal end FROM src limit 1") + createQueryTest(s"case when then $nullVal else $i end ", + s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { @@ -57,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala new file mode 100644 index 000000000000..9c10ffe1113d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.{DataInput, DataOutput} +import java.util.{ArrayList, Arrays, Properties} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.io.Writable +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHive + +import org.apache.spark.util.Utils + +case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) + +// Case classes for the custom UDF's. +case class IntegerCaseClass(i: Int) +case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) +case class StringCaseClass(s: String) +case class ListStringCaseClass(l: Seq[String]) + +/** + * A test suite for Hive custom UDFs. + */ +class HiveUDFSuite extends QueryTest { + + import TestHive.{udf, sql} + import TestHive.implicits._ + + test("spark sql udf test that returns a struct") { + udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + assert(sql( + """ + |SELECT getStruct(1).f1, + | getStruct(1).f2, + | getStruct(1).f3, + | getStruct(1).f4, + | getStruct(1).f5 FROM src LIMIT 1 + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) + } + + test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { + checkAnswer( + sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), + Row(8) + ) + } + + test("hive struct udf") { + sql( + """ + |CREATE EXTERNAL TABLE hiveUDFTestTable ( + | pair STRUCT + |) + |PARTITIONED BY (partition STRING) + |ROW FORMAT SERDE '%s' + |STORED AS SEQUENCEFILE + """. + stripMargin.format(classOf[PairSerDe].getName)) + + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile + sql(s""" + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') + LOCATION '$location'""") + + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + } + + test("Max/Min on named_struct") { + def testOrderInStruct(): Unit = { + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + } + val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) + TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) + testOrderInStruct() + TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) + testOrderInStruct() + TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + } + + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + + test("SPARK-2693 udaf aggregates test") { + checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) + + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) + } + + test("Generic UDAF aggregates") { + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) + + checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) + } + + test("UDFIntegerToString") { + val testData = TestHive.sparkContext.parallelize( + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() + testData.registerTempTable("integerTable") + + val udfName = classOf[UDFIntegerToString].getName + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") + checkAnswer( + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), + Seq(Row("1"), Row("2"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + + TestHive.reset() + } + + test("UDFToListString") { + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToListString(s) FROM inputTable") + } + assert(errMsg.getMessage contains "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") + TestHive.reset() + } + + test("UDFToListInt") { + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToListInt(s) FROM inputTable") + } + assert(errMsg.getMessage contains "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") + TestHive.reset() + } + + test("UDFToStringIntMap") { + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + + s"AS '${classOf[UDFToStringIntMap].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToStringIntMap(s) FROM inputTable") + } + assert(errMsg.getMessage contains "Map type in java is unsupported because " + + "JVM type erasure makes spark fail to catch key and value types in Map<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") + TestHive.reset() + } + + test("UDFToIntIntMap") { + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + + s"AS '${classOf[UDFToIntIntMap].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToIntIntMap(s) FROM inputTable") + } + assert(errMsg.getMessage contains "Map type in java is unsupported because " + + "JVM type erasure makes spark fail to catch key and value types in Map<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") + TestHive.reset() + } + + test("UDFListListInt") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() + testData.registerTempTable("listListIntTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + checkAnswer( + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), + Seq(Row(0), Row(2), Row(13))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + + TestHive.reset() + } + + test("UDFListString") { + val testData = TestHive.sparkContext.parallelize( + ListStringCaseClass(Seq("a", "b", "c")) :: + ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() + testData.registerTempTable("listStringTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + checkAnswer( + sql("SELECT testUDFListString(l) FROM listStringTable"), + Seq(Row("a,b,c"), Row("d,e"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + + TestHive.reset() + } + + test("UDFStringString") { + val testData = TestHive.sparkContext.parallelize( + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() + testData.registerTempTable("stringTable") + + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + checkAnswer( + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), + Seq(Row("hello world"), Row("hello goodbye"))) + + checkAnswer( + sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"), + Seq(Row(" hello world"), Row(" hello goodbye"))) + + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") + + TestHive.reset() + } + + test("UDFTwoListList") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: + Nil).toDF() + testData.registerTempTable("TwoListTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + checkAnswer( + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), + Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + + TestHive.reset() + } +} + +class TestPair(x: Int, y: Int) extends Writable with Serializable { + def this() = this(0, 0) + var entry: (Int, Int) = (x, y) + + override def write(output: DataOutput): Unit = { + output.writeInt(entry._1) + output.writeInt(entry._2) + } + + override def readFields(input: DataInput): Unit = { + val x = input.readInt() + val y = input.readInt() + entry = (x, y) + } +} + +class PairSerDe extends AbstractSerDe { + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getObjectInspector: ObjectInspector = { + ObjectInspectorFactory + .getStandardStructObjectInspector( + Arrays.asList("pair"), + Arrays.asList(ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + )) + } + + override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] + + override def getSerDeStats: SerDeStats = null + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null + + override def deserialize(value: Writable): AnyRef = { + val pair = value.asInstanceOf[TestPair] + + val row = new ArrayList[ArrayList[AnyRef]] + row.add(new ArrayList[AnyRef](2)) + row.get(0).add(Integer.valueOf(pair.entry._1)) + row.get(0).add(Integer.valueOf(pair.entry._2)) + + row + } +} + +class PairUDF extends GenericUDF { + override def initialize(p1: Array[ObjectInspector]): ObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) + ) + + override def evaluate(args: Array[DeferredObject]): AnyRef = { + Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) + } + + override def getDisplayString(p1: Array[String]): String = "" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala deleted file mode 100644 index dd0df1a9f632..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import java.io.{DataInput, DataOutput} -import java.util -import java.util.Properties - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} -import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHive - -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ - -case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) - -// Case classes for the custom UDF's. -case class IntegerCaseClass(i: Int) -case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) -case class StringCaseClass(s: String) -case class ListStringCaseClass(l: Seq[String]) - -/** - * A test suite for Hive custom UDFs. - */ -class HiveUdfSuite extends QueryTest { - import TestHive._ - - test("spark sql udf test that returns a struct") { - udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) - assert(sql( - """ - |SELECT getStruct(1).f1, - | getStruct(1).f2, - | getStruct(1).f3, - | getStruct(1).f4, - | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).head() === Row(1, 2, 3, 4, 5)) - } - - test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { - checkAnswer( - sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), - Row(8) - ) - } - - test("hive struct udf") { - sql( - """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( - | pair STRUCT - |) - |PARTITIONED BY (partition STRING) - |ROW FORMAT SERDE '%s' - |STORED AS SEQUENCEFILE - """. - stripMargin.format(classOf[PairSerDe].getName)) - - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile - sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') - LOCATION '$location'""") - - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") - } - - test("SPARK-2693 udaf aggregates test") { - checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src").collect().toSeq) - - checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), - sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) - } - - test("Generic UDAF aggregates") { - checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - - checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), - sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) - } - - test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( - IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) - testData.registerTempTable("integerTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") - checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), - Seq(Row("1"), Row("2"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - - TestHive.reset() - } - - test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( - ListListIntCaseClass(Nil) :: - ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) - testData.registerTempTable("listListIntTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") - checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), - Seq(Row(0), Row(2), Row(13))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - - TestHive.reset() - } - - test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( - ListStringCaseClass(Seq("a", "b", "c")) :: - ListStringCaseClass(Seq("d", "e")) :: Nil) - testData.registerTempTable("listStringTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") - checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), - Seq(Row("a,b,c"), Row("d,e"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - - TestHive.reset() - } - - test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( - StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) - testData.registerTempTable("stringTable") - - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") - checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), - Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") - - TestHive.reset() - } - - test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( - ListListIntCaseClass(Nil) :: - ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: - Nil) - testData.registerTempTable("TwoListTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") - checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), - Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - - TestHive.reset() - } -} - -class TestPair(x: Int, y: Int) extends Writable with Serializable { - def this() = this(0, 0) - var entry: (Int, Int) = (x, y) - - override def write(output: DataOutput): Unit = { - output.writeInt(entry._1) - output.writeInt(entry._2) - } - - override def readFields(input: DataInput): Unit = { - val x = input.readInt() - val y = input.readInt() - entry = (x, y) - } -} - -class PairSerDe extends AbstractSerDe { - override def initialize(p1: Configuration, p2: Properties): Unit = {} - - override def getObjectInspector: ObjectInspector = { - ObjectInspectorFactory - .getStandardStructObjectInspector( - Seq("pair"), - Seq(ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector)) - )) - } - - override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] - - override def getSerDeStats: SerDeStats = null - - override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null - - override def deserialize(value: Writable): AnyRef = { - val pair = value.asInstanceOf[TestPair] - - val row = new util.ArrayList[util.ArrayList[AnyRef]] - row.add(new util.ArrayList[AnyRef](2)) - row(0).add(Integer.valueOf(pair.entry._1)) - row(0).add(Integer.valueOf(pair.entry._2)) - - row - } -} - -class PairUdf extends GenericUDF { - override def initialize(p1: Array[ObjectInspector]): ObjectInspector = - ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) - ) - - override def evaluate(args: Array[DeferredObject]): AnyRef = { - println("Type = %s".format(args(0).getClass.getName)) - Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) - } - - override def getDisplayString(p1: Array[String]): String = "" -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8474d850c9c6..210d56674541 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.TestHive -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * A set of test cases that validate partition and column pruning. */ @@ -82,16 +81,16 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { Seq.empty) createPruningTest("Column pruning - non-trivial top project with aliases", - "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", - Seq("double"), + "SELECT c1 * 2 AS dbl FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("dbl"), Seq("key"), Seq.empty) // Partition pruning tests createPruningTest("Partition pruning - non-partitioned, non-trivial project", - "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", - Seq("double"), + "SELECT key * 2 AS dbl FROM src WHERE value IS NOT NULL", + Seq("dbl"), Seq("key", "value"), Seq.empty) @@ -143,27 +142,28 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { sql: String, expectedOutputColumns: Seq[String], expectedScannedColumns: Seq[String], - expectedPartValues: Seq[Seq[String]]) = { + expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan + val plan = new TestHive.QueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) + val partValues = if (relation.table.isPartitioned) { + p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) + } else { + Seq.empty + } (columnNames, partValues) }.head assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") - assert( - actualPartValues.length === expectedPartValues.length, - "Partition value count mismatches") + val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted + val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted - for ((actual, expected) <- actualPartValues.zip(expectedPartValues)) { - assert(actual sameElements expected, "Partition values mismatch") - } + assert(actualPartitions === expectedPartitions, "Partitions selected do not match") } // Creates a query test to compare query results generated by Hive and Catalyst. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4efe0c5e0cd4..1ff1d9a2934c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,22 +17,243 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest +import java.sql.{Date, Timestamp} -import org.apache.spark.sql.Row +import scala.collection.JavaConverters._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} +import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) case class Nested3(f3: Int) +case class NestedArray2(b: Seq[Int]) +case class NestedArray1(a: NestedArray2) + +case class Order( + id: Int, + make: String, + `type`: String, + price: Int, + pdate: String, + customer: String, + city: String, + state: String, + month: Int) + +case class WindowData( + month: Int, + area: String, + product: Int) +/** A SQL Dialect for testing purpose, and it can not be nested type */ +class MyDialect extends DefaultParserDialect + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest { +class SQLQuerySuite extends QueryTest with SQLTestUtils { + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext + + test("UDTF") { + sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + + test("SPARK-6835: udtf in lateral view") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + } + + test("SPARK-6851: Self-joined converted parquet tables") { + val orders = Seq( + Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(3, "Swift", "MTB", 285, "2015-01-17", "John S", "Redwood City", "CA", 20151), + Order(4, "Atlas", "Hybrid", 303, "2015-01-23", "Jones S", "San Mateo", "CA", 20151), + Order(7, "Next", "MTB", 356, "2015-01-04", "Jane D", "Daly City", "CA", 20151), + Order(10, "Next", "YFlikr", 187, "2015-01-09", "John D", "Fremont", "CA", 20151), + Order(11, "Swift", "YFlikr", 187, "2015-01-23", "John D", "Hayward", "CA", 20151), + Order(2, "Next", "Hybrid", 324, "2015-02-03", "Jane D", "Daly City", "CA", 20152), + Order(5, "Next", "Street", 187, "2015-02-08", "John D", "Fremont", "CA", 20152), + Order(6, "Atlas", "Street", 154, "2015-02-09", "John D", "Pacifica", "CA", 20152), + Order(8, "Swift", "Hybrid", 485, "2015-02-19", "John S", "Redwood City", "CA", 20152), + Order(9, "Atlas", "Split", 303, "2015-02-28", "Jones S", "San Mateo", "CA", 20152)) + + val orderUpdates = Seq( + Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) + + orders.toDF.registerTempTable("orders1") + orderUpdates.toDF.registerTempTable("orderupdates1") + + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } + + test("show functions") { + val allFunctions = + (FunctionRegistry.builtin.listFunction().toSet[String] ++ + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted + checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + checkAnswer(sql("SHOW functions abs"), Row("abs")) + checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) + checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + // this probably will failed if we add more function with `sha` prefixing. + checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + } + + test("describe functions") { + // The Spark SQL built-in functions + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql')", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + + checkExistence(sql("describe functioN `~`"), true, + "Function: ~", + "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", + "Usage: ~ n - Bitwise not") + } + + test("SPARK-5371: union with null and sum") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + val query = sql( + """ + |SELECT + | MIN(c1), + | MIN(c2) + |FROM ( + | SELECT + | SUM(c1) c1, + | NULL c2 + | FROM table1 + | UNION ALL + | SELECT + | NULL c1, + | SUM(c2) c2 + | FROM table1 + |) a + """.stripMargin) + checkAnswer(query, Row(1, 1) :: Nil) + } + + test("CTAS with WITH clause") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + sql( + """ + |CREATE TABLE with_table1 AS + |WITH T AS ( + | SELECT * + | FROM table1 + |) + |SELECT * + |FROM T + """.stripMargin) + val query = sql("SELECT * FROM with_table1") + checkAnswer(query, Row(1, 1) :: Nil) + } + + test("explode nested Field") { + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") + checkAnswer( + sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), @@ -40,6 +261,106 @@ class SQLQuerySuite extends QueryTest { ) } + test("CTAS without serde") { + def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + relation match { + case LogicalRelation(r: ParquetRelation) => + if (!isDataSourceParquet) { + fail( + s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + + s"${ParquetRelation.getClass.getCanonicalName}.") + } + + case r: MetastoreRelation => + if (isDataSourceParquet) { + fail( + s"${ParquetRelation.getClass.getCanonicalName} is expected, but found " + + s"${classOf[MetastoreRelation].getCanonicalName}.") + } + } + } + + val originalConf = convertCTAS + + setConf(HiveContext.CONVERT_CTAS, true) + + try { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE IF EXISTS ctas1") + } + } + + test("SQL Dialect Switching") { + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(getSQLDialect().getClass === classOf[MyDialect]) + assert(sql("SELECT 1").collect() === Array(Row(1))) + + // set the dialect back to the DefaultSQLDialect + sql("SET spark.sql.dialect=sql") + assert(getSQLDialect().getClass === classOf[DefaultParserDialect]) + sql("SET spark.sql.dialect=hiveql") + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + + // set invalid dialect + sql("SET spark.sql.dialect.abc=MyTestClass") + sql("SET spark.sql.dialect=abc") + intercept[Exception] { + sql("SELECT 1") + } + // test if the dialect set back to HiveQLDialect + getSQLDialect().getClass === classOf[HiveQLDialect] + + sql("SET spark.sql.dialect=MyTestClass") + intercept[DialectException] { + sql("SELECT 1") + } + // test if the dialect set back to HiveQLDialect + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( @@ -86,7 +407,7 @@ class SQLQuerySuite extends QueryTest { SELECT key, value FROM src ORDER BY key, value""").collect().toSeq) - intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { + intercept[AnalysisException] { sql( """CREATE TABLE ctas4 AS | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() @@ -100,8 +421,50 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" + "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) + + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + } + + // use the Hive SerDe for parquet tables + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + } + } + + test("specifying the column list for CTAS") { + Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1") + + sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select key, value from mytable1").collect()) + sql("DROP TABLE gen__tmp") + + sql("create table gen__tmp(a double, b double) as select key, value from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select cast(key as double), cast(value as double) from mytable1").collect()) + sql("DROP TABLE gen__tmp") + + sql("drop table mytable1") } test("command substitution") { @@ -141,7 +504,8 @@ class SQLQuerySuite extends QueryTest { } test("double nested data") { - sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + .toDF().registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) @@ -151,7 +515,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) - intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + intercept[AnalysisException] { sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() } } @@ -159,17 +523,17 @@ class SQLQuerySuite extends QueryTest { test("test CTAS") { checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) checkAnswer( - sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } test("SPARK-4825 save join to table") { - val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") - testData.insertInto("test1") + testData.write.mode(SaveMode.Append).insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") - testData.insertInto("test2") - testData.insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), @@ -233,6 +597,12 @@ class SQLQuerySuite extends QueryTest { } } + test("SPARK-4699 HiveContext should be case insensitive by default") { + checkAnswer( + sql("SELECT KEY FROM Src ORDER BY value"), + sql("SELECT key FROM src ORDER BY value").collect().toSeq) + } + test("SPARK-5284 Insert into Hive throws NPE when a inner complex type field has a null value") { val schema = StructType( StructField("s", @@ -244,7 +614,7 @@ class SQLQuerySuite extends QueryTest { val rowRdd = sparkContext.parallelize(row :: Nil) - applySchema(rowRdd, schema).registerTempTable("testTable") + TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -270,16 +640,537 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4296 Grouping field with Hive UDF as sub expression") { val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") } + + test("resolve udtf in projection #1") { + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + read.json(rdd).registerTempTable("data") + val df = sql("SELECT explode(a) AS val FROM data") + val col = df("val") + } + + test("resolve udtf in projection #2") { + val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) + read.json(rdd).registerTempTable("data") + checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) + checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) + intercept[AnalysisException] { + sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") + } + + intercept[AnalysisException] { + sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") + } + } + + // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive + test("TGF with non-TGF in projection") { + val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) + read.json(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), + Row("1", "1", "1", "1") :: Nil) + } + + test("logical.Project should not be resolved if it contains aggregates or generators") { + // This test is used to test the fix of SPARK-5875. + // The original issue was that Project's resolved will be true when it contains + // AggregateExpressions or Generators. However, in this case, the Project + // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of + // PreInsertionCasts will actually start to work before ImplicitGenerate and then + // generates an invalid query plan. + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + read.json(rdd).registerTempTable("data") + val originalConf = convertCTAS + setConf(HiveContext.CONVERT_CTAS, false) + + try { + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } + + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) + + sql("DROP TABLE explodeTest") + dropTempTable("data") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + } + } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 5)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } + + test("test script transform for stdout") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) + } + + test("test script transform for stderr") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(0 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) + } + + test("test script transform data type") { + val data = (1 to 5).map { i => (i, i) } + data.toDF("key", "value").registerTempTable("test") + checkAnswer( + sql("""FROM + |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |SELECT thing1 + 1 + """.stripMargin), (2 to 6).map(i => Row(i))) + } + + test("window function: udaf with aggregate expressin") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product), sum(sum(product)) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 5, 11), + ("a", 6, 11), + ("b", 7, 15), + ("b", 8, 15), + ("c", 9, 19), + ("c", 10, 19) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, sum(product) - 1, sum(sum(product)) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 4, 11), + ("a", 5, 11), + ("b", 6, 15), + ("b", 7, 15), + ("c", 8, 19), + ("c", 9, 19) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, sum(product), sum(product) / sum(sum(product)) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 5, 5d/11), + ("a", 6, 6d/11), + ("b", 7, 7d/15), + ("b", 8, 8d/15), + ("c", 10, 10d/19), + ("c", 9, 9d/19) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, sum(product), sum(product) / sum(sum(product) - 1) over (partition by area) + |from windowData group by month, area + """.stripMargin), + Seq( + ("a", 5, 5d/9), + ("a", 6, 6d/9), + ("b", 7, 7d/13), + ("b", 8, 8d/13), + ("c", 10, 10d/17), + ("c", 9, 9d/17) + ).map(i => Row(i._1, i._2, i._3))) + } + + test("window function: partition and order expressions") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select month, area, product, sum(product + 1) over (partition by 1 order by 2) + |from windowData + """.stripMargin), + Seq( + (1, "a", 5, 51), + (2, "a", 6, 51), + (3, "b", 7, 51), + (4, "b", 8, 51), + (5, "c", 9, 51), + (6, "c", 10, 51) + ).map(i => Row(i._1, i._2, i._3, i._4))) + + checkAnswer( + sql( + """ + |select month, area, product, sum(product) + |over (partition by month % 2 order by 10 - product) + |from windowData + """.stripMargin), + Seq( + (1, "a", 5, 21), + (2, "a", 6, 24), + (3, "b", 7, 16), + (4, "b", 8, 18), + (5, "c", 9, 9), + (6, "c", 10, 10) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("window function: expressions in arguments of a window functions") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select month, area, month % 2, + |lag(product, 1 + 1, product) over (partition by month % 2 order by area) + |from windowData + """.stripMargin), + Seq( + (1, "a", 1, 5), + (2, "a", 0, 6), + (3, "b", 1, 7), + (4, "b", 0, 8), + (5, "c", 1, 5), + (6, "c", 0, 6) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.registerTempTable("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + dropTempTable("nums") + } + + test("test case key when") { + (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") + checkAnswer( + sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), + Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil) + } + + test("SPARK-7595: Window will cause resolve failed with self join") { + sql("SELECT * FROM src") // Force loading of src table. + + checkAnswer(sql( + """ + |with + | v1 as (select key, count(value) over (partition by key) cnt_val from src), + | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) + | select * from v2 order by key limit 1 + """.stripMargin), Row(0, 3)) + } + + test("SPARK-7269 Check analysis failed in case in-sensitive") { + Seq(1, 2, 3).map { i => + (i.toString, i.toString) + }.toDF("key", "value").registerTempTable("df_analysis") + sql("SELECT kEy from df_analysis group by key").collect() + sql("SELECT kEy+3 from df_analysis group by key+3").collect() + sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + sql("SELECT 2 from df_analysis A group by key+1").collect() + intercept[AnalysisException] { + sql("SELECT kEy+1 from df_analysis group by key+3") + } + intercept[AnalysisException] { + sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + } + } + + test("Cast STRING to BIGINT") { + checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) + } + + // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest + test("udf_java_method") { + checkAnswer(sql( + """ + |SELECT java_method("java.lang.String", "valueOf", 1), + | java_method("java.lang.String", "isEmpty"), + | java_method("java.lang.Math", "max", 2, 3), + | java_method("java.lang.Math", "min", 2, 3), + | java_method("java.lang.Math", "round", 2.5), + | java_method("java.lang.Math", "exp", 1.0), + | java_method("java.lang.Math", "floor", 1.9) + |FROM src tablesample (1 rows) + """.stripMargin), + Row( + "1", + "true", + java.lang.Math.max(2, 3).toString, + java.lang.Math.min(2, 3).toString, + java.lang.Math.round(2.5).toString, + java.lang.Math.exp(1.0).toString, + java.lang.Math.floor(1.9).toString)) + } + + test("dynamic partition value test") { + try { + sql("set hive.exec.dynamic.partition.mode=nonstrict") + // date + sql("drop table if exists dynparttest1") + sql("create table dynparttest1 (value int) partitioned by (pdate date)") + sql( + """ + |insert into table dynparttest1 partition(pdate) + | select count(*), cast('2015-05-21' as date) as pdate from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest1"), + Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) + + // decimal + sql("drop table if exists dynparttest2") + sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + sql( + """ + |insert into table dynparttest2 partition(pdec) + | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest2"), + Seq(Row(500, new java.math.BigDecimal("100.1")))) + } finally { + sql("drop table if exists dynparttest1") + sql("drop table if exists dynparttest2") + sql("set hive.exec.dynamic.partition.mode=strict") + } + } + + test("Call add jar in a different thread (SPARK-8306)") { + @volatile var error: Option[Throwable] = None + val thread = new Thread { + override def run() { + // To make sure this test works, this jar should not be loaded in another place. + TestHive.sql( + s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + try { + TestHive.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + } catch { + case throwable: Throwable => + error = Some(throwable) + } + } + } + thread.start() + thread.join() + error match { + case Some(throwable) => + fail("CREATE TEMPORARY FUNCTION should not fail.", throwable) + case None => // OK + } + } + + test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { + checkAnswer( + sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + Row(false)) + } + + test("SPARK-6785: HiveQuerySuite - Date cast") { + // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST + checkAnswer( + sql( + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 + """.stripMargin), + Row( + Date.valueOf("1969-12-31"), + String.valueOf("1969-12-31"), + Timestamp.valueOf("1969-12-31 16:00:00"), + String.valueOf("1969-12-31 16:00:00"), + Timestamp.valueOf("1970-01-01 00:00:00"))) + + } + + test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { + val df = + TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + df.toDF("id", "datef").registerTempTable("test_SPARK8588") + checkAnswer( + TestHive.sql( + """ + |select id, concat(year(datef)) + |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') + """.stripMargin), + Row(1, "2014") :: Row(2, "2015") :: Nil + ) + TestHive.dropTempTable("test_SPARK8588") + } + + test("SPARK-9371: fix the support for special chars in column names for hive context") { + TestHive.read.json(TestHive.sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } + + test("Convert hive interval term into Literal of CalendarIntervalType") { + checkAnswer(sql("select interval '10-9' year to month"), + Row(CalendarInterval.fromString("interval 10 years 9 months"))) + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + + "32 seconds 99 milliseconds 899 microseconds"))) + checkAnswer(sql("select interval '30' year"), + Row(CalendarInterval.fromString("interval 30 years"))) + checkAnswer(sql("select interval '25' month"), + Row(CalendarInterval.fromString("interval 25 months"))) + checkAnswer(sql("select interval '-100' day"), + Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) + checkAnswer(sql("select interval '40' hour"), + Row(CalendarInterval.fromString("interval 1 days 16 hours"))) + checkAnswer(sql("select interval '80' minute"), + Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) + checkAnswer(sql("select interval '299.889987299' second"), + Row(CalendarInterval.fromString( + "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) + } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala new file mode 100644 index 000000000000..9aca40f15ac1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types.StringType + +class ScriptTransformationSuite extends SparkPlanTest { + + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext + + private val noSerdeIOSchema = HiveScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + schemaLess = false + ) + + private val serdeIOSchema = noSerdeIOSchema.copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) + + test("cat without SerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("cat with LazySimpleSerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } + + test("script transformation should not swallow errors from upstream operators (with serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } +} + +private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + override def output: Seq[Attribute] = child.output +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..deec0048d24b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.sources.HadoopFsRelationTest +import org.apache.spark.sql.types._ + +class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write + .orc(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.options(Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala new file mode 100644 index 000000000000..a46ca9a2c970 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.BeforeAndAfterAll + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +// The data where the partitioning key exists only in the directory structure. +case class OrcParData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal + + def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) + } + + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.mode("overwrite").orc(path.getCanonicalPath) + } + + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally TestHive.dropTempTable(tableName) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read.orc(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read.orc(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .orc(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .orc(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } +} + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala new file mode 100644 index 000000000000..744d46293814 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +case class AllDataTypesWithNonPrimitiveType( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapValueContainsNull: Map[Int, Option[Long]], + data: (Seq[Int], (Int, String))) + +case class BinaryData(binaryData: Array[Byte]) + +case class Contact(name: String, phone: String) + +case class Person(name: String, age: Int, contacts: Seq[Contact]) + +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + test("Read/write All Types") { + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) + } + + withOrcFile(data) { file => + checkAnswer( + sqlContext.read.orc(file), + data.toDF().collect()) + } + } + + test("Read/write binary data") { + withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => + val bytes = read.orc(file).head().getAs[Array[Byte]](0) + assert(new String(bytes, "utf8") === "test") + } + } + + test("Read/write all types with non-primitive type") { + val data = (0 to 255).map { i => + AllDataTypesWithNonPrimitiveType( + s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + 0 until i, + (0 until i).map(Option(_).filter(_ % 3 == 0)), + (0 until i).map(i => i -> i.toLong).toMap, + (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None), + (0 until i, (i, s"$i"))) + } + + withOrcFile(data) { file => + checkAnswer( + read.orc(file), + data.toDF().collect()) + } + } + + test("Creating case class RDD table") { + val data = (1 to 100).map(i => (i, s"val_$i")) + sparkContext.parallelize(data).toDF().registerTempTable("t") + withTempTable("t") { + checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) + } + } + + test("Simple selection form ORC table") { + val data = (1 to 10).map { i => + Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") }) + } + + withOrcTable(data, "t") { + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = leaf-0 + assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = (not leaf-0) + assertResult(10) { + sql("SELECT name, contacts FROM t where age > 5") + .flatMap(_.getAs[Seq[_]]("contacts")) + .count() + } + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // leaf-1 = (LESS_THAN age 8) + // expr = (and (not leaf-0) leaf-1) + { + val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + assert(df.count() === 2) + assertResult(4) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + + // ppd: + // leaf-0 = (LESS_THAN age 2) + // leaf-1 = (LESS_THAN_EQUALS age 8) + // expr = (or leaf-0 (not leaf-1)) + { + val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + assert(df.count() === 3) + assertResult(6) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + } + } + + test("save and load case class RDD with `None`s as orc") { + val data = ( + None: Option[Int], + None: Option[Long], + None: Option[Float], + None: Option[Double], + None: Option[Boolean] + ) :: Nil + + withOrcFile(data) { file => + checkAnswer( + read.orc(file), + Row(Seq.fill(5)(null): _*)) + } + } + + // We only support zlib in Hive 0.12.0 now + test("Default compression options for writing to an ORC file") { + withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => + assertResult(CompressionKind.ZLIB) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { + val data = (1 to 100).map(i => (i, s"val_$i")) + val conf = sparkContext.hadoopConfiguration + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") + withOrcFile(data) { file => + assertResult(CompressionKind.SNAPPY) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") + withOrcFile(data) { file => + assertResult(CompressionKind.NONE) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") + withOrcFile(data) { file => + assertResult(CompressionKind.LZO) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + } + + test("simple select queries") { + withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer( + sql("SELECT `_1` FROM t where t.`_1` > 5"), + (6 until 10).map(Row.apply(_))) + + checkAnswer( + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withOrcTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("columns only referenced by pushed down filters should remain") { + withOrcTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in orc") { + withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + List(Row("same", "run_5", 100))) + } + } + + test("SPARK-8501: Avoids discovery schema from empty ORC files") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempTable("empty", "single") { + sqlContext.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.registerTempTable("empty") + + // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val errorMessage = intercept[AnalysisException] { + sqlContext.read.orc(path) + }.getMessage + + assert(errorMessage.contains("Failed to discover schema from ORC files")) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.registerTempTable("single") + + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = sqlContext.read.orc(path) + assert(df.schema === singleRowDF.schema.asNullable) + checkAnswer(df, singleRowDF) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala new file mode 100644 index 000000000000..80c38084f293 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, Row} + +case class OrcData(intField: Int, stringField: String) + +abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { + var orcTableDir: File = null + var orcTableAsDir: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + + orcTableAsDir = File.createTempFile("orctests", "sparksql") + orcTableAsDir.delete() + orcTableAsDir.mkdir() + + // Hack: to prepare orc data files using hive external tables + orcTableDir = File.createTempFile("orctests", "sparksql") + orcTableDir.delete() + orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + sparkContext + .makeRDD(1 to 10) + .map(i => OrcData(i, s"part-$i")) + .toDF() + .registerTempTable(s"orc_temp_table") + + sql( + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.getCanonicalPath}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + } + + override def afterAll(): Unit = { + orcTableDir.delete() + orcTableAsDir.delete() + } + + test("create temporary orc table") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) + } + + test("create temporary orc table as") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) + } + + test("appending insert") { + sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + Seq.fill(2)(Row(i, s"part-$i")) + }) + } + + test("overwrite insert") { + sql( + """INSERT OVERWRITE TABLE normal_orc_as_source + |SELECT * FROM orc_temp_table WHERE intField > 5 + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM normal_orc_as_source"), + (6 to 10).map(i => Row(i, s"part-$i"))) + } + + test("write null values") { + sql("DROP TABLE IF EXISTS orcNullValues") + + val df = sql( + """ + |SELECT + | CAST(null as TINYINT), + | CAST(null as SMALLINT), + | CAST(null as INT), + | CAST(null as BIGINT), + | CAST(null as FLOAT), + | CAST(null as DOUBLE), + | CAST(null as DECIMAL(7,2)), + | CAST(null as TIMESTAMP), + | CAST(null as DATE), + | CAST(null as STRING), + | CAST(null as VARCHAR(10)) + |FROM orc_temp_table limit 1 + """.stripMargin) + + df.write.format("orc").saveAsTable("orcNullValues") + + checkAnswer( + sql("SELECT * FROM orcNullValues"), + Row.fromSeq(Seq.fill(11)(null))) + + sql("DROP TABLE IF EXISTS orcNullValues") + } +} + +class OrcSourceSuite extends OrcSuite { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala new file mode 100644 index 000000000000..f7ba20ff41d8 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.test.SQLTestUtils + +private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => + protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive + protected val sqlContext = _sqlContext + import sqlContext.implicits._ + import sqlContext.sparkContext + + /** + * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withOrcFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Orc file will be deleted after `f` returns. + */ + protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + withOrcFile(data)(path => f(sqlContext.read.orc(path))) + } + + /** + * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Orc file will be dropped/deleted after `f` returns. + */ + protected def withOrcTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withOrcDataFrame(data) { df => + sqlContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala new file mode 100644 index 000000000000..34d3434569f5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -0,0 +1,886 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) +// The data that also includes the partitioning key +case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) + +case class StructContainer(intStructField: Int, stringStructField: String) + +case class ParquetDataWithComplexTypes( + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +case class ParquetDataWithKeyAndComplexTypes( + p: Int, + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +/** + * A suite to test the automatic conversion of metastore tables with parquet data to use the + * built in parquet support. + */ +class ParquetMetastoreSuite extends ParquetPartitioningTest { + override def beforeAll(): Unit = { + super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + sql(s""" + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' + """) + + sql(s""" + create external table partitioned_parquet_with_key + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDirWithKey.getCanonicalPath}' + """) + + sql(s""" + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(normalTableDir, "normal").getCanonicalPath}' + """) + + sql(s""" + CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes + ( + intField INT, + stringField STRING, + structField STRUCT, + arrayField ARRAY + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}' + """) + + sql(s""" + CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes + ( + intField INT, + stringField STRING, + structField STRUCT, + arrayField ARRAY + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' + """) + + sql( + """ + |create table test_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") + } + + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + read.json(rdd1).registerTempTable("jt") + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) + read.json(rdd2).registerTempTable("jt_array") + + setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) + } + + override def afterAll(): Unit = { + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) + } + + test(s"conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: PhysicalRDD => true + }.nonEmpty) + } + + test("scan an empty parquet table") { + checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) + } + + test("scan an empty parquet table with upper case") { + checkAnswer(sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) + } + + test("insert into an empty parquet table") { + dropTables("test_insert_parquet") + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // Insert into am empty table. + sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), + Row(6, "str6") :: Row(7, "str7") :: Nil + ) + // Insert overwrite. + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + Row(3, "str3") :: Row(4, "str4") :: Nil + ) + dropTables("test_insert_parquet") + + // Create it again. + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + // Insert overwrite an empty table. + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + Row(3, "str3") :: Row(4, "str4") :: Nil + ) + // Insert into the table. + sql("insert into table test_insert_parquet select a, b from jt") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet"), + (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) + ) + dropTables("test_insert_parquet") + } + + test("scan a parquet table created through a CTAS statement") { + withTable("test_parquet_ctas") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) + + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(_: ParquetRelation) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation].getCanonicalName }") + } + } + } + + test("MetastoreRelation in InsertIntoTable will be converted") { + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"However, found a ${o.toString} ") + } + + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) + } + } + + test("MetastoreRelation in InsertIntoHiveTable will be converted") { + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } + + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) + } + } + + test("SPARK-6450 regression test") { + withTable("ms_convert") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed + + assertResult(2) { + analyzed.collect { + case r@LogicalRelation(_: ParquetRelation) => r + }.size + } + } + } + + def collectParquetRelation(df: DataFrame): ParquetRelation = { + val plan = df.queryExecution.analyzed + plan.collectFirst { + case LogicalRelation(r: ParquetRelation) => r + }.getOrElse { + fail(s"Expecting a ParquetRelation2, but got:\n$plan") + } + } + + test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { + withTable("nonPartitioned") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + } + } + + test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { + withTable("partitioned") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + } + } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifier) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifier) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (`date` string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (`date`='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (`date`='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifier) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") + } +} + +/** + * A suite of tests for the Parquet support through the data sources API. + */ +class ParquetSourceSuite extends ParquetPartitioningTest { + override def beforeAll(): Unit = { + super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet") + + sql( s""" + create temporary table partitioned_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDir.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table partitioned_parquet_with_key + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKey.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + ) + """) + + sql( s""" + CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' + ) + """) + + sql( s""" + CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithComplexTypes.getCanonicalPath}' + ) + """) + } + + test("SPARK-6016 make sure to use the latest footers") { + sql("drop table if exists spark_6016_fix") + + // Create a DataFrame with two partitions. So, the created table will have two parquet files. + val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") + checkAnswer( + sql("select * from spark_6016_fix"), + (1 to 10).map(i => Row(i)) + ) + + // Create a DataFrame with four partitions. So, the created table will have four parquet files. + val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") + // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, + // since the new table has four parquet files, we are trying to read new footers from two files + // and then merge metadata in footers of these four (two outdated ones and two latest one), + // which will cause an error. + checkAnswer( + sql("select * from spark_6016_fix"), + (1 to 10).map(i => Row(i)) + ) + + sql("drop table spark_6016_fix") + } + + test("SPARK-8811: compatibility with array of struct in Hive") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("array_of_struct") { + val conf = Seq( + HiveContext.CONVERT_METASTORE_PARQUET.key -> "false", + SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") + + withSQLConf(conf: _*) { + sql( + s"""CREATE TABLE array_of_struct + |STORED AS PARQUET LOCATION '$path' + |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + """.stripMargin) + + checkAnswer( + sqlContext.read.parquet(path), + Row("1st", "2nd", Seq(Row("val_a", "val_b")))) + } + } + } + } + + test("values in arrays and maps stored in parquet are always nullable") { + val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") + val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) + val arrayType1 = ArrayType(IntegerType, containsNull = false) + val expectedSchema1 = + StructType( + StructField("m", mapType1, nullable = true) :: + StructField("a", arrayType1, nullable = true) :: Nil) + assert(df.schema === expectedSchema1) + + withTable("alwaysNullable") { + df.write.format("parquet").saveAsTable("alwaysNullable") + + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) + + assert(table("alwaysNullable").schema === expectedSchema2) + + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + } + } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + val tempDir = Utils.createTempDir() + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[Throwable](df2.write.parquet(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.write.parquet(filePath2) + val df4 = read.parquet(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } +} + +/** + * A collection of tests for parquet data with various forms of partitioning. + */ +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext + + var partitionedTableDir: File = null + var normalTableDir: File = null + var partitionedTableDirWithKey: File = null + var partitionedTableDirWithComplexTypes: File = null + var partitionedTableDirWithKeyAndComplexTypes: File = null + + override def beforeAll(): Unit = { + partitionedTableDir = Utils.createTempDir() + normalTableDir = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .toDF() + .write.parquet(partDir.getCanonicalPath) + } + + sparkContext + .makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-1")) + .toDF() + .write.parquet(new File(normalTableDir, "normal").getCanonicalPath) + + partitionedTableDirWithKey = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKey, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetDataWithKey(p, i, s"part-$p")) + .toDF() + .write.parquet(partDir.getCanonicalPath) + } + + partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithKeyAndComplexTypes( + p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().write.parquet(partDir.getCanonicalPath) + } + + partitionedTableDirWithComplexTypes = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().write.parquet(partDir.getCanonicalPath) + } + } + + override protected def afterAll(): Unit = { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } + + /** + * Drop named tables if they exist + * @param tableNames tables to drop + */ + def dropTables(tableNames: String*): Unit = { + tableNames.foreach { name => + sql(s"DROP TABLE IF EXISTS $name") + } + } + + Seq( + "partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes").foreach { table => + + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) + } + + test(s"pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + Row(10)) + } + + test(s"non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } + + test(s"sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } + } + + Seq( + "partitioned_parquet_with_key_and_complextypes", + "partitioned_parquet_with_complextypes").foreach { table => + + test(s"SPARK-5775 read struct from $table") { + checkAnswer( + sql( + s""" + |SELECT p, structField.intStructField, structField.stringStructField + |FROM $table WHERE p = 1 + """.stripMargin), + (1 to 10).map(i => Row(1, i, f"${i}_string"))) + } + + // Re-enable this after SPARK-5508 is fixed + ignore(s"SPARK-5775 read array from $table") { + checkAnswer( + sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), + (1 to 10).map(i => Row(1 to i, 1))) + } + } + + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + Row(10)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala deleted file mode 100644 index 79fd99d9f89f..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ /dev/null @@ -1,271 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import org.apache.spark.sql.catalyst.expressions.Row -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive._ - -// The data where the partitioning key exists only in the directory structure. -case class ParquetData(intField: Int, stringField: String) -// The data that also includes the partitioning key -case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) - - -/** - * A suite to test the automatic conversion of metastore tables with parquet data to use the - * built in parquet support. - */ -class ParquetMetastoreSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { - super.beforeAll() - - sql(s""" - create external table partitioned_parquet - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDir.getCanonicalPath}' - """) - - sql(s""" - create external table partitioned_parquet_with_key - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDirWithKey.getCanonicalPath}' - """) - - sql(s""" - create external table normal_parquet - ( - intField INT, - stringField STRING - ) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' - """) - - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") - } - - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") - } - - setConf("spark.sql.hive.convertMetastoreParquet", "true") - } - - override def afterAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "false") - } - - test("conversion is working") { - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: HiveTableScan => true - }.isEmpty) - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true - }.nonEmpty) - } -} - -/** - * A suite of tests for the Parquet support through the data sources API. - */ -class ParquetSourceSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { - super.beforeAll() - - sql( s""" - create temporary table partitioned_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDir.getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table partitioned_parquet_with_key - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDirWithKey.getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table normal_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' - ) - """) - } -} - -/** - * A collection of tests for parquet data with various forms of partitioning. - */ -abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { - var partitionedTableDir: File = null - var partitionedTableDirWithKey: File = null - - override def beforeAll(): Unit = { - partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithKey.delete() - partitionedTableDirWithKey.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithKey, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetDataWithKey(p, i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - } - - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) - } - - test(s"pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - Row(10)) - } - - test(s"non-existant partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } - } - - test("non-part select(*)") { - checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), - Row(10)) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala new file mode 100644 index 000000000000..b4640b161628 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils + + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext + + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..8ca3a1708519 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.math.BigDecimal + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } + + test("SPARK-10196: save decimal type to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("decimal", DecimalType(7, 2)) + + val data = + Row(new BigDecimal("10.02")) :: + Row(new BigDecimal("20000.99")) :: + Row(new BigDecimal("10000")) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..cb4cedddbfdd --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + + +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "parquet" + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write.parquet(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[AnalysisException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..e8975e5f5cd0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala new file mode 100644 index 000000000000..e8141923a9b5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.text.NumberFormat +import java.util.UUID + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Row, SQLContext} + +/** + * A simple example [[HadoopFsRelationProvider]]. + */ +class SimpleTextSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new SimpleTextRelation(paths, schema, partitionColumns, parameters)(sqlContext) + } +} + +class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { + val numberFormat = NumberFormat.getInstance() + + numberFormat.setMinimumIntegerDigits(5) + numberFormat.setGroupingUsed(false) + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + val name = FileOutputFormat.getOutputName(context) + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") + } +} + +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { + private val recordWriter: RecordWriter[NullWritable, Text] = + new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) + + override def write(row: Row): Unit = { + val serialized = row.toSeq.map(_.toString).mkString(",") + recordWriter.write(null, new Text(serialized)) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} + +/** + * A simple example [[HadoopFsRelation]], used for testing purposes. Data are stored as comma + * separated string lines. When scanning data, schema must be explicitly provided via data source + * option `"dataSchema"`. + */ +class SimpleTextRelation( + override val paths: Array[String], + val maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient val sqlContext: SQLContext) + extends HadoopFsRelation { + + import sqlContext.sparkContext + + override val dataSchema: StructType = + maybeDataSchema.getOrElse(DataType.fromJson(parameters("dataSchema")).asInstanceOf[StructType]) + + override def equals(other: Any): Boolean = other match { + case that: SimpleTextRelation => + this.paths.sameElements(that.paths) && + this.maybeDataSchema == that.maybeDataSchema && + this.dataSchema == that.dataSchema && + this.partitionColumns == that.partitionColumns + + case _ => false + } + + override def hashCode(): Int = + Objects.hashCode(paths, maybeDataSchema, dataSchema, partitionColumns) + + override def buildScan(inputStatuses: Array[FileStatus]): RDD[Row] = { + val fields = dataSchema.map(_.dataType) + + sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => + Row(record.split(",").zip(fields).map { case (value, dataType) => + // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) + val catalystValue = Cast(Literal(value), dataType).eval() + // Here we're converting Catalyst values to Scala values to test `needsConversion` + CatalystTypeConverters.convertToScala(catalystValue, dataType) + }: _*) + } + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) + } + } +} + +/** + * A simple example [[HadoopFsRelationProvider]]. + */ +class CommitFailureTestSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) + } +} + +class CommitFailureTestRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient sqlContext: SQLContext) + extends SimpleTextRelation( + paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) { + override def close(): Unit = { + super.close() + sys.error("Intentional task commitment failure for testing purpose.") + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala new file mode 100644 index 000000000000..7966b43596e7 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -0,0 +1,629 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.ParquetOutputCommitter + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + + +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext + import sqlContext.implicits._ + + val dataSourceName: String + + val dataSchema = + StructType( + Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false))) + + lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + + lazy val partitionedTestDF1 = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") + + lazy val partitionedTestDF2 = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") + + lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + + def checkQueries(df: DataFrame): Unit = { + // Selects everything + checkAnswer( + df, + for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) + + // Simple filtering and partition pruning + checkAnswer( + df.filter('a > 1 && 'p1 === 2), + for (i <- 2 to 3; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", 2, p2)) + + // Simple projection and filtering + checkAnswer( + df.filter('a > 1).select('b, 'a + 1), + for (i <- 2 to 3; _ <- 1 to 2; _ <- Seq("foo", "bar")) yield Row(s"val_$i", i + 1)) + + // Simple projection and partition pruning + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + + // Project many copies of columns with different types (reproduction for SPARK-7858) + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) + yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) + + // Self-join + df.registerTempTable("t") + withTempTable("t") { + checkAnswer( + sql( + """SELECT l.a, r.b, l.p1, r.p2 + |FROM t l JOIN t r + |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 + """.stripMargin), + for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) + } + } + + test("save()/load() - non-partitioned table - Overwrite") { + withTempPath { file => + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + + checkAnswer( + sqlContext.read.format(dataSourceName) + .option("path", file.getCanonicalPath) + .option("dataSchema", dataSchema.json) + .load(), + testDF.collect()) + } + } + + test("save()/load() - non-partitioned table - Append") { + withTempPath { file => + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) + + checkAnswer( + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath).orderBy("a"), + testDF.unionAll(testDF).orderBy("a").collect()) + } + } + + test("save()/load() - non-partitioned table - ErrorIfExists") { + withTempDir { file => + intercept[AnalysisException] { + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) + } + } + } + + test("save()/load() - non-partitioned table - Ignore") { + withTempDir { file => + testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) + + val path = new Path(file.getCanonicalPath) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.listStatus(path).isEmpty) + } + } + + test("save()/load() - partitioned table - simple queries") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkQueries( + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath)) + } + } + + test("save()/load() - partitioned table - Overwrite") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkAnswer( + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), + partitionedTestDF.collect()) + } + } + + test("save()/load() - partitioned table - Append") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkAnswer( + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), + partitionedTestDF.unionAll(partitionedTestDF).collect()) + } + } + + test("save()/load() - partitioned table - Append - new partition values") { + withTempPath { file => + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkAnswer( + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), + partitionedTestDF.collect()) + } + } + + test("save()/load() - partitioned table - ErrorIfExists") { + withTempDir { file => + intercept[AnalysisException] { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + } + } + } + + test("save()/load() - partitioned table - Ignore") { + withTempDir { file => + partitionedTestDF.write + .format(dataSourceName).mode(SaveMode.Ignore).save(file.getCanonicalPath) + + val path = new Path(file.getCanonicalPath) + val fs = path.getFileSystem(SparkHadoopUtil.get.conf) + assert(fs.listStatus(path).isEmpty) + } + } + + test("saveAsTable()/load() - non-partitioned table - Overwrite") { + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), testDF.collect()) + } + } + + test("saveAsTable()/load() - non-partitioned table - Append") { + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite).saveAsTable("t") + testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + } + } + + test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + intercept[AnalysisException] { + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") + } + } + } + + test("saveAsTable()/load() - non-partitioned table - Ignore") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") + assert(sqlContext.table("t").collect().isEmpty) + } + } + + test("saveAsTable()/load() - partitioned table - simple queries") { + partitionedTestDF.write.format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") + + withTable("t") { + checkQueries(sqlContext.table("t")) + } + } + + test("saveAsTable()/load() - partitioned table - Overwrite") { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append") { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append - new partition values") { + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + // Using only a subset of all partition columns + intercept[Throwable] { + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1") + .saveAsTable("t") + } + } + + test("saveAsTable()/load() - partitioned table - ErrorIfExists") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + intercept[AnalysisException] { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + } + } + } + + test("saveAsTable()/load() - partitioned table - Ignore") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Ignore) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + assert(sqlContext.table("t").collect().isEmpty) + } + } + + test("Hadoop style globbing") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + val df = sqlContext.read + .format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(s"${file.getCanonicalPath}/p1=*/p2=???") + + val expectedPaths = Set( + s"${file.getCanonicalFile}/p1=1/p2=foo", + s"${file.getCanonicalFile}/p1=2/p2=foo", + s"${file.getCanonicalFile}/p1=1/p2=bar", + s"${file.getCanonicalFile}/p1=2/p2=bar" + ).map { p => + val path = new Path(p) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString + } + + val actualPaths = df.queryExecution.analyzed.collectFirst { + case LogicalRelation(relation: HadoopFsRelation) => + relation.paths.toSet + }.getOrElse { + fail("Expect an FSBasedRelation, but none could be found") + } + + assert(actualPaths === expectedPaths) + checkAnswer(df, partitionedTestDF.collect()) + } + } + + // HadoopFsRelation.discoverPartitions() called by refresh(), which will ignore + // the given partition data type. + ignore("Partition column type casting") { + withTempPath { file => + val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2) + + input + .write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("ps", "p2") + .saveAsTable("t") + + withTempTable("t") { + checkAnswer(sqlContext.table("t"), input.collect()) + } + } + } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) + } + } + + // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores + // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or + // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this + // requirement. We probably want to move this test case to spark-integration-tests or spark-perf + // later. + test("SPARK-8406: Avoids name collision while writing files") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext + .range(10000) + .repartition(250) + .write + .mode(SaveMode.Overwrite) + .format(dataSourceName) + .save(path) + + assertResult(10000) { + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) + .load(path) + .count() + } + } + } + + test("SPARK-8578 specified custom output committer will not be used to append data") { + val clonedConf = new Configuration(configuration) + try { + val df = sqlContext.range(1, 10).toDF("i") + withTempPath { dir => + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there data already exists, + // this append should succeed because we will use the output committer associated + // with file format and AlwaysFailOutputCommitter will not be used. + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext.read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df.unionAll(df)) + + // This will fail because AlwaysFailOutputCommitter is used when we do append. + intercept[Exception] { + df.write.mode("overwrite").format(dataSourceName).save(dir.getCanonicalPath) + } + } + withTempPath { dir => + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there is no existing data, + // this append will fail because AlwaysFailOutputCommitter is used when we do append + // and there is no existing data. + intercept[Exception] { + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + } + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { + val df = Seq( + (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), + (2, "v2", Array(4, 5, 6), Map("k2" -> "v2"), Tuple2(2, "5")), + (3, "v3", Array(7, 8, 9), Map("k3" -> "v3"), Tuple2(3, "6"))).toDF("a", "b", "c", "d", "e") + withTempDir { file => + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").save(file.getCanonicalPath) + } + } + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") + } + } + + test("SPARK-9899 Disable customized output committer when speculation is on") { + val clonedConf = new Configuration(configuration) + val speculationEnabled = + sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + + try { + withTempPath { dir => + // Enables task speculation + sqlContext.sparkContext.conf.set("spark.speculation", "true") + + // Uses a customized output committer which always fails + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + + // Code below shouldn't throw since customized output committer should be disabled. + val df = sqlContext.range(10).coalesce(1) + df.write.format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) + } + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailParquetOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } +} diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala deleted file mode 100644 index 254919e8f6fd..000000000000 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.{ArrayList => JArrayList, Properties} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions - -import org.apache.hadoop.{io => hadoopIo} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.stats.StatsSetupConst -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat - -import org.apache.spark.sql.types.{Decimal, DecimalType} - -case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { - // for Serialization - def this() = this(null) - - import org.apache.spark.util.Utils._ - def createFunction[UDFType <: AnyRef](): UDFType = { - getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - } -} - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[hive] object HiveShim { - val version = "0.12.0" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) - } - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, - getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, - getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, - getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, - getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, - getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, - getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, - getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, - getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, - getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, - getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, - getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DECIMAL, - getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.VOID, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[String] - - def processResults(results: JArrayList[String]) = results - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd(0), conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - new HiveDecimal(bd) - } - - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - ColumnProjectionUtils.appendReadColumnIDs(conf, ids) - ColumnProjectionUtils.appendReadColumnNames(conf, names) - } - - def getExternalTmpPath(context: Context, uri: URI) = { - context.getExternalTmpFileURI(uri) - } - - def getDataLocationPath(p: Partition) = p.getPartitionPath - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) - - def compatibilityBlackList = Seq( - "decimal_.*", - "udf7", - "drop_partitions_filter2", - "show_.*", - "serde_regex", - "udf_to_date", - "udaf_collect_set", - "udf_concat" - ) - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) - } - - def decimalMetastoreString(decimalType: DecimalType): String = "decimal" - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = - TypeInfoFactory.decimalTypeInfo - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - DecimalType.Unlimited - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - } - } - - def prepareWritable(w: Writable): Writable = { - w - } -} - -class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) - extends FileSinkDesc(dir, tableInfo, compressed) { -} diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala deleted file mode 100644 index 45ca59ae56a3..000000000000 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ /dev/null @@ -1,446 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.util.{ArrayList => JArrayList} -import java.util.Properties -import java.rmi.server.UID - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable -import org.apache.hadoop.{io => hadoopIo} - -import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} - - -/** - * This class provides the UDF creation and also the UDF instance serialization and - * de-serialization cross process boundary. - * - * Detail discussion can be found at https://github.com/apache/spark/pull/3640 - * - * @param functionClassName UDF class name - */ -case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable { - // for Serialization - def this() = this(null) - - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo - import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF - - @transient - private val methodDeSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "deserializeObjectByKryo", - classOf[Kryo], - classOf[InputStream], - classOf[Class[_]]) - method.setAccessible(true) - - method - } - - @transient - private val methodSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "serializeObjectByKryo", - classOf[Kryo], - classOf[Object], - classOf[OutputStream]) - method.setAccessible(true) - - method - } - - def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz) - .asInstanceOf[UDFType] - } - - def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out) - } - - private var instance: AnyRef = null - - def writeExternal(out: java.io.ObjectOutput) { - // output the function name - out.writeUTF(functionClassName) - - // Write a flag if instance is null or not - out.writeBoolean(instance != null) - if (instance != null) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) - } - } - - def readExternal(in: java.io.ObjectInput) { - // read the function name - functionClassName = in.readUTF() - - if (in.readBoolean()) { - // if the instance is not null - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) - - // deserialize the function object via Hive Utilities - instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), - getContextOrSparkClassLoader.loadClass(functionClassName)) - } - } - - def createFunction[UDFType <: AnyRef](): UDFType = { - if (instance != null) { - instance.asInstanceOf[UDFType] - } else { - val func = getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - if (!func.isInstanceOf[UDF]) { - // We cache the function if it's no the Simple UDF, - // as we always have to create new instance for Simple UDF - instance = func - } - func - } - } -} - -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[hive] object HiveShim { - val version = "0.13.1" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(inputFormatClass, outputFormatClass, properties) - } - - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.voidTypeInfo, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - } - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - } - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - } - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - } - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[Object] - - def processResults(results: JArrayList[Object]) = { - results.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { - context.runSqlHive("CREATE DATABASE default") - context.runSqlHive("USE default") - } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd, conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - HiveDecimal.create(bd) - } - - /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug - */ - private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { - val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") - val result: StringBuilder = new StringBuilder(old) - var first: Boolean = old.isEmpty - - for (col <- cols) { - if (first) { - first = false - } else { - result.append(',') - } - result.append(col) - } - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) - } - - /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty - */ - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { - ColumnProjectionUtils.appendReadColumns(conf, ids) - } - if (names != null && names.size > 0) { - appendReadColumnNames(conf, names) - } - } - - def getExternalTmpPath(context: Context, path: Path) = { - context.getExternalTmpPath(path.toUri) - } - - def getDataLocationPath(p: Partition) = p.getDataLocation - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl) - - def compatibilityBlackList = Seq() - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation())) - } - - /* - * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ - implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) - f.setCompressCodec(w.compressCodec) - f.setCompressType(w.compressType) - f.setTableInfo(w.tableInfo) - f.setDestTableId(w.destTableId) - f - } - - // Precision and scale to pass for unlimited decimals; these are the same as the precision and - // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) - private val UNLIMITED_DECIMAL_PRECISION = 38 - private val UNLIMITED_DECIMAL_SCALE = 18 - - def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { - case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" - case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" - } - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { - case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) - } - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] - DecimalType(info.precision(), info.scale()) - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, - hdoi.precision(), hdoi.scale()) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) - } - } - - /* - * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that - * is needed to initialize before serialization. - */ - def prepareWritable(w: Writable): Writable = { - w match { - case w: AvroGenericRecordWritable => - w.setRecordReaderID(new UID()) - case _ => - } - w - } -} - -/* - * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. - * Fix it through wrapper. - */ -class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) - extends Serializable with Logging { - var compressCodec: String = _ - var compressType: String = _ - var destTableId: Int = _ - - def setCompressed(compressed: Boolean) { - this.compressed = compressed - } - - def getDirName = dir - - def setDestTableId(destTableId: Int) { - this.destTableId = destTableId - } - - def setTableInfo(tableInfo: TableDesc) { - this.tableInfo = tableInfo - } - - def setCompressCodec(intermediateCompressorCodec: String) { - compressCodec = intermediateCompressorCodec - } - - def setCompressType(intermediateCompressType: String) { - compressType = intermediateCompressType - } -} diff --git a/streaming/pom.xml b/streaming/pom.xml index 1e92ba686a57..697895e72fe5 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + @@ -82,6 +89,11 @@ junit test + + org.seleniumhq.selenium + selenium-java + test + com.novocode junit-interface @@ -92,34 +104,6 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - test-jar-on-test-compile - test-compile - - test-jar - - - - - org.apache.maven.plugins maven-shade-plugin @@ -128,13 +112,5 @@ - - - ../python - - pyspark/streaming/*.py - - - diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java new file mode 100644 index 000000000000..3738fc1a235c --- /dev/null +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +/** + * :: DeveloperApi :: + * + * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming + * to save the received data (by receivers) and associated metadata to a reliable storage, so that + * they can be recovered after driver failures. See the Spark documentation for more information + * on how to plug in your own custom implementation of a write ahead log. + */ +@org.apache.spark.annotation.DeveloperApi +public abstract class WriteAheadLog { + /** + * Write the record to the log and return a record handle, which contains all the information + * necessary to read back the written record. The time is used to the index the record, + * such that it can be cleaned later. Note that implementations of this abstract class must + * ensure that the written data is durable and readable (using the record handle) by the + * time this function returns. + */ + abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + + /** + * Read a written record based on the given record handle. + */ + abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + + /** + * Read and return an iterator of all the records that have been written but not yet cleaned up. + */ + abstract public Iterator readAll(); + + /** + * Clean all the records that are older than the threshold time. It can wait for + * the completion of the deletion. + */ + abstract public void clean(long threshTime, boolean waitForCompletion); + + /** + * Close this log and release any resources. + */ + abstract public void close(); +} diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java new file mode 100644 index 000000000000..662889e779fb --- /dev/null +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util; + +/** + * :: DeveloperApi :: + * + * This abstract class represents a handle that refers to a record written in a + * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. + * It must contain all the information necessary for the record to be read and returned by + * an implemenation of the WriteAheadLog class. + * + * @see org.apache.spark.streaming.util.WriteAheadLog + */ +@org.apache.spark.annotation.DeveloperApi +public abstract class WriteAheadLogRecordHandle implements java.io.Serializable { +} diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css new file mode 100644 index 000000000000..ec12616b58d8 --- /dev/null +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +.graph { + font: 10px sans-serif; +} + +.axis path, .axis line { + fill: none; + stroke: gray; + shape-rendering: crispEdges; +} + +.axis text { + fill: gray; +} + +.tooltip-inner { + max-width: 500px !important; /* Make sure we only have one line tooltip */ +} + +.line { + fill: none; + stroke: #0088cc; + stroke-width: 1.5px; +} + +.bar rect { + fill: #0088cc; + shape-rendering: crispEdges; +} + +.bar rect:hover { + fill: #00c2ff; +} + +.timeline { + width: 500px; +} + +.histogram { + width: auto; +} + +span.expand-input-rate { + cursor: pointer; +} + +tr.batch-table-cell-highlight > td { + background-color: #D6FFE4 !important; +} diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js new file mode 100644 index 000000000000..4886b68eeaf7 --- /dev/null +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +// timeFormat: StreamingPage.scala will generate a global "timeFormat" dictionary to store the time +// and its formatted string. Because we cannot specify a timezone in JavaScript, to make sure the +// server and client use the same timezone, we use the "timeFormat" dictionary to format all time +// values used in the graphs. + +// A global margin left for all timeline graphs. It will be set in "registerTimeline". This will be +// used to align all timeline graphs. +var maxMarginLeftForTimeline = 0; + +// The max X values for all histograms. It will be set in "registerHistogram". +var maxXForHistogram = 0; + +var histogramBinCount = 10; +var yValueFormat = d3.format(",.2f"); + +var unitLabelYOffset = -10; + +// Show a tooltip "text" for "node" +function showBootstrapTooltip(node, text) { + $(node).tooltip({title: text, trigger: "manual", container: "body"}); + $(node).tooltip("show"); +} + +// Hide the tooltip for "node" +function hideBootstrapTooltip(node) { + $(node).tooltip("destroy"); +} + +// Register a timeline graph. All timeline graphs should be register before calling any +// "drawTimeline" so that we can determine the max margin left for all timeline graphs. +function registerTimeline(minY, maxY) { + var numOfChars = yValueFormat(maxY).length; + // A least width for "maxY" in the graph + var pxForMaxY = numOfChars * 8 + 10; + // Make sure we have enough space to show the ticks in the y axis of timeline + maxMarginLeftForTimeline = pxForMaxY > maxMarginLeftForTimeline? pxForMaxY : maxMarginLeftForTimeline; +} + +// Register a histogram graph. All histogram graphs should be register before calling any +// "drawHistogram" so that we can determine the max X value for histograms. +function registerHistogram(values, minY, maxY) { + var data = d3.layout.histogram().range([minY, maxY]).bins(histogramBinCount)(values); + // d.x is the y values while d.y is the x values + var maxX = d3.max(data, function(d) { return d.y; }); + maxXForHistogram = maxX > maxXForHistogram ? maxX : maxXForHistogram; +} + +// Draw a line between (x1, y1) and (x2, y2) +function drawLine(svg, xFunc, yFunc, x1, y1, x2, y2) { + var line = d3.svg.line() + .x(function(d) { return xFunc(d.x); }) + .y(function(d) { return yFunc(d.y); }); + var data = [{x: x1, y: y1}, {x: x2, y: y2}]; + svg.append("path") + .datum(data) + .style("stroke-dasharray", ("6, 6")) + .style("stroke", "lightblue") + .attr("class", "line") + .attr("d", line); +} + +/** + * @param id the `id` used in the html `div` tag + * @param data the data for the timeline graph + * @param minX the min value of X axis + * @param maxX the max value of X axis + * @param minY the min value of Y axis + * @param maxY the max value of Y axis + * @param unitY the unit of Y axis + * @param batchInterval if "batchInterval" is specified, we will draw a line for "batchInterval" in the graph + */ +function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { + // Hide the right border of " + + + + } + + protected def baseRow(batch: BatchUIData): Seq[Node] = { + val batchTime = batch.batchTime.milliseconds + val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) + val eventCount = batch.numRecords + val schedulingDelay = batch.schedulingDelay + val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val processingTime = batch.processingDelay + val formattedProcessingTime = processingTime.map(SparkUIUtils.formatDuration).getOrElse("-") + val batchTimeId = s"batch-$batchTime" + + + + + + } + + private def batchTable: Seq[Node] = { +
    {info.userName} + {jobLink} + {info.groupId}{formatDate(info.startTimestamp)}{if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)}{formatDurationOption(Some(info.totalTime))}{info.statement}{info.state}
    {errorSummary}{details}
    {session.userName} {session.ip} {session.sessionId} {formatDate(session.startTimestamp)} {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} {formatDurationOption(Some(session.totalTime))} {session.totalExecution.toString}
    {d}
    {info.userName} + {jobLink} + {info.groupId}{formatDate(info.startTimestamp)}{formatDate(info.finishTimestamp)}{formatDurationOption(Some(info.totalTime))}{info.statement}{info.state}
    {errorSummary}{details}
    {d}
    ". We cannot use "css" directly, or "sorttable.js" will override them. + d3.select(d3.select(id).node().parentNode) + .style("padding", "8px 0 8px 8px") + .style("border-right", "0px solid white"); + + var margin = {top: 20, right: 27, bottom: 30, left: maxMarginLeftForTimeline}; + var width = 500 - margin.left - margin.right; + var height = 150 - margin.top - margin.bottom; + + var x = d3.scale.linear().domain([minX, maxX]).range([0, width]); + var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); + + var xAxis = d3.svg.axis().scale(x).orient("bottom").tickFormat(function(d) { + var formattedDate = timeFormat[d]; + var dotIndex = formattedDate.indexOf('.'); + if (dotIndex >= 0) { + // Remove milliseconds + return formattedDate.substring(0, dotIndex); + } else { + return formattedDate; + } + }); + var formatYValue = d3.format(",.2f"); + var yAxis = d3.svg.axis().scale(y).orient("left").ticks(5).tickFormat(formatYValue); + + var line = d3.svg.line() + .x(function(d) { return x(d.x); }) + .y(function(d) { return y(d.y); }); + + var svg = d3.select(id).append("svg") + .attr("width", width + margin.left + margin.right) + .attr("height", height + margin.top + margin.bottom) + .append("g") + .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); + + // Only show the first and last time in the graph + xAxis.tickValues(x.domain()); + + svg.append("g") + .attr("class", "x axis") + .attr("transform", "translate(0," + height + ")") + .call(xAxis) + + svg.append("g") + .attr("class", "y axis") + .call(yAxis) + .append("text") + .attr("transform", "translate(0," + unitLabelYOffset + ")") + .text(unitY); + + + if (batchInterval && batchInterval <= maxY) { + drawLine(svg, x, y, minX, batchInterval, maxX, batchInterval); + } + + svg.append("path") + .datum(data) + .attr("class", "line") + .attr("d", line); + + // If the user click one point in the graphs, jump to the batch row and highlight it. And + // recovery the batch row after 3 seconds if necessary. + // We need to remember the last clicked batch so that we can recovery it. + var lastClickedBatch = null; + var lastTimeout = null; + + // Add points to the line. However, we make it invisible at first. But when the user moves mouse + // over a point, it will be displayed with its detail. + svg.selectAll(".point") + .data(data) + .enter().append("circle") + .attr("stroke", "white") // white and opacity = 0 make it invisible + .attr("fill", "white") + .attr("opacity", "0") + .style("cursor", "pointer") + .attr("cx", function(d) { return x(d.x); }) + .attr("cy", function(d) { return y(d.y); }) + .attr("r", function(d) { return 3; }) + .on('mouseover', function(d) { + var tip = formatYValue(d.y) + " " + unitY + " at " + timeFormat[d.x]; + showBootstrapTooltip(d3.select(this).node(), tip); + // show the point + d3.select(this) + .attr("stroke", "steelblue") + .attr("fill", "steelblue") + .attr("opacity", "1"); + }) + .on('mouseout', function() { + hideBootstrapTooltip(d3.select(this).node()); + // hide the point + d3.select(this) + .attr("stroke", "white") + .attr("fill", "white") + .attr("opacity", "0"); + }) + .on("click", function(d) { + if (lastTimeout != null) { + window.clearTimeout(lastTimeout); + } + if (lastClickedBatch != null) { + clearBatchRow(lastClickedBatch); + lastClickedBatch = null; + } + lastClickedBatch = d.x; + highlightBatchRow(lastClickedBatch) + lastTimeout = window.setTimeout(function () { + lastTimeout = null; + if (lastClickedBatch != null) { + clearBatchRow(lastClickedBatch); + lastClickedBatch = null; + } + }, 3000); // Clean up after 3 seconds + + var batchSelector = $("#batch-" + d.x); + var topOffset = batchSelector.offset().top - 15; + if (topOffset < 0) { + topOffset = 0; + } + $('html,body').animate({scrollTop: topOffset}, 200); + }); +} + +/** + * @param id the `id` used in the html `div` tag + * @param values the data for the histogram graph + * @param minY the min value of Y axis + * @param maxY the max value of Y axis + * @param unitY the unit of Y axis + * @param batchInterval if "batchInterval" is specified, we will draw a line for "batchInterval" in the graph + */ +function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { + // Hide the left border of "". We cannot use "css" directly, or "sorttable.js" will override them. + d3.select(d3.select(id).node().parentNode) + .style("padding", "8px 8px 8px 0") + .style("border-left", "0px solid white"); + + var margin = {top: 20, right: 30, bottom: 30, left: 10}; + var width = 350 - margin.left - margin.right; + var height = 150 - margin.top - margin.bottom; + + var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width - 50]); + var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); + + var xAxis = d3.svg.axis().scale(x).orient("top").ticks(5); + var yAxis = d3.svg.axis().scale(y).orient("left").ticks(0).tickFormat(function(d) { return ""; }); + + var data = d3.layout.histogram().range([minY, maxY]).bins(histogramBinCount)(values); + + var svg = d3.select(id).append("svg") + .attr("width", width + margin.left + margin.right) + .attr("height", height + margin.top + margin.bottom) + .append("g") + .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); + + if (batchInterval && batchInterval <= maxY) { + drawLine(svg, x, y, 0, batchInterval, maxXForHistogram, batchInterval); + } + + svg.append("g") + .attr("class", "x axis") + .call(xAxis) + .append("text") + .attr("transform", "translate(" + (margin.left + width - 45) + ", " + unitLabelYOffset + ")") + .text("#batches"); + + svg.append("g") + .attr("class", "y axis") + .call(yAxis) + + var bar = svg.selectAll(".bar") + .data(data) + .enter() + .append("g") + .attr("transform", function(d) { return "translate(0," + (y(d.x) - height + y(d.dx)) + ")";}) + .attr("class", "bar").append("rect") + .attr("width", function(d) { return x(d.y); }) + .attr("height", function(d) { return height - y(d.dx); }) + .on('mouseover', function(d) { + var percent = yValueFormat(d.y * 100.0 / values.length) + "%"; + var tip = d.y + " batches (" + percent + ") between " + yValueFormat(d.x) + " and " + yValueFormat(d.x + d.dx) + " " + unitY; + showBootstrapTooltip(d3.select(this).node(), tip); + }) + .on('mouseout', function() { + hideBootstrapTooltip(d3.select(this).node()); + }); + + if (batchInterval && batchInterval <= maxY) { + // Add the "stable" text to the graph below the batch interval line. + var stableXOffset = x(maxXForHistogram) - 20; + var stableYOffset = y(batchInterval) + 15; + svg.append("text") + .style("fill", "lightblue") + .attr("class", "stable-text") + .attr("text-anchor", "middle") + .attr("transform", "translate(" + stableXOffset + "," + stableYOffset + ")") + .text("stable") + .on('mouseover', function(d) { + var tip = "Processing Time <= Batch Interval (" + yValueFormat(batchInterval) +" " + unitY +")"; + showBootstrapTooltip(d3.select(this).node(), tip); + }) + .on('mouseout', function() { + hideBootstrapTooltip(d3.select(this).node()); + }); + } +} + +$(function() { + var status = window.localStorage && window.localStorage.getItem("show-streams-detail") == "true"; + + $("span.expand-input-rate").click(function() { + status = !status; + $("#inputs-table").toggle('collapsed'); + // Toggle the class of the arrow between open and closed + $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); + if (window.localStorage) { + window.localStorage.setItem("show-streams-detail", "" + status); + } + }); + + if (status) { + $("#inputs-table").toggle('collapsed'); + // Toggle the class of the arrow between open and closed + $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); + } +}); + +function highlightBatchRow(batch) { + $("#batch-" + batch).parent().addClass("batch-table-cell-highlight"); +} + +function clearBatchRow(batch) { + $("#batch-" + batch).parent().removeClass("batch-table-cell-highlight"); +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b780282bdac3..cd5d960369c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -25,8 +25,9 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -43,10 +44,27 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConfPairs = ssc.conf.getAll - def sparkConf = { - new SparkConf(false).setAll(sparkConfPairs) + def createSparkConf(): SparkConf = { + + // Reload properties for the checkpoint application since user wants to set a reload property + // or spark had changed its value and user wants to set it back. + val propertiesToReload = List( + "spark.driver.host", + "spark.driver.port", + "spark.master", + "spark.yarn.keytab", + "spark.yarn.principal") + + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") + val newReloadConf = new SparkConf(loadDefaults = true) + propertiesToReload.foreach { prop => + newReloadConf.getOption(prop).foreach { value => + newSparkConf.set(prop, value) + } + } + newSparkConf } def validate() { @@ -64,17 +82,18 @@ object Checkpoint extends Logging { val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r /** Get the checkpoint file for the given checkpoint time */ - def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) } /** Get the checkpoint backup file for the given checkpoint time */ - def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -82,6 +101,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) + val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf)) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -97,6 +117,44 @@ object Checkpoint extends Logging { Seq.empty } } + + /** Serialize the checkpoint, or throw any exception that occurs */ + def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = { + val compressionCodec = CompressionCodec.createCodec(conf) + val bos = new ByteArrayOutputStream() + val zos = compressionCodec.compressedOutputStream(bos) + val oos = new ObjectOutputStream(zos) + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } + bos.toByteArray + } + + /** Deserialize a checkpoint from the input stream, or throw any exception that occurs */ + def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = { + val compressionCodec = CompressionCodec.createCodec(conf) + var ois: ObjectInputStreamWithLoader = null + Utils.tryWithSafeFinally { + + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(inputStream) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + cp.validate() + cp + } { + if (ois != null) { + ois.close() + } + } + } } @@ -116,7 +174,10 @@ class CheckpointWriter( private var stopped = false private var fs_ : FileSystem = _ - class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { + class CheckpointWriteHandler( + checkpointTime: Time, + bytes: Array[Byte], + clearCheckpointDataLater: Boolean) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() @@ -131,15 +192,22 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) - fos.write(bytes) - fos.close() + Utils.tryWithSafeFinally { + fos.write(bytes) + } { + fos.close() + } // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } @@ -151,8 +219,8 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) - if (allCheckpointFiles.size > 4) { + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) + if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) fs.delete(file, true) @@ -163,7 +231,7 @@ class CheckpointWriter( val finishTime = System.currentTimeMillis() logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") - jobGenerator.onCheckpointCompletion(checkpointTime) + jobGenerator.onCheckpointCompletion(checkpointTime, clearCheckpointDataLater) return } catch { case ioe: IOException => @@ -177,15 +245,11 @@ class CheckpointWriter( } } - def write(checkpoint: Checkpoint) { - val bos = new ByteArrayOutputStream() - val zos = compressionCodec.compressedOutputStream(bos) - val oos = new ObjectOutputStream(zos) - oos.writeObject(checkpoint) - oos.close() - bos.close() + def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) { try { - executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + val bytes = Checkpoint.serialize(checkpoint, conf) + executor.execute(new CheckpointWriteHandler( + checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => @@ -222,13 +286,33 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = - { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None. + */ + def read(checkpointDir: String): Option[Checkpoint] = { + read(checkpointDir, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = true) + } + + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None (if ignoreReadError = true), + * or throw exception (if ignoreReadError = false). + */ + def read( + checkpointDir: String, + conf: SparkConf, + hadoopConf: Configuration, + ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) - def fs = checkpointPath.getFileSystem(hadoopConf) + + // TODO(rxin): Why is this a def?! + def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse if (checkpointFiles.isEmpty) { return None } @@ -240,18 +324,7 @@ object CheckpointReader extends Logging { logInfo("Attempting to load checkpoint from file " + file) try { val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() + val cp = Checkpoint.deserialize(fis, conf) logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) return Some(cp) @@ -262,7 +335,10 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) + if (!ignoreReadError) { + throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + } + None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 0e285d6088ec..40789c66f399 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -45,7 +45,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { startTime = time outputStreams.foreach(_.initialize(zeroTime)) outputStreams.foreach(_.remember(rememberDuration)) - outputStreams.foreach(_.validate) + outputStreams.foreach(_.validateAtStart) inputStreams.par.foreach(_.start()) } } @@ -100,16 +100,20 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray } - def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray } - def getReceiverInputStreams() = this.synchronized { + def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized { inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]]) .map(_.asInstanceOf[ReceiverInputDStream[_]]) .toArray } + def getInputStreamName(streamId: Int): Option[String] = synchronized { + inputStreams.find(_.id == streamId).map(_.name) + } + def generateJobs(time: Time): Seq[Job] = { logDebug("Generating jobs for time " + time) val jobs = this.synchronized { @@ -153,10 +157,10 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def validate() { this.synchronized { - assert(batchDuration != null, "Batch duration has not been set") + require(batchDuration != null, "Batch duration has not been set") // assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + // " is very low") - assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + require(getOutputStreams().size > 0, "No output operations registered, so nothing to execute") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index a0d8fb5ab93e..3249bb348981 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -55,7 +55,6 @@ case class Duration (private val millis: Long) { def div(that: Duration): Double = this / that - def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -71,7 +70,7 @@ case class Duration (private val millis: Long) { def milliseconds: Long = millis - def prettyPrint = Utils.msDurationToString(millis) + def prettyPrint: String = Utils.msDurationToString(millis) } @@ -80,7 +79,7 @@ case class Duration (private val millis: Long) { * a given number of milliseconds. */ object Milliseconds { - def apply(milliseconds: Long) = new Duration(milliseconds) + def apply(milliseconds: Long): Duration = new Duration(milliseconds) } /** @@ -88,7 +87,7 @@ object Milliseconds { * a given number of seconds. */ object Seconds { - def apply(seconds: Long) = new Duration(seconds * 1000) + def apply(seconds: Long): Duration = new Duration(seconds * 1000) } /** @@ -96,7 +95,7 @@ object Seconds { * a given number of minutes. */ object Minutes { - def apply(minutes: Long) = new Duration(minutes * 60000) + def apply(minutes: Long): Duration = new Duration(minutes * 60000) } // Java-friendlier versions of the objects above. @@ -107,16 +106,16 @@ object Durations { /** * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. */ - def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. */ - def seconds(seconds: Long) = Seconds(seconds) + def seconds(seconds: Long): Duration = Seconds(seconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. */ - def minutes(minutes: Long) = Minutes(minutes) + def minutes(minutes: Long): Duration = Minutes(minutes) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index ad4f3fdd14ad..3f5be785e1b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) { this.endTime < that.endTime } - def <= (that: Interval) = (this < that || this == that) + def <= (that: Interval): Boolean = (this < that || this == that) - def > (that: Interval) = !(this <= that) + def > (that: Interval): Boolean = !(this <= that) - def >= (that: Interval) = !(this < that) + def >= (that: Interval): Boolean = !(this < that) - override def toString = "[" + beginTime + ", " + endTime + "]" + override def toString: String = "[" + beginTime + ", " + endTime + "]" } private[streaming] object Interval { - def currentInterval(duration: Duration): Interval = { + def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) val intervalBegin = time.floor(duration) new Interval(intervalBegin, intervalBegin + duration) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 8ef078713784..b496d1f341a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -17,26 +17,34 @@ package org.apache.spark.streaming -import java.io.InputStream -import java.util.concurrent.atomic.AtomicInteger +import java.io.{InputStream, NotSerializableException} +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map import scala.collection.mutable.Queue import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + import org.apache.spark._ -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.input.FixedLengthBinaryInputFormat +import org.apache.spark.rdd.{RDD, RDDOperationScope} +import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} -import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} +import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -103,7 +111,20 @@ class StreamingContext private[streaming] ( * Recreate a StreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(path, new Configuration) + def this(path: String) = this(path, SparkHadoopUtil.get.conf) + + /** + * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. + * @param path Path to the directory that was specified as the checkpoint directory + * @param sparkContext Existing SparkContext + */ + def this(path: String, sparkContext: SparkContext) = { + this( + sparkContext, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + null) + } + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + @@ -113,10 +134,12 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { - new SparkContext(cp_.sparkConf) - } else { + if (sc_ != null) { sc_ + } else if (isCheckpointPresent) { + SparkContext.getOrCreate(cp_.createSparkConf()) + } else { + throw new SparkException("Cannot create StreamingContext without a SparkContext") } } @@ -127,7 +150,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = SparkEnv.get + private[streaming] val env = sc.env private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -135,14 +158,14 @@ class StreamingContext private[streaming] ( cp_.graph.restoreCheckpointData() cp_.graph } else { - assert(batchDur_ != null, "Batch duration for streaming context cannot be null") + require(batchDur_ != null, "Batch duration for StreamingContext cannot be null") val newGraph = new DStreamGraph() newGraph.setBatchDuration(batchDur_) newGraph } } - private val nextReceiverInputStreamId = new AtomicInteger(0) + private val nextInputStreamId = new AtomicInteger(0) private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { @@ -170,23 +193,21 @@ class StreamingContext private[streaming] ( None } - /** Register streaming source to metrics system */ + /* Initializing a streamingSource to register metrics */ private val streamingSource = new StreamingSource(this) - SparkEnv.get.metricsSystem.registerSource(streamingSource) - /** Enumeration to identify current state of the StreamingContext */ - private[streaming] object StreamingContextState extends Enumeration { - type CheckpointState = Value - val Initialized, Started, Stopped = Value - } + private var state: StreamingContextState = INITIALIZED - import StreamingContextState._ - private[streaming] var state = Initialized + private val startSite = new AtomicReference[CallSite](null) + + private var shutdownHookRef: AnyRef = _ + + conf.getOption("spark.streaming.checkpoint.directory").foreach(checkpoint) /** * Return the associated Spark context */ - def sparkContext = sc + def sparkContext: SparkContext = sc /** * Set each DStreams in this context to remember RDDs it generated in the last given duration. @@ -218,21 +239,46 @@ class StreamingContext private[streaming] ( } } + private[streaming] def isCheckpointingEnabled: Boolean = { + checkpointDir != null + } + private[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } - private[streaming] def getNewReceiverStreamId() = nextReceiverInputStreamId.getAndIncrement() + private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() + + /** + * Execute a block of code in a scope such that all new DStreams created in this body will + * be part of the same scope. For more detail, see the comments in `doCompute`. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[streaming] def withScope[U](body: => U): U = sparkContext.withScope(body) + + /** + * Execute a block of code in a scope such that all new DStreams created in this body will + * be part of the same scope. For more detail, see the comments in `doCompute`. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[streaming] def withNamedScope[U](name: String)(body: => U): U = { + RDDOperationScope.withScope(sc, name, allowNesting = false, ignoreParent = false)(body) + } /** * Create an input stream with any arbitrary user implemented receiver. * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver + * + * @deprecated As of 1.0.0", replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") - def networkStream[T: ClassTag]( - receiver: Receiver[T]): ReceiverInputDStream[T] = { - receiverStream(receiver) + def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { + withNamedScope("network stream") { + receiverStream(receiver) + } } /** @@ -240,9 +286,10 @@ class StreamingContext private[streaming] ( * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver */ - def receiverStream[T: ClassTag]( - receiver: Receiver[T]): ReceiverInputDStream[T] = { - new PluggableInputDStream[T](this, receiver) + def receiverStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { + withNamedScope("receiver stream") { + new PluggableInputDStream[T](this, receiver) + } } /** @@ -262,7 +309,7 @@ class StreamingContext private[streaming] ( name: String, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy - ): ReceiverInputDStream[T] = { + ): ReceiverInputDStream[T] = withNamedScope("actor stream") { receiverStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) } @@ -279,7 +326,7 @@ class StreamingContext private[streaming] ( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[String] = { + ): ReceiverInputDStream[String] = withNamedScope("socket text stream") { socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) } @@ -317,7 +364,7 @@ class StreamingContext private[streaming] ( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[T] = { + ): ReceiverInputDStream[T] = withNamedScope("raw socket stream") { new RawInputDStream[T](this, hostname, port, storageLevel) } @@ -359,6 +406,30 @@ class StreamingContext private[streaming] ( new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) } + /** + * Create a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @param conf Hadoop configuration + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassTag, + V: ClassTag, + F <: NewInputFormat[K, V]: ClassTag + ] (directory: String, + filter: Path => Boolean, + newFilesOnly: Boolean, + conf: Configuration): InputDStream[(K, V)] = { + new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly, Option(conf)) + } + /** * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value @@ -367,13 +438,49 @@ class StreamingContext private[streaming] ( * file system. File names starting with . are ignored. * @param directory HDFS directory to monitor for new file */ - def textFileStream(directory: String): DStream[String] = { + def textFileStream(directory: String): DStream[String] = withNamedScope("text file stream") { fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } + /** + * :: Experimental :: + * + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as flat binary files, assuming a fixed length per record, + * generating one byte array per record. Files must be written to the monitored directory + * by "moving" them from another location within the same file system. File names + * starting with . are ignored. + * + * '''Note:''' We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. + * + * @param directory HDFS directory to monitor for new file + * @param recordLength length of each record in bytes + */ + @Experimental + def binaryRecordsStream( + directory: String, + recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { + val conf = sc_.hadoopConfiguration + conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) + val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( + directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) + val data = br.map { case (k, v) => + val bytes = v.getBytes + require(bytes.length == recordLength, "Byte array does not have correct length. " + + s"${bytes.length} did not equal recordLength: $recordLength") + bytes + } + data + } + /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -388,6 +495,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. @@ -405,7 +516,7 @@ class StreamingContext private[streaming] ( /** * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ - def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = { + def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = withScope { new UnionDStream[T](streams.toArray) } @@ -416,7 +527,7 @@ class StreamingContext private[streaming] ( def transform[T: ClassTag]( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] - ): DStream[T] = { + ): DStream[T] = withScope { new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } @@ -431,29 +542,80 @@ class StreamingContext private[streaming] ( assert(graph != null, "Graph is null") graph.validate() - assert( - checkpointDir == null || checkpointDuration != null, + require( + !isCheckpointingEnabled || checkpointDuration != null, "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) + + // Verify whether the DStream checkpoint is serializable + if (isCheckpointingEnabled) { + val checkpoint = new Checkpoint(this, Time.apply(0)) + try { + Checkpoint.serialize(checkpoint, conf) + } catch { + case e: NotSerializableException => + throw new NotSerializableException( + "DStream checkpointing has been enabled but the DStreams with their functions " + + "are not serializable\n" + + SerializationDebugger.improveException(checkpoint, e).getMessage() + ) + } + } + } + + /** + * :: DeveloperApi :: + * + * Return the current state of the context. The context can be in three possible states - + * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + */ + @DeveloperApi + def getState(): StreamingContextState = synchronized { + state } /** * Start the execution of the streams. * - * @throws SparkException if the context has already been started or stopped. + * @throws IllegalStateException if the StreamingContext is already stopped. */ def start(): Unit = synchronized { - if (state == Started) { - throw new SparkException("StreamingContext has already been started") - } - if (state == Stopped) { - throw new SparkException("StreamingContext has already been stopped") + state match { + case INITIALIZED => + startSite.set(DStream.getCreationSite()) + sparkContext.setCallSite(startSite.get) + StreamingContext.ACTIVATION_LOCK.synchronized { + StreamingContext.assertNoOtherContextIsActive() + try { + validate() + scheduler.start() + state = StreamingContextState.ACTIVE + } catch { + case NonFatal(e) => + logError("Error starting the context, marking it as stopped", e) + scheduler.stop(false) + state = StreamingContextState.STOPPED + throw e + } + StreamingContext.setActiveContext(this) + } + shutdownHookRef = ShutdownHookManager.addShutdownHook( + StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + // Registering Streaming Metrics at the start of the StreamingContext + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) + uiTab.foreach(_.attach()) + logInfo("StreamingContext started") + case ACTIVE => + logWarning("StreamingContext has already been started") + case STOPPED => + throw new IllegalStateException("StreamingContext has already been stopped") } - validate() - sparkContext.setCallSite(DStream.getCreationSite()) - scheduler.start() - state = Started } /** @@ -468,20 +630,39 @@ class StreamingContext private[streaming] ( * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ + @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { waiter.waitForStopOrError(timeout) } + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * + * @param timeout time to wait in milliseconds + * @return `true` if it's stopped; or throw the reported error during the execution; or `false` + * if the waiting time elapsed before returning from the method. + */ + def awaitTerminationOrTimeout(timeout: Long): Boolean = { + waiter.waitForStopOrError(timeout) + } + /** * Stop the execution of the streams immediately (does not wait for all received data - * to be processed). + * to be processed). By default, if `stopSparkContext` is not specified, the underlying + * SparkContext will also be stopped. This implicit behavior can be configured using the + * SparkConf configuration spark.streaming.stopSparkContextByDefault. * - * @param stopSparkContext if true, stops the associated SparkContext. The underlying SparkContext + * @param stopSparkContext If true, stops the associated SparkContext. The underlying SparkContext * will be stopped regardless of whether this StreamingContext has been * started. */ - def stop(stopSparkContext: Boolean = true): Unit = synchronized { + def stop( + stopSparkContext: Boolean = conf.getBoolean("spark.streaming.stopSparkContextByDefault", true) + ): Unit = synchronized { stop(stopSparkContext, false) } @@ -496,20 +677,38 @@ class StreamingContext private[streaming] ( * received data to be completed */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - state match { - case Initialized => logWarning("StreamingContext has not been started yet") - case Stopped => logWarning("StreamingContext has already been stopped") - case Started => - scheduler.stop(stopGracefully) - logInfo("StreamingContext stopped successfully") - waiter.notifyStop() + try { + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + case STOPPED => + logWarning("StreamingContext has already been stopped") + case ACTIVE => + scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) + uiTab.foreach(_.detach()) + StreamingContext.setActiveContext(null) + waiter.notifyStop() + if (shutdownHookRef != null) { + ShutdownHookManager.removeShutdownHook(shutdownHookRef) + } + logInfo("StreamingContext stopped successfully") + } + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). + if (stopSparkContext) sc.stop() + } finally { + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state = STOPPED } - // Even if the streaming context has not been started, we still need to stop the SparkContext. - // Even if we have already stopped, we still need to attempt to stop the SparkContext because - // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). - if (stopSparkContext) sc.stop() - // The state should always be Stopped after calling `stop()`, even if we haven't started yet: - state = Stopped + } + + private def stopOnShutdown(): Unit = { + val stopGracefully = conf.getBoolean("spark.streaming.stopGracefullyOnShutdown", false) + logInfo(s"Invoking stop(stopGracefully=$stopGracefully) from shutdown hook") + // Do not stop SparkContext, let its own shutdown hook stop it + stop(stopSparkContext = false, stopGracefully = stopGracefully) } } @@ -520,15 +719,99 @@ class StreamingContext private[streaming] ( object StreamingContext extends Logging { - private[streaming] val DEFAULT_CLEANER_TTL = 3600 + /** + * Lock that guards activation of a StreamingContext as well as access to the singleton active + * StreamingContext in getActiveOrCreate(). + */ + private val ACTIVATION_LOCK = new Object() + + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + + private val activeContext = new AtomicReference[StreamingContext](null) + + private def assertNoOtherContextIsActive(): Unit = { + ACTIVATION_LOCK.synchronized { + if (activeContext.get() != null) { + throw new IllegalStateException( + "Only one StreamingContext may be started in this JVM. " + + "Currently running StreamingContext was started at" + + activeContext.get.startSite.get.longForm) + } + } + } + + private def setActiveContext(ssc: StreamingContext): Unit = { + ACTIVATION_LOCK.synchronized { + activeContext.set(ssc) + } + } + + /** + * :: Experimental :: + * + * Get the currently active context, if there is one. Active means started but not stopped. + */ + @Experimental + def getActive(): Option[StreamingContext] = { + ACTIVATION_LOCK.synchronized { + Option(activeContext.get()) + } + } + /** + * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. + * This is kept here only for backward compatibility. + */ @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) + : PairDStreamFunctions[K, V] = { DStream.toPairDStreamFunctions(stream)(kt, vt, ord) } + /** + * :: Experimental :: + * + * Either return the "active" StreamingContext (that is, started but not stopped), or create a + * new StreamingContext that is + * @param creatingFunc Function to create a new StreamingContext + */ + @Experimental + def getActiveOrCreate(creatingFunc: () => StreamingContext): StreamingContext = { + ACTIVATION_LOCK.synchronized { + getActive().getOrElse { creatingFunc() } + } + } + + /** + * :: Experimental :: + * + * Either get the currently active StreamingContext (that is, started but not stopped), + * OR recreate a StreamingContext from checkpoint data in the given path. If checkpoint data + * does not exist in the provided, then create a new StreamingContext by calling the provided + * `creatingFunc`. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext + * @param hadoopConf Optional Hadoop configuration if necessary for reading from the + * file system + * @param createOnError Optional, whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + @Experimental + def getActiveOrCreate( + checkpointPath: String, + creatingFunc: () => StreamingContext, + hadoopConf: Configuration = SparkHadoopUtil.get.conf, + createOnError: Boolean = false + ): StreamingContext = { + ACTIVATION_LOCK.synchronized { + getActive().getOrElse { getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError) } + } + } + /** * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be @@ -546,19 +829,11 @@ object StreamingContext extends Logging { def getOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = try { - CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) - } catch { - case e: Exception => - if (createOnError) { - None - } else { - throw e - } - } + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), hadoopConf, createOnError) checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } @@ -581,7 +856,7 @@ object StreamingContext extends Logging { ): SparkContext = { val conf = SparkContext.updatedConf( new SparkConf(), master, appName, sparkHome, jars, environment) - createNewSparkContext(conf) + new SparkContext(conf) } private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java new file mode 100644 index 000000000000..d7b639383ee3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * :: DeveloperApi :: + * + * Represents the state of a StreamingContext. + */ +@DeveloperApi +public enum StreamingContextState { + /** + * The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + */ + INITIALIZED, + + /** + * The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + */ + ACTIVE, + + /** + * The context has been stopped and cannot be used any more. + */ + STOPPED +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala index 42c49678d24f..92cfd7d40338 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -63,6 +63,11 @@ case class Time(private val millis: Long) { new Time((this.millis / t) * t) } + def floor(that: Duration, zeroTime: Time): Time = { + val t = that.milliseconds + new Time(((this.millis - zeroTime.milliseconds) / t) * t + zeroTime.milliseconds) + } + def isMultipleOf(that: Duration): Boolean = (this.millis % that.milliseconds == 0) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 505e4431e435..01cdcb057404 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -36,7 +36,7 @@ import org.apache.spark.streaming.dstream.DStream * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. */ class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T]) - extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { + extends AbstractJavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index c382a12f4d09..edfa474677f1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,11 +17,10 @@ package org.apache.spark.streaming.api.java -import java.util import java.lang.{Long => JLong} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -34,6 +33,15 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ import org.apache.spark.streaming.dstream.DStream +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaDStreamLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[streaming] +abstract class AbstractJavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], + R <: JavaRDDLike[T, R]] extends JavaDStreamLike[T, This, R] + trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] extends Serializable { implicit val classTag: ClassTag[T] @@ -136,8 +144,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * an array. */ def glom(): JavaDStream[JList[T]] = - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - + new JavaDStream(dstream.glom().map(_.toSeq.asJava)) /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */ @@ -160,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -170,7 +177,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -181,7 +188,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -192,7 +201,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -269,7 +280,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ def slice(fromTime: Time, toTime: Time): JList[R] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) + dstream.slice(fromTime, toTime).map(wrapRDD).asJava } /** @@ -278,7 +289,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction[R, Void]) { foreachRDD(foreachFunc) } @@ -289,7 +300,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction2[R, Time, Void]) { foreachRDD(foreachFunc) } @@ -406,8 +417,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T implicit val cmv2: ClassTag[V2] = fakeClassTag implicit val cmw: ClassTag[W] = fakeClassTag - def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = + def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = { transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd + } dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index de124cf40eff..e2aec6c2f63e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.api.java import java.lang.{Long => JLong, Iterable => JIterable} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -45,7 +45,7 @@ import org.apache.spark.streaming.dstream.DStream class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifest: ClassTag[K], implicit val vManifest: ClassTag[V]) - extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { + extends AbstractJavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) @@ -116,14 +116,14 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey().mapValues(asJavaIterable _) + dstream.groupByKey().mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(numPartitions).mapValues(asJavaIterable _) + dstream.groupByKey(numPartitions).mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. @@ -132,7 +132,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(partitioner).mapValues(asJavaIterable _) + dstream.groupByKey(partitioner).mapValues(_.asJava) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are @@ -197,7 +197,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration).mapValues(_.asJava) } /** @@ -212,7 +212,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(_.asJava) } /** @@ -227,9 +227,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(asJavaIterable _) + : JavaPairDStream[K, JIterable[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(_.asJava) } /** @@ -247,9 +246,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ):JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(asJavaIterable _) + ): JavaPairDStream[K, JIterable[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(_.asJava) } /** @@ -262,7 +260,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def reduceByKeyAndWindow(reduceFunc: JFunction2[V, V, V], windowDuration: Duration) - :JavaPairDStream[K, V] = { + : JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } @@ -281,7 +279,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration - ):JavaPairDStream[K, V] = { + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } @@ -431,7 +429,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values + val list: JList[V] = values.asJava val scalaState: Optional[S] = JavaUtils.optionToOptional(state) val result: Optional[S] = in.apply(list, scalaState) result.isPresent match { @@ -526,7 +524,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) @@ -539,7 +537,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream).mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -551,8 +549,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( numPartitions: Int ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, numPartitions) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, numPartitions).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -564,8 +561,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner: Partitioner ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, partitioner).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -726,7 +722,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { + def saveAsHadoopFiles(prefix: String, suffix: String) { dstream.saveAsHadoopFiles(prefix, suffix) } @@ -734,12 +730,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles( + def saveAsHadoopFiles[F <: OutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]]) { + outputFormatClass: Class[F]) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } @@ -747,12 +743,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles( + def saveAsHadoopFiles[F <: OutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], + outputFormatClass: Class[F], conf: JobConf) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } @@ -761,7 +757,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { + def saveAsNewAPIHadoopFiles(prefix: String, suffix: String) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix) } @@ -769,12 +765,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles( + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + outputFormatClass: Class[F]) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } @@ -782,13 +778,13 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles( + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = new Configuration) { + outputFormatClass: Class[F], + conf: Configuration = dstream.context.sparkContext.hadoopConfiguration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 9a2254bcdc1f..13f371f29603 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} @@ -29,15 +29,18 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function0 => JFunction0} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver +import org.apache.hadoop.conf.Configuration /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -112,7 +115,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) + this(new StreamingContext( + master, + appName, + batchDuration, + sparkHome, + jars, + environment.asScala)) /** * Create a JavaStreamingContext using an existing JavaSparkContext. @@ -134,7 +143,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Recreate a JavaStreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(new StreamingContext(path, new Configuration)) + def this(path: String) = this(new StreamingContext(path, SparkHadoopUtil.get.conf)) /** * Re-creates a JavaStreamingContext from a checkpoint file. @@ -146,6 +155,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) + /** + * @deprecated As of 0.9.0, replaced by `sparkContext` + */ @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext @@ -177,7 +189,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create an input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given + * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data @@ -191,7 +203,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).iterator().asScala implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -209,6 +221,24 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.textFileStream(directory) } + /** + * :: Experimental :: + * + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as flat binary files with fixed record lengths, + * yielding byte arrays + * + * '''Note:''' We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. + * + * @param directory HDFS directory to monitor for new files + * @param recordLength The length at which to split the records + */ + @Experimental + def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { + ssc.binaryRecordsStream(directory, recordLength) + } + /** * Create an input stream from network source hostname:port, where data is received * as serialized blocks (serialized using the Spark's serializer) that can be directly @@ -294,10 +324,41 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } + /** + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @param conf Hadoop configuration + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]]( + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F], + filter: JFunction[Path, JBoolean], + newFilesOnly: Boolean, + conf: Configuration): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() + ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor @@ -365,7 +426,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @tparam T Type of objects in the RDD */ @@ -373,7 +438,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue) } @@ -381,7 +446,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -393,7 +462,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime) } @@ -401,7 +470,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty @@ -414,7 +487,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) } @@ -433,7 +506,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = { - val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[T]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[T] = first.classTag ssc.union(dstreams)(cm) } @@ -445,7 +518,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { first: JavaPairDStream[K, V], rest: JList[JavaPairDStream[K, V]] ): JavaPairDStream[K, V] = { - val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[(K, V)] = first.classTag implicit val kcm: ClassTag[K] = first.kManifest implicit val vcm: ClassTag[V] = first.vManifest @@ -467,12 +540,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ): JavaDStream[T] = { implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** @@ -492,12 +564,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** @@ -527,6 +598,28 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.addStreamingListener(streamingListener) } + /** + * :: DeveloperApi :: + * + * Return the current state of the context. The context can be in three possible states - + *
      + *
    • + * StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + *
    • + *
    • + * StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + *
    • + *
    • + * StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + *
    • + *
    + */ + def getState(): StreamingContextState = { + ssc.getState() + } + /** * Start the execution of the streams. */ @@ -546,11 +639,25 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ + @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { ssc.awaitTermination(timeout) } + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * + * @param timeout time to wait in milliseconds + * @return `true` if it's stopped; or throw the reported error during the execution; or `false` + * if the waiting time elapsed before returning from the method. + */ + def awaitTerminationOrTimeout(timeout: Long): Boolean = { + ssc.awaitTerminationOrTimeout(timeout) + } + /** * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ @@ -562,7 +669,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = ssc.stop(stopSparkContext) /** * Stop the execution of the streams. @@ -570,7 +677,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param stopGracefully Stop gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { ssc.stop(stopSparkContext, stopGracefully) } @@ -591,7 +698,9 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ + @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -612,7 +721,9 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -636,7 +747,9 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -649,6 +762,72 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext] + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 7053f47ec69a..2c373640d2fd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,14 +20,13 @@ package org.apache.spark.streaming.api.python import java.io.{ObjectInputStream, ObjectOutputStream} import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} -import scala.collection.JavaConversions._ + import scala.collection.JavaConverters._ import scala.language.existentials import py4j.GatewayServer import org.apache.spark.api.java._ -import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} @@ -109,7 +108,7 @@ private[python] object PythonTransformFunctionSerializer { } def serialize(func: PythonTransformFunction): Array[Byte] = { - assert(serializer != null, "Serializer has not been registered!") + require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") @@ -119,7 +118,7 @@ private[python] object PythonTransformFunctionSerializer { } def deserialize(bytes: Array[Byte]): PythonTransformFunction = { - assert(serializer != null, "Serializer has not been registered!") + require(serializer != null, "Serializer has not been registered!") serializer.loads(bytes) } } @@ -161,7 +160,7 @@ private[python] object PythonDStream { */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] - rdds.forall(queue.add(_)) + rdds.asScala.foreach(queue.add) queue } } @@ -176,11 +175,11 @@ private[python] abstract class PythonDStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -212,7 +211,7 @@ private[python] class PythonTransformed2DStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent, parent2) + override def dependencies: List[DStream[_]] = List(parent, parent2) override def slideDuration: Duration = parent.slideDuration @@ -223,7 +222,7 @@ private[python] class PythonTransformed2DStream( func(Some(rdd1), Some(rdd2), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -260,12 +259,15 @@ private[python] class PythonReducedWindowedDStream( extends PythonDStream(parent, preduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) - override val mustCheckpoint = true - val invReduceFunc = new TransformFunction(pinvReduceFunc) + override val mustCheckpoint: Boolean = true + + val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index b874f561c12e..1da0b0a54df0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -25,12 +25,13 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD} +import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD, RDDOperationScope} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job +import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} /** @@ -60,6 +61,8 @@ abstract class DStream[T: ClassTag] ( @transient private[streaming] var ssc: StreamingContext ) extends Serializable with Logging { + validateAtInit() + // ======================================================================= // Methods that should be implemented by subclasses of DStream // ======================================================================= @@ -71,7 +74,7 @@ abstract class DStream[T: ClassTag] ( def dependencies: List[DStream[_]] /** Method that generates a RDD for the given time */ - def compute (validTime: Time): Option[RDD[T]] + def compute(validTime: Time): Option[RDD[T]] // ======================================================================= // Methods and fields available on all DStreams @@ -104,11 +107,49 @@ abstract class DStream[T: ClassTag] ( private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ - def context = ssc + def context: StreamingContext = ssc /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() + /** + * The base scope associated with the operation that created this DStream. + * + * This is the medium through which we pass the DStream operation name (e.g. updatedStateByKey) + * to the RDDs created by this DStream. Note that we never use this scope directly in RDDs. + * Instead, we instantiate a new scope during each call to `compute` based on this one. + * + * This is not defined if the DStream is created outside of one of the public DStream operations. + */ + protected[streaming] val baseScope: Option[String] = { + Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) + } + + /** + * Make a scope that groups RDDs created in the same DStream operation in the same batch. + * + * Each DStream produces many scopes and each scope may be shared by other DStreams created + * in the same operation. Separate calls to the same DStream operation create separate scopes. + * For instance, `dstream.map(...).map(...)` creates two separate scopes per batch. + */ + private def makeScope(time: Time): Option[RDDOperationScope] = { + baseScope.map { bsJson => + val formattedBatchTime = UIUtils.formatBatchTime( + time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + val bs = RDDOperationScope.fromJson(bsJson) + val baseName = bs.name // e.g. countByWindow, "kafka stream [0]" + val scopeName = + if (baseName.length > 10) { + // If the operation name is too long, wrap the line + s"$baseName\n@ $formattedBatchTime" + } else { + s"$baseName @ $formattedBatchTime" + } + val scopeId = s"${bs.id}_${time.milliseconds}" + new RDDOperationScope(scopeName, id = scopeId) + } + } + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -171,43 +212,57 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.initialize(zeroTime)) } - private[streaming] def validate() { - assert(rememberDuration != null, "Remember duration is set to null") + private def validateAtInit(): Unit = { + ssc.getState() match { + case StreamingContextState.INITIALIZED => + // good to go + case StreamingContextState.ACTIVE => + throw new IllegalStateException( + "Adding new inputs, transformations, and output operations after " + + "starting a context is not supported") + case StreamingContextState.STOPPED => + throw new IllegalStateException( + "Adding new inputs, transformations, and output operations after " + + "stopping a context is not supported") + } + } + + private[streaming] def validateAtStart() { + require(rememberDuration != null, "Remember duration is set to null") - assert( + require( !mustCheckpoint || checkpointDuration != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) - assert( + require( checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, - "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + - " or SparkContext.checkpoint() to set the checkpoint directory." + "The checkpoint directory has not been set. Please set it by StreamingContext.checkpoint()." ) - assert( + require( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + "Please set it to at least " + slideDuration + "." ) - assert( + require( checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple " + slideDuration + "." + "Please set it to a multiple of " + slideDuration + "." ) - assert( + require( checkpointDuration == null || storageLevel != StorageLevel.NONE, "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) - assert( + require( checkpointDuration == null || rememberDuration > checkpointDuration, "The remember duration for " + this.getClass.getSimpleName + " has been set to " + rememberDuration + " which is not more than the checkpoint interval (" + @@ -216,7 +271,7 @@ abstract class DStream[T: ClassTag] ( val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - assert( + require( metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + @@ -226,7 +281,7 @@ abstract class DStream[T: ClassTag] ( math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." ) - dependencies.foreach(_.validate()) + dependencies.foreach(_.validateAtStart()) logInfo("Slide time = " + slideDuration) logInfo("Storage level = " + storageLevel) @@ -278,28 +333,23 @@ abstract class DStream[T: ClassTag] ( * Get the RDD corresponding to the given time; either retrieve it from cache * or compute-and-cache it. */ - private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = { // If RDD was already generated, then retrieve it from HashMap, // or else compute the RDD generatedRDDs.get(time).orElse { // Compute the RDD if time is valid (e.g. correct time in a sliding window) // of RDD generation, else generate nothing. if (isTimeValid(time)) { - // Set the thread-local property for call sites to this DStream's creation site - // such that RDDs generated by compute gets that as their creation site. - // Note that this `getOrCompute` may get called from another DStream which may have - // set its own call site. So we store its call site in a temporary variable, - // set this DStream's creation site, generate RDDs and then restore the previous call site. - val prevCallSite = ssc.sparkContext.getCallSite() - ssc.sparkContext.setCallSite(creationSite) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. We need to have this call here because - // compute() might cause Spark jobs to be launched. - val rddOption = PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - compute(time) + + val rddOption = createRDDWithLocalProperties(time) { + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. We need to have this call here because + // compute() might cause Spark jobs to be launched. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + compute(time) + } } - ssc.sparkContext.setCallSite(prevCallSite) rddOption.foreach { case newRDD => // Register the generated RDD for caching and checkpointing @@ -320,6 +370,41 @@ abstract class DStream[T: ClassTag] ( } } + /** + * Wrap a body of code such that the call site and operation scope + * information are passed to the RDDs created in this body properly. + */ + protected def createRDDWithLocalProperties[U](time: Time)(body: => U): U = { + val scopeKey = SparkContext.RDD_SCOPE_KEY + val scopeNoOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY + // Pass this DStream's operation scope and creation site information to RDDs through + // thread-local properties in our SparkContext. Since this method may be called from another + // DStream, we need to temporarily store any old scope and creation site information to + // restore them later after setting our own. + val prevCallSite = ssc.sparkContext.getCallSite() + val prevScope = ssc.sparkContext.getLocalProperty(scopeKey) + val prevScopeNoOverride = ssc.sparkContext.getLocalProperty(scopeNoOverrideKey) + + try { + ssc.sparkContext.setCallSite(creationSite) + // Use the DStream's base scope for this RDD so we can (1) preserve the higher level + // DStream operation name, and (2) share this scope with other DStreams created in the + // same operation. Disallow nesting so that low-level Spark primitives do not show up. + // TODO: merge callsites with scopes so we can just reuse the code there + makeScope(time).foreach { s => + ssc.sparkContext.setLocalProperty(scopeKey, s.toJson) + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + } + + body + } finally { + // Restore any state that was modified before returning + ssc.sparkContext.setCallSite(prevCallSite) + ssc.sparkContext.setLocalProperty(scopeKey, prevScope) + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, prevScopeNoOverride) + } + } + /** * Generate a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job @@ -439,7 +524,7 @@ abstract class DStream[T: ClassTag] ( // ======================================================================= /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[U: ClassTag](mapFunc: T => U): DStream[U] = { + def map[U: ClassTag](mapFunc: T => U): DStream[U] = ssc.withScope { new MappedDStream(this, context.sparkContext.clean(mapFunc)) } @@ -447,26 +532,31 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = { + def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope { new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } /** Return a new DStream containing only the elements that satisfy a predicate. */ - def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + def filter(filterFunc: T => Boolean): DStream[T] = ssc.withScope { + new FilteredDStream(this, context.sparkContext.clean(filterFunc)) + } /** * Return a new DStream in which each RDD is generated by applying glom() to each RDD of * this DStream. Applying glom() to an RDD coalesces all elements within each partition into * an array. */ - def glom(): DStream[Array[T]] = new GlommedDStream(this) - + def glom(): DStream[Array[T]] = ssc.withScope { + new GlommedDStream(this) + } /** * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the * returned DStream has exactly numPartitions partitions. */ - def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions)) + def repartition(numPartitions: Int): DStream[T] = ssc.withScope { + this.transform(_.repartition(numPartitions)) + } /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs @@ -476,7 +566,7 @@ abstract class DStream[T: ClassTag] ( def mapPartitions[U: ClassTag]( mapPartFunc: Iterator[T] => Iterator[U], preservePartitioning: Boolean = false - ): DStream[U] = { + ): DStream[U] = ssc.withScope { new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) } @@ -484,14 +574,15 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD has a single element generated by reducing each RDD * of this DStream. */ - def reduce(reduceFunc: (T, T) => T): DStream[T] = + def reduce(reduceFunc: (T, T) => T): DStream[T] = ssc.withScope { this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + } /** * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = { + def count(): DStream[Long] = ssc.withScope { this.map(_ => (null, 1L)) .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) .reduceByKey(_ + _) @@ -505,24 +596,29 @@ abstract class DStream[T: ClassTag] ( * `numPartitions` not specified). */ def countByValue(numPartitions: Int = ssc.sc.defaultParallelism)(implicit ord: Ordering[T] = null) - : DStream[(T, Long)] = + : DStream[(T, Long)] = ssc.withScope { this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: RDD[T] => Unit): Unit = { + def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { this.foreachRDD(foreachFunc) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = { + def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { this.foreachRDD(foreachFunc) } @@ -530,17 +626,18 @@ abstract class DStream[T: ClassTag] ( * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: RDD[T] => Unit) { - this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { + val cleanedF = context.sparkContext.clean(foreachFunc, false) + this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() } @@ -549,23 +646,24 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean - transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false)) + val cleanedF = context.sparkContext.clean(transformFunc, false) + transform((r: RDD[T], t: Time) => cleanedF(r)) } /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) - val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { + val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) } @@ -578,9 +676,9 @@ abstract class DStream[T: ClassTag] ( */ def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V] - ): DStream[V] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + ): DStream[V] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = ssc.sparkContext.clean(transformFunc, false) transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2)) @@ -592,9 +690,9 @@ abstract class DStream[T: ClassTag] ( */ def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V] - ): DStream[V] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + ): DStream[V] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = ssc.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { @@ -610,7 +708,7 @@ abstract class DStream[T: ClassTag] ( * Print the first ten elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ - def print() { + def print(): Unit = ssc.withScope { print(10) } @@ -618,15 +716,19 @@ abstract class DStream[T: ClassTag] ( * Print the first num elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ - def print(num: Int) { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val firstNum = rdd.take(num + 1) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - firstNum.take(num).foreach(println) - if (firstNum.size > num) println("...") - println() + def print(num: Int): Unit = ssc.withScope { + def foreachFunc: (RDD[T], Time) => Unit = { + (rdd: RDD[T], time: Time) => { + val firstNum = rdd.take(num + 1) + // scalastyle:off println + println("-------------------------------------------") + println("Time: " + time) + println("-------------------------------------------") + firstNum.take(num).foreach(println) + if (firstNum.length > num) println("...") + println() + // scalastyle:on println + } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } @@ -648,7 +750,7 @@ abstract class DStream[T: ClassTag] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { + def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = ssc.withScope { new WindowedDStream(this, windowDuration, slideDuration) } @@ -666,7 +768,7 @@ abstract class DStream[T: ClassTag] ( reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration - ): DStream[T] = { + ): DStream[T] = ssc.withScope { this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) } @@ -691,7 +793,7 @@ abstract class DStream[T: ClassTag] ( invReduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration - ): DStream[T] = { + ): DStream[T] = ssc.withScope { this.map(x => (1, x)) .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) .map(_._2) @@ -707,7 +809,9 @@ abstract class DStream[T: ClassTag] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + def countByWindow( + windowDuration: Duration, + slideDuration: Duration): DStream[Long] = ssc.withScope { this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } @@ -728,8 +832,7 @@ abstract class DStream[T: ClassTag] ( slideDuration: Duration, numPartitions: Int = ssc.sc.defaultParallelism) (implicit ord: Ordering[T] = null) - : DStream[(T, Long)] = - { + : DStream[(T, Long)] = ssc.withScope { this.map(x => (x, 1L)).reduceByKeyAndWindow( (x: Long, y: Long) => x + y, (x: Long, y: Long) => x - y, @@ -744,32 +847,40 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same slideDuration as this DStream. */ - def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) + def union(that: DStream[T]): DStream[T] = ssc.withScope { + new UnionDStream[T](Array(this, that)) + } /** * Return all the RDDs defined by the Interval object (both end times included) */ - def slice(interval: Interval): Seq[RDD[T]] = { + def slice(interval: Interval): Seq[RDD[T]] = ssc.withScope { slice(interval.beginTime, interval.endTime) } /** * Return all the RDDs between 'fromTime' to 'toTime' (both included) */ - def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { + def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = ssc.withScope { if (!isInitialized) { throw new SparkException(this + " has not been initialized") } - if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) { - logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") + + val alignedToTime = if ((toTime - zeroTime).isMultipleOf(slideDuration)) { + toTime + } else { + logWarning("toTime (" + toTime + ") is not a multiple of slideDuration (" + + slideDuration + ")") + toTime.floor(slideDuration, zeroTime) } - if (!(toTime - zeroTime).isMultipleOf(slideDuration)) { - logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") + + val alignedFromTime = if ((fromTime - zeroTime).isMultipleOf(slideDuration)) { + fromTime + } else { + logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + + slideDuration + ")") + fromTime.floor(slideDuration, zeroTime) } - val alignedToTime = toTime.floor(slideDuration) - val alignedFromTime = fromTime.floor(slideDuration) logInfo("Slicing from " + fromTime + " to " + toTime + " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") @@ -784,7 +895,7 @@ abstract class DStream[T: ClassTag] ( * The file name at each batch interval is generated based on `prefix` and * `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsObjectFiles(prefix: String, suffix: String = "") { + def saveAsObjectFiles(prefix: String, suffix: String = ""): Unit = ssc.withScope { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) @@ -797,7 +908,7 @@ abstract class DStream[T: ClassTag] ( * of elements. The file name at each batch interval is generated based on * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsTextFiles(prefix: String, suffix: String = "") { + def saveAsTextFiles(prefix: String, suffix: String = ""): Unit = ssc.withScope { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) @@ -837,7 +948,7 @@ object DStream { /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { - def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined val isSparkClass = doesMatch(SPARK_CLASS_REGEX) val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 0dc72790fbdb..39fd21342813 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } } - override def toString() = { + override def toString: String = { "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index e7c5639a6349..c358f5b5bd70 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -18,17 +18,18 @@ package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ -import org.apache.spark.util.{TimeStampedHashMap, Utils} +import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -62,19 +63,34 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * the streaming app. * - If a file is to be visible in the directory listings, it must be visible within a certain * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.MIN_REMEMBER_DURATION`). Otherwise, the file will never be + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be * selected as the mod time will be less than the ignore threshold when it becomes visible. * - Once a file is visible, the mod time cannot change. If it does due to appends, then the * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : ClassTag]( +class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, - newFilesOnly: Boolean = true) + newFilesOnly: Boolean = true, + conf: Option[Configuration] = None) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { + private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) + + /** + * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. + * + * Files with mod times older than this "window" of remembering will be ignored. So if new + * files are visible within this window, then the file will get selected in the next batch. + */ + private val minRememberDurationS = { + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.fileStream.minRememberDuration", + ssc.conf.get("spark.streaming.minRememberDuration", "60s"))) + } + // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock @@ -83,14 +99,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Initial ignore threshold based on which old, existing files in the directory (at the time of // starting the streaming application) will be ignored or considered - private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.currentTime() else 0L + private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.getTimeMillis() else 0L /* * Make sure that the information of files selected in the last few batches are remembered. * This would allow us to filter away not-too-old files which have already been recently * selected and processed. */ - private val numBatchesToRemember = FileInputDStream.calculateNumBatchesToRemember(slideDuration) + private val numBatchesToRemember = FileInputDStream + .calculateNumBatchesToRemember(slideDuration, minRememberDurationS) private val durationToRemember = slideDuration * numBatchesToRemember remember(durationToRemember) @@ -130,7 +147,14 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) batchTimeToSelectedFiles += ((validTime, newFiles)) recentlySelectedFiles ++= newFiles - Some(filesToRDD(newFiles)) + val rdds = Some(filesToRDD(newFiles)) + // Copy newFiles to immutable.List to prevent from being modified by the user + val metadata = Map( + "files" -> newFiles.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> newFiles.mkString("\n")) + val inputInfo = StreamInputInfo(id, 0, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + rdds } /** Clear the old time-to-files mappings along with old RDDs */ @@ -156,7 +180,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): Array[String] = { try { - lastNewFileFindingTime = clock.currentTime() + lastNewFileFindingTime = clock.getTimeMillis() // Calculate ignore threshold val modTimeIgnoreThreshold = math.max( @@ -169,7 +193,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) } val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) - val timeTaken = clock.currentTime() - lastNewFileFindingTime + val timeTaken = clock.getTimeMillis() - lastNewFileFindingTime logInfo("Finding new files took " + timeTaken + " ms") logDebug("# cached file times = " + fileToModTime.size) if (timeTaken > slideDuration.milliseconds) { @@ -236,15 +260,23 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file =>{ - val rdd = context.sparkContext.newAPIHadoopFile[K, V, F](file) + val fileRDDs = files.map { file => + val rdd = serializableConfOpt.map(_.value) match { + case Some(config) => context.sparkContext.newAPIHadoopFile( + file, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + config) + case None => context.sparkContext.newAPIHadoopFile[K, V, F](file) + } if (rdd.partitions.size == 0) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + "Refer to the streaming programming guide for more details.") } rdd - }) + } new UnionRDD(context.sparkContext, fileRDDs) } @@ -271,7 +303,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() @@ -285,7 +317,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] + private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() @@ -307,7 +339,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } } - override def toString() = { + override def toString: String = { "[\n" + hadoopFiles.size + " file sets\n" + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } @@ -317,20 +349,14 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private[streaming] object FileInputDStream { - /** - * Minimum duration of remembering the information of selected files. Files with mod times - * older than this "window" of remembering will be ignored. So if new files are visible - * within this window, then the file will get selected in the next batch. - */ - private val MIN_REMEMBER_DURATION = Minutes(1) - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") /** * Calculate the number of last batches to remember, such that all the files selected in - * at least last MIN_REMEMBER_DURATION duration can be remembered. + * at least last minRememberDurationS duration can be remembered. */ - def calculateNumBatchesToRemember(batchDuration: Duration): Int = { - math.ceil(MIN_REMEMBER_DURATION.milliseconds.toDouble / batchDuration.milliseconds).toInt + def calculateNumBatchesToRemember(batchDuration: Duration, + minRememberDurationS: Duration): Int = { + math.ceil(minRememberDurationS.milliseconds.toDouble / batchDuration.milliseconds).toInt } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index c81534ae584e..fcd5216f101a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag]( filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 658623455498..9d09a3baf37c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( flatMapValueFunc: V => TraversableOnce[U] ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index c7bb2833eabb..475ea2d2d4f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag]( flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 1361c30395b5..c109ceccc698 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] ( foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration @@ -37,7 +37,7 @@ class ForEachDStream[T: ClassTag] ( override def generateJob(time: Time): Option[Job] = { parent.getOrCompute(time) match { case Some(rdd) => - val jobFunc = () => { + val jobFunc = () => createRDDWithLocalProperties(time) { ssc.sparkContext.setCallSite(creationSite) foreachFunc(rdd, time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index a9bb51f05404..dbb295fe54f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -25,7 +25,7 @@ private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index aa1993f0580a..a6c4cd220e42 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,10 +17,15 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Time, Duration, StreamingContext} - import scala.reflect.ClassTag +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.Utils + /** * This is the abstract base class for all input streams. This class provides methods * start() and stop() which is called by Spark Streaming system to start and stop receiving data. @@ -41,6 +46,38 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) ssc.graph.addInputStream(this) + /** This is an unique identifier for the input stream. */ + val id = ssc.getNewInputStreamId() + + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: Option[RateController] = None + + /** A human-readable name of this InputDStream */ + private[streaming] def name: String = { + // e.g. FlumePollingDStream -> "Flume polling stream" + val newName = Utils.getFormattedClassName(this) + .replaceAll("InputDStream", "Stream") + .split("(?=[A-Z])") + .filter(_.nonEmpty) + .mkString(" ") + .toLowerCase + .capitalize + s"$newName [$id]" + } + + /** + * The base scope associated with the operation that created this DStream. + * + * For InputDStreams, we use the name of this DStream as the scope name. + * If an outer scope is given, we assume that it includes an alternative name for this stream. + */ + protected[streaming] override val baseScope: Option[String] = { + val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) + .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } + .getOrElse(name.toLowerCase) + Some(new RDDOperationScope(scopeName).toJson) + } + /** * Checks whether the 'time' is valid wrt slideDuration for generating RDD. * Additionally it also ensures valid times are in strictly increasing order. @@ -61,7 +98,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) } } - override def dependencies = List() + override def dependencies: List[DStream[_]] = List() override def slideDuration: Duration = { if (ssc == null) throw new Exception("ssc is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 3d8ee29df1e8..5994bc1e23f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag]( preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 7aea1f945d9d..954d2eb4a7b0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( mapValueFunc: V => U ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index 02704a8d1c2e..fa14b2e897c3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] ( mapFunc: T => U ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 8a5857163244..71bec96d46c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,20 +24,23 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ -class PairDStreamFunctions[K, V](self: DStream[(K,V)]) +class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) extends Serializable { private[streaming] def ssc = self.ssc + private[streaming] def sparkContext = self.context.sparkContext + private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } @@ -46,7 +49,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ - def groupByKey(): DStream[(K, Iterable[V])] = { + def groupByKey(): DStream[(K, Iterable[V])] = ssc.withScope { groupByKey(defaultPartitioner()) } @@ -54,7 +57,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = { + def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = ssc.withScope { groupByKey(defaultPartitioner(numPartitions)) } @@ -62,7 +65,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` on each RDD. The supplied * org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ - def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = { + def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = ssc.withScope { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) @@ -75,7 +78,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ - def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { + def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner()) } @@ -84,7 +87,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { + def reduceByKey( + reduceFunc: (V, V) => V, + numPartitions: Int): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } @@ -93,9 +98,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ - def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + def reduceByKey( + reduceFunc: (V, V) => V, + partitioner: Partitioner): DStream[(K, V)] = ssc.withScope { + combineByKey((v: V) => v, reduceFunc, reduceFunc, partitioner) } /** @@ -104,12 +110,20 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * org.apache.spark.rdd.PairRDDFunctions in the Spark core documentation for more information. */ def combineByKey[C: ClassTag]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner, - mapSideCombine: Boolean = true): DStream[(K, C)] = { - new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true): DStream[(K, C)] = ssc.withScope { + val cleanedCreateCombiner = sparkContext.clean(createCombiner) + val cleanedMergeValue = sparkContext.clean(mergeValue) + val cleanedMergeCombiner = sparkContext.clean(mergeCombiner) + new ShuffledDStream[K, V, C]( + self, + cleanedCreateCombiner, + cleanedMergeValue, + cleanedMergeCombiner, + partitioner, mapSideCombine) } @@ -121,7 +135,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ - def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = { + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) } @@ -136,8 +150,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * DStream's batching interval */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : DStream[(K, Iterable[V])] = - { + : DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } @@ -157,7 +170,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, Iterable[V])] = { + ): DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -176,7 +189,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, Iterable[V])] = { + ): DStream[(K, Iterable[V])] = ssc.withScope { val createCombiner = (v: Iterable[V]) => new ArrayBuffer[V] ++= v val mergeValue = (buf: ArrayBuffer[V], v: Iterable[V]) => buf ++= v val mergeCombiner = (buf1: ArrayBuffer[V], buf2: ArrayBuffer[V]) => buf1 ++= buf2 @@ -198,7 +211,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } @@ -217,7 +230,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) reduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } @@ -238,7 +251,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -260,11 +273,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - self.reduceByKey(cleanedReduceFunc, partitioner) + ): DStream[(K, V)] = ssc.withScope { + self.reduceByKey(reduceFunc, partitioner) .window(windowDuration, slideDuration) - .reduceByKey(cleanedReduceFunc, partitioner) + .reduceByKey(reduceFunc, partitioner) } /** @@ -294,8 +306,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration = self.slideDuration, numPartitions: Int = ssc.sc.defaultParallelism, filterFunc: ((K, V)) => Boolean = null - ): DStream[(K, V)] = { - + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions), filterFunc @@ -328,7 +339,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration, partitioner: Partitioner, filterFunc: ((K, V)) => Boolean - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) @@ -349,7 +360,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) */ def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner()) } @@ -365,7 +376,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } @@ -382,9 +393,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true) } @@ -406,7 +418,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None) } @@ -425,9 +437,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner, initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true, initialRDD) } @@ -451,7 +464,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) partitioner: Partitioner, rememberPartitioner: Boolean, initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, Some(initialRDD)) } @@ -460,8 +473,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying a map function to the value of each key-value pairs in * 'this' DStream without changing the key. */ - def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuedDStream[K, V, U](self, mapValuesFunc) + def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = ssc.withScope { + new MapValuedDStream[K, V, U](self, sparkContext.clean(mapValuesFunc)) } /** @@ -470,8 +483,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) */ def flatMapValues[U: ClassTag]( flatMapValuesFunc: V => TraversableOnce[U] - ): DStream[(K, U)] = { - new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) + ): DStream[(K, U)] = ssc.withScope { + new FlatMapValuedDStream[K, V, U](self, sparkContext.clean(flatMapValuesFunc)) } /** @@ -479,7 +492,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Hash partitioning is used to generate the RDDs with Spark's default number * of partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { cogroup(other, defaultPartitioner()) } @@ -487,8 +501,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)], numPartitions: Int) - : DStream[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { cogroup(other, defaultPartitioner(numPartitions)) } @@ -499,7 +514,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def cogroup[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Iterable[V], Iterable[W]))] = { + ): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.cogroup(rdd2, partitioner) @@ -510,7 +525,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. */ - def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = { + def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = ssc.withScope { join[W](other, defaultPartitioner()) } @@ -518,7 +533,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ - def join[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (V, W))] = { + def join[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int): DStream[(K, (V, W))] = ssc.withScope { join[W](other, defaultPartitioner(numPartitions)) } @@ -529,7 +546,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def join[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (V, W))] = { + ): DStream[(K, (V, W))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.join(rdd2, partitioner) @@ -541,7 +558,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def leftOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = { + def leftOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = ssc.withScope { leftOuterJoin[W](other, defaultPartitioner()) } @@ -553,7 +571,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def leftOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (V, Option[W]))] = { + ): DStream[(K, (V, Option[W]))] = ssc.withScope { leftOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -565,7 +583,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def leftOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (V, Option[W]))] = { + ): DStream[(K, (V, Option[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.leftOuterJoin(rdd2, partitioner) @@ -577,7 +595,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def rightOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = { + def rightOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = ssc.withScope { rightOuterJoin[W](other, defaultPartitioner()) } @@ -589,7 +608,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def rightOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (Option[V], W))] = { + ): DStream[(K, (Option[V], W))] = ssc.withScope { rightOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -601,7 +620,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def rightOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Option[V], W))] = { + ): DStream[(K, (Option[V], W))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.rightOuterJoin(rdd2, partitioner) @@ -613,7 +632,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def fullOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { fullOuterJoin[W](other, defaultPartitioner()) } @@ -625,7 +645,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def fullOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (Option[V], Option[W]))] = { + ): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { fullOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -637,7 +657,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def fullOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Option[V], Option[W]))] = { + ): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.fullOuterJoin(rdd2, partitioner) @@ -651,7 +671,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def saveAsHadoopFiles[F <: OutputFormat[K, V]]( prefix: String, suffix: String - )(implicit fm: ClassTag[F]) { + )(implicit fm: ClassTag[F]): Unit = ssc.withScope { saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -667,9 +687,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) - ) { + ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) @@ -684,7 +704,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, suffix: String - )(implicit fm: ClassTag[F]) { + )(implicit fm: ClassTag[F]): Unit = ssc.withScope { saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -700,9 +720,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], conf: Configuration = ssc.sparkContext.hadoopConfiguration - ) { + ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableConfiguration(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315..a2f5d82a79bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.streaming.{Time, StreamingContext} +import java.io.{NotSerializableException, ObjectOutputStream} + +import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.streaming.{Time, StreamingContext} + private[streaming] class QueueInputDStream[T: ClassTag]( @transient ssc: StreamingContext, @@ -36,6 +37,10 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def writeObject(oos: ObjectOutputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing") + } + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 8be04314c428..6c139f32da31 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -20,11 +20,13 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.streaming._ +import org.apache.spark.storage.BlockId import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD -import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} -import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.streaming.scheduler.{ReceivedBlockInfo, RateController, StreamInputInfo} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.streaming.{StreamingContext, Time} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -39,8 +41,16 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** This is an unique identifier for the receiver input stream. */ - val id = ssc.getNewReceiverStreamId() + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration))) + } else { + None + } + } /** * Gets the receiver object that will be sent to the worker nodes @@ -67,31 +77,74 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } else { // Otherwise, ask the tracker for all the blocks that have been allocated to this stream // for this batch - val blockInfos = - ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty) - val blockStoreResults = blockInfos.map { _.blockStoreResult } - val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray - - // Check whether all the results are of the same type - val resultTypes = blockStoreResults.map { _.getClass }.distinct - if (resultTypes.size > 1) { - logWarning("Multiple result types in block information, WAL information will be ignored.") - } + val receiverTracker = ssc.scheduler.receiverTracker + val blockInfos = receiverTracker.getBlocksOfBatch(validTime).getOrElse(id, Seq.empty) - // If all the results are of type WriteAheadLogBasedStoreResult, then create - // WriteAheadLogBackedBlockRDD else create simple BlockRDD. - if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { - val logSegments = blockStoreResults.map { - _.asInstanceOf[WriteAheadLogBasedStoreResult].segment - }.toArray - // Since storeInBlockManager = false, the storage level does not matter. - new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, - blockIds, logSegments, storeInBlockManager = false, StorageLevel.MEMORY_ONLY_SER) - } else { - new BlockRDD[T](ssc.sc, blockIds) - } + // Register the input blocks information into InputInfoTracker + val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + + // Create the BlockRDD + createBlockRDD(validTime, blockInfos) } } Some(blockRDD) } + + private[streaming] def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = { + + if (blockInfos.nonEmpty) { + val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Are WAL record handles present with all the blocks + val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + + if (areWALRecordHandlesPresent) { + // If all the blocks have WAL record handle, then create a WALBackedBlockRDD + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) + } else { + // Else, create a BlockRDD. However, if there are some blocks with WAL info but not + // others then that is unexpected and log a warning accordingly. + if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + logError("Some blocks do not have Write Ahead Log information; " + + "this is unexpected and data may not be recoverable after driver failures") + } else { + logWarning("Some blocks have Write Ahead Log information; this is unexpected") + } + } + val validBlockIds = blockIds.filter { id => + ssc.sparkContext.env.blockManager.master.contains(id) + } + if (validBlockIds.size != blockIds.size) { + logWarning("Some blocks could not be recovered as they were not found in memory. " + + "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "for more details.") + } + new BlockRDD[T](ssc.sc, validBlockIds) + } + } else { + // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD + // according to the configuration + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, Array.empty, Array.empty, Array.empty) + } else { + new BlockRDD[T](ssc.sc, Array.empty) + } + } + } + + /** + * A RateController that sends the new rate to receivers, via the receiver tracker. + */ + private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) + } } + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index c0a5af0b65cc..6a583bf2a362 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -38,29 +38,29 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { + ) extends DStream[(K, V)](parent.ssc) { - assert(_windowDuration.isMultipleOf(parent.slideDuration), + require(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) - assert(_slideDuration.isMultipleOf(parent.slideDuration), + require(_slideDuration.isMultipleOf(parent.slideDuration), "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) // Reduce each batch of data using reduceByKey which will be further reduced by window // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + private val reducedStream = parent.reduceByKey(reduceFunc, partitioner) // Persist RDDs to memory by default as these RDDs are going to be reused. super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration - override def dependencies = List(reducedStream) + override def dependencies: List[DStream[_]] = List(reducedStream) override def slideDuration: Duration = _slideDuration @@ -68,7 +68,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( override def parentRememberDuration: Duration = rememberDuration + windowDuration - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + override def persist(storageLevel: StorageLevel): DStream[(K, V)] = { super.persist(storageLevel) reducedStream.persist(storageLevel) this @@ -118,7 +118,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Get the RDD of the reduced value of the previous window val previousWindowRDD = - getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K, V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 880a89bc3689..e0ffd5d86b43 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -25,19 +25,19 @@ import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( - parent: DStream[(K,V)], + parent: DStream[(K, V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true - ) extends DStream[(K,C)] (parent.ssc) { + ) extends DStream[(K, C)] (parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[(K,C)]] = { + override def compute(validTime: Time): Option[RDD[(K, C)]] = { parent.getOrCompute(validTime) match { case Some(rdd) => Some(rdd.combineByKey[C]( createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 8b72bcf20653..5ce5b7aae6e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import scala.util.control.NonFatal + import org.apache.spark.streaming.StreamingContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.NextIterator @@ -74,13 +76,17 @@ class SocketReceiver[T: ClassTag]( while(!isStopped && iterator.hasNext) { store(iterator.next) } - logInfo("Stopped receiving") - restart("Retrying connecting to " + host + ":" + port) + if (!isStopped()) { + restart("Socket data stream had no more data") + } else { + logInfo("Stopped receiving") + } } catch { case e: java.net.ConnectException => restart("Error connecting to " + host + ":" + port, e) - case t: Throwable => - restart("Error receiving data", t) + case NonFatal(e) => + logWarning("Error receiving data", e) + restart("Error receiving data", e) } finally { if (socket != null) { socket.close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index ebb04dd35b9a..621d6dff788f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration @@ -51,7 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { val itr = t._2._2.iterator - val headOption = if(itr.hasNext) Some(itr.next) else None + val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 71b61856e23c..5d46ca0715ff 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] ( require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index abbc40befa95..9405dbaa1232 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { + parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) + } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 775b6bfd065c..4efba039f895 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -44,9 +44,9 @@ class WindowedDStream[T: ClassTag]( // Persist parent level by default, as those RDDs are going to be obviously reused. parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = _slideDuration @@ -68,7 +68,7 @@ class WindowedDStream[T: ClassTag]( new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) } else { logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc,rddsInWindow) + new UnionRDD(ssc.sc, rddsInWindow) } Some(windowRDD) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index dd1e96334952..e081ffe46f50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -16,41 +16,58 @@ */ package org.apache.spark.streaming.rdd -import scala.reflect.ClassTag +import java.io.File +import java.nio.ByteBuffer +import java.util.UUID -import org.apache.hadoop.conf.Configuration +import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, WriteAheadLogRandomReader} +import org.apache.spark.streaming.util._ +import org.apache.spark.util.SerializableConfiguration /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. * It contains information about the id of the blocks having this partition's data and - * the segment of the write ahead log that backs the partition. + * the corresponding record handle in the write ahead log that backs the partition. * @param index index of the partition * @param blockId id of the block having the partition data - * @param segment segment of the write ahead log having the partition data + * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark + * executors). If not, then block lookups by the block ids will be skipped. + * By default, this is an empty array signifying true for all the blocks. + * @param walRecordHandle Handle of the record in a write ahead log having the partition data */ private[streaming] class WriteAheadLogBackedBlockRDDPartition( val index: Int, val blockId: BlockId, - val segment: WriteAheadLogFileSegment) - extends Partition + val isBlockIdValid: Boolean, + val walRecordHandle: WriteAheadLogRecordHandle + ) extends Partition /** * This class represents a special case of the BlockRDD where the data blocks in - * the block manager are also backed by segments in write ahead logs. For reading + * the block manager are also backed by data in write ahead logs. For reading * the data, this RDD first looks up the blocks by their ids in the block manager. - * If it does not find them, it looks up the corresponding file segment. + * If it does not find them, it looks up the WAL using the corresponding record handle. + * The lookup of the blocks from the block manager can be skipped by setting the corresponding + * element in isBlockIdValid to false. This is a performance optimization which does not affect + * correctness, and it can be used in situations where it is known that the block + * does not exist in the Spark executors (e.g. after a failed driver is restarted). + * * * @param sc SparkContext * @param blockIds Ids of the blocks that contains this RDD's data - * @param segments Segments in write ahead logs that contain this RDD's data - * @param storeInBlockManager Whether to store in the block manager after reading from the segment + * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data + * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark + * executors). If not, then block lookups by the block ids will be skipped. + * By default, this is an empty array signifying true for all the blocks. + * @param storeInBlockManager Whether to store a block in the block manager + * after reading it from the WAL * @param storageLevel storage level to store when storing in block manager * (applicable when storeInBlockManager = true) */ @@ -58,30 +75,39 @@ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, @transient blockIds: Array[BlockId], - @transient segments: Array[WriteAheadLogFileSegment], - storeInBlockManager: Boolean, - storageLevel: StorageLevel) + @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + storeInBlockManager: Boolean = false, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) extends BlockRDD[T](sc, blockIds) { require( - blockIds.length == segments.length, - s"Number of block ids (${blockIds.length}) must be " + - s"the same as number of segments (${segments.length}})!") + blockIds.length == walRecordHandles.length, + s"Number of block Ids (${blockIds.length}) must be " + + s" same as number of WAL record handles (${walRecordHandles.length})") + + require( + isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, + s"Number of elements in isBlockIdValid (${isBlockIdValid.length}) must be " + + s" same as number of block Ids (${blockIds.length})") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration - private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + private val broadcastedHadoopConf = new SerializableConfiguration(hadoopConfig) + + override def isValid(): Boolean = true override def getPartitions: Array[Partition] = { assertValid() - Array.tabulate(blockIds.size) { i => - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), isValid, walRecordHandles(i)) } } /** * Gets the partition data by getting the corresponding block from the block manager. - * If the block does not exist, then the data is read from the corresponding segment + * If the block does not exist, then the data is read from the corresponding record * in write ahead log files. */ override def compute(split: Partition, context: TaskContext): Iterator[T] = { @@ -90,35 +116,87 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( val blockManager = SparkEnv.get.blockManager val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockId = partition.blockId - blockManager.get(blockId) match { - case Some(block) => // Data is in Block Manager - val iterator = block.data.asInstanceOf[Iterator[T]] - logDebug(s"Read partition data of $this from block manager, block $blockId") - iterator - case None => // Data not found in Block Manager, grab it from write ahead log file - val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf) - val dataRead = reader.read(partition.segment) - reader.close() - logInfo(s"Read partition data of $this from write ahead log, segment ${partition.segment}") - if (storeInBlockManager) { - blockManager.putBytes(blockId, dataRead, storageLevel) - logDebug(s"Stored partition data of $this into block manager with level $storageLevel") - dataRead.rewind() + + def getBlockFromBlockManager(): Option[Iterator[T]] = { + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) + } + + def getBlockFromWriteAheadLog(): Iterator[T] = { + var dataRead: ByteBuffer = null + var writeAheadLog: WriteAheadLog = null + try { + // The WriteAheadLogUtils.createLog*** method needs a directory to create a + // WriteAheadLog object as the default FileBasedWriteAheadLog needs a directory for + // writing log data. However, the directory is not needed if data needs to be read, hence + // a dummy path is provided to satisfy the method parameter requirements. + // FileBasedWriteAheadLog will not create any file or directory at that path. + // FileBasedWriteAheadLog will not create any file or directory at that path. Also, + // this dummy directory should not already exist otherwise the WAL will try to recover + // past events from the directory and throw errors. + val nonExistentDirectory = new File( + System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).getAbsolutePath + writeAheadLog = WriteAheadLogUtils.createLogForReceiver( + SparkEnv.get.conf, nonExistentDirectory, hadoopConf) + dataRead = writeAheadLog.read(partition.walRecordHandle) + } catch { + case NonFatal(e) => + throw new SparkException( + s"Could not read data from write ahead log record ${partition.walRecordHandle}", e) + } finally { + if (writeAheadLog != null) { + writeAheadLog.close() + writeAheadLog = null } - blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + } + if (dataRead == null) { + throw new SparkException( + s"Could not read data from write ahead log record ${partition.walRecordHandle}, " + + s"read returned null") + } + logInfo(s"Read partition data of $this from write ahead log, record handle " + + partition.walRecordHandle) + if (storeInBlockManager) { + blockManager.putBytes(blockId, dataRead, storageLevel) + logDebug(s"Stored partition data of $this into block manager with level $storageLevel") + dataRead.rewind() + } + blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + } + + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromWriteAheadLog() } + } else { + getBlockFromWriteAheadLog() } } /** * Get the preferred location of the partition. This returns the locations of the block - * if it is present in the block manager, else it returns the location of the - * corresponding segment in HDFS. + * if it is present in the block manager, else if FileBasedWriteAheadLogSegment is used, + * it returns the location of the corresponding file segment in HDFS . */ override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] - val blockLocations = getBlockIdLocations().get(partition.blockId) - def segmentLocations = HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) - blockLocations.getOrElse(segmentLocations) + val blockLocations = if (partition.isBlockIdValid) { + getBlockIdLocations().get(partition.blockId) + } else { + None + } + + blockLocations.getOrElse { + partition.walRecordHandle match { + case fileSegment: FileBasedWriteAheadLogSegment => + try { + HdfsUtils.getFileSegmentLocations( + fileSegment.path, fileSegment.offset, fileSegment.length, hadoopConfig) + } catch { + case NonFatal(e) => + logError("Error getting WAL file segment locations", e) + Seq.empty + } + case _ => + Seq.empty + } + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index a7d63bd4f2db..7ec74016a1c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.receiver +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ @@ -25,10 +26,10 @@ import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} + import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.StorageLevel -import java.nio.ByteBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -143,19 +144,19 @@ private[streaming] class ActorReceiver[T: ClassTag]( receiverSupervisorStrategy: SupervisorStrategy ) extends Receiver[T](storageLevel) with Logging { - protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), "Supervisor" + streamId) class Supervisor extends Actor { override val supervisorStrategy = receiverSupervisorStrategy - val worker = context.actorOf(props, name) + private val worker = context.actorOf(props, name) logInfo("Started receiver worker at:" + worker.path) - val n: AtomicInteger = new AtomicInteger(0) - val hiccups: AtomicInteger = new AtomicInteger(0) + private val n: AtomicInteger = new AtomicInteger(0) + private val hiccups: AtomicInteger = new AtomicInteger(0) - def receive = { + override def receive: PartialFunction[Any, Unit] = { case IteratorData(iterator) => logDebug("received iterator") @@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag]( } } - def onStart() = { - supervisor - logInfo("Supervision tree for receivers initialized at:" + supervisor.path) - + def onStart(): Unit = { + actorSupervisor + logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) } - def onStop() = { - supervisor ! PoisonPill + def onStop(): Unit = { + actorSupervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 79263a718397..421d60ae359f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,9 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, SystemClock} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -68,74 +69,179 @@ private[streaming] trait BlockGeneratorListener { * named blocks at regular intervals. This class starts two threads, * one to periodically start a new batch and prepare the previous batch of as a block, * the other to push the blocks into the block manager. + * + * Note: Do not create BlockGenerator instances directly inside receivers. Use + * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and use it. */ private[streaming] class BlockGenerator( listener: BlockGeneratorListener, receiverId: Int, - conf: SparkConf + conf: SparkConf, + clock: Clock = new SystemClock() ) extends RateLimiter(conf) with Logging { private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) - private val clock = new SystemClock() - private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) + /** + * The BlockGenerator can be in 5 possible states, in the order as follows. + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + */ + private object GeneratorState extends Enumeration { + type GeneratorState = Value + val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, StoppedAll = Value + } + import GeneratorState._ + + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") + require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") + private val blockIntervalTimer = - new RecurringTimer(clock, blockInterval, updateCurrentBuffer, "BlockGenerator") + new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator") private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize) private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @volatile private var currentBuffer = new ArrayBuffer[Any] - @volatile private var stopped = false + @volatile private var state = Initialized /** Start block generating and pushing threads. */ - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Started BlockGenerator") + def start(): Unit = synchronized { + if (state == Initialized) { + state = Active + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") + } else { + throw new SparkException( + s"Cannot start BlockGenerator as its not in the Initialized state [state = $state]") + } } - /** Stop all threads. */ - def stop() { + /** + * Stop everything in the right order such that all the data added is pushed out correctly. + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. + */ + def stop(): Unit = { + // Set the state to stop adding data + synchronized { + if (state == Active) { + state = StoppedAddingData + } else { + logWarning(s"Cannot stop BlockGenerator as its not in the Active state [state = $state]") + return + } + } + + // Stop generating blocks and set the state for block pushing thread to start draining the queue logInfo("Stopping BlockGenerator") blockIntervalTimer.stop(interruptTimer = false) - stopped = true - logInfo("Waiting for block pushing thread") + synchronized { state = StoppedGeneratingBlocks } + + // Wait for the queue to drain and mark generated as stopped + logInfo("Waiting for block pushing thread to terminate") blockPushingThread.join() + synchronized { state = StoppedAll } logInfo("Stopped BlockGenerator") } /** - * Push a single data item into the buffer. All received data items - * will be periodically pushed into BlockManager. + * Push a single data item into the buffer. */ - def addData (data: Any): Unit = synchronized { - waitToPush() - currentBuffer += data + def addData(data: Any): Unit = { + if (state == Active) { + waitToPush() + synchronized { + if (state == Active) { + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push a single data item into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. + * `BlockGeneratorListener.onAddData` callback will be called. */ - def addDataWithCallback(data: Any, metadata: Any) = synchronized { - waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + def addDataWithCallback(data: Any, metadata: Any): Unit = { + if (state == Active) { + waitToPush() + synchronized { + if (state == Active) { + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } + /** + * Push multiple data items into the buffer. After buffering the data, the + * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items + * are atomically added to the buffer, and are hence guaranteed to be present in a single block. + */ + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = { + if (state == Active) { + // Unroll iterator into a temp buffer, and wait for pushing in the process + val tempBuffer = new ArrayBuffer[Any] + dataIterator.foreach { data => + waitToPush() + tempBuffer += data + } + synchronized { + if (state == Active) { + currentBuffer ++= tempBuffer + listener.onAddData(tempBuffer, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + + def isActive(): Boolean = state == Active + + def isStopped(): Boolean = state == StoppedAll + /** Change the buffer to which single records are added to. */ - private def updateCurrentBuffer(time: Long): Unit = synchronized { + private def updateCurrentBuffer(time: Long): Unit = { try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[Any] - if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer) - listener.onGenerateBlock(blockId) + var newBlock: Block = null + synchronized { + if (currentBuffer.nonEmpty) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[Any] + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) + listener.onGenerateBlock(blockId) + newBlock = new Block(blockId, newBlockBuffer) + } + } + + if (newBlock != null) { blocksForPushing.put(newBlock) // put is blocking when queue is full - logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) } } catch { case ie: InterruptedException => @@ -148,18 +254,25 @@ private[streaming] class BlockGenerator( /** Keep pushing blocks to the BlockManager. */ private def keepPushingBlocks() { logInfo("Started block pushing thread") + + def areBlocksBeingGenerated: Boolean = synchronized { + state != StoppedGeneratingBlocks + } + try { - while(!stopped) { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + // While blocks are being generated, keep polling for to-be-pushed blocks and push them. + while (areBlocksBeingGenerated) { + Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => } } - // Push out the blocks that are still left + + // At this point, state is StoppedGeneratingBlock. So drain the queue of to-be-pushed blocks. logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") while (!blocksForPushing.isEmpty) { - logDebug("Getting block ") val block = blocksForPushing.take() + logDebug(s"Pushing block $block") pushBlock(block) logInfo("Blocks left to push " + blocksForPushing.size()) } @@ -176,7 +289,7 @@ private[streaming] class BlockGenerator( logError(message, t) listener.onError(message, t) } - + private def pushBlock(block: Block) { listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index e4f6ba626ebb..bca1fbc8fda2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.receiver +import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} + import org.apache.spark.{Logging, SparkConf} -import java.util.concurrent.TimeUnit._ /** Provides waitToPush() method to limit the rate at which receivers consume data. * @@ -33,37 +34,31 @@ import java.util.concurrent.TimeUnit._ */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private var lastSyncTime = System.nanoTime - private var messagesWrittenSinceSync = 0L - private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + // treated as an upper limit + private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue) + private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble) def waitToPush() { - if( desiredRate <= 0 ) { - return - } - val now = System.nanoTime - val elapsedNanosecs = math.max(now - lastSyncTime, 1) - val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs - if (rate < desiredRate) { - // It's okay to write; just update some variables and return - messagesWrittenSinceSync += 1 - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - messagesWrittenSinceSync = 1 - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate - val elapsedTimeInMillis = elapsedNanosecs / 1000000 - val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis - if (sleepTimeInMillis > 0) { - logTrace("Natural rate is " + rate + " per second but desired rate is " + - desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") - Thread.sleep(sleepTimeInMillis) + rateLimiter.acquire() + } + + /** + * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. + */ + def getCurrentLimit: Long = rateLimiter.getRate.toLong + + /** + * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by + * {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that. + * + * @param newRate A new rate in events per second. It has no effect if it's 0 or negative. + */ + private[receiver] def updateRate(newRate: Long): Unit = + if (newRate > 0) { + if (maxRateLimit > 0) { + rateLimiter.setRate(newRate.min(maxRateLimit)) + } else { + rateLimiter.setRate(newRate) } - waitToPush() } - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index f7a8ebee8a54..5f6c5b024085 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,22 +17,25 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.{existentials, postfixOps} -import WriteAheadLogBasedBlockHandler._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ -import org.apache.spark.streaming.util.{Clock, SystemClock, WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.Utils +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} +import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { - def blockId: StreamBlockId // Any implementation of this trait will store a block id + // Any implementation of this trait will store a block id + def blockId: StreamBlockId + // Any implementation of this trait will have to return the number of records + def numRecords: Option[Long] } /** Trait that represents a class that handles the storage of blocks received by receiver */ @@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler { * that stores the metadata related to storage of blocks using * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] */ -private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) +private[streaming] case class BlockManagerBasedStoreResult( + blockId: StreamBlockId, numRecords: Option[Long]) extends ReceivedBlockStoreResult @@ -62,13 +66,22 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI private[streaming] class BlockManagerBasedBlockHandler( blockManager: BlockManager, storageLevel: StorageLevel) extends ReceivedBlockHandler with Logging { - + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + var numRecords = None: Option[Long] + val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + numRecords = Some(arrayBuffer.size.toLong) + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, + tellMaster = true) case IteratorBlock(iterator) => - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + val countIterator = new CountingIterator(iterator) + val putResult = blockManager.putIterator(blockId, countIterator, storageLevel, + tellMaster = true) + numRecords = countIterator.count + putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) case o => @@ -79,7 +92,7 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } - BlockManagerBasedStoreResult(blockId) + BlockManagerBasedStoreResult(blockId, numRecords) } def cleanupOldBlocks(threshTime: Long) { @@ -96,7 +109,8 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, - segment: WriteAheadLogFileSegment + numRecords: Option[Long], + walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -116,10 +130,6 @@ private[streaming] class WriteAheadLogBasedBlockHandler( private val blockStoreTimeout = conf.getInt( "spark.streaming.receiver.blockStoreTimeout", 30).seconds - private val rollingInterval = conf.getInt( - "spark.streaming.receiver.writeAheadLog.rollingInterval", 60) - private val maxFailures = conf.getInt( - "spark.streaming.receiver.writeAheadLog.maxFailures", 3) private val effectiveStorageLevel = { if (storageLevel.deserialized) { @@ -139,18 +149,14 @@ private[streaming] class WriteAheadLogBasedBlockHandler( s"$effectiveStorageLevel when write ahead log is enabled") } - // Manages rolling log files - private val logManager = new WriteAheadLogManager( - checkpointDirToLogDir(checkpointDir, streamId), - hadoopConf, rollingInterval, maxFailures, - callerName = this.getClass.getSimpleName, - clock = clock - ) + // Write ahead log manages + private val writeAheadLog = WriteAheadLogUtils.createLogForReceiver( + conf, checkpointDirToLogDir(checkpointDir, streamId), hadoopConf) // For processing futures used in parallel block storing into block manager and write ahead log // # threads = 2, so that both writing to BM and WAL can proceed in parallel implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) + ThreadUtils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) /** * This implementation stores the block into the block manager as well as a write ahead log. @@ -159,12 +165,17 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => + numRecords = Some(arrayBuffer.size.toLong) blockManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => - blockManager.dataSerialize(blockId, iterator) + val countIterator = new CountingIterator(iterator) + val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + numRecords = countIterator.count + serializedBlock case ByteBufferBlock(byteBuffer) => byteBuffer case _ => @@ -183,21 +194,22 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in write ahead log val storeInWriteAheadLogFuture = Future { - logManager.writeToLog(serializedBlock) + writeAheadLog.write(serializedBlock, clock.getTimeMillis()) } - // Combine the futures, wait for both to complete, and return the write ahead log segment + // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) - val segment = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, segment) + val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) + WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { - logManager.cleanupOldLogs(threshTime, waitForCompletion = false) + writeAheadLog.clean(threshTime, false) } def stop() { - logManager.stop() + writeAheadLog.close() + executionContext.shutdown() } } @@ -206,3 +218,23 @@ private[streaming] object WriteAheadLogBasedBlockHandler { new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString } } + +/** + * A utility that will wrap the Iterator to get the count + */ +private[streaming] class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { + private var _count = 0 + + private def isFullyConsumed: Boolean = !iterator.hasNext + + def hasNext(): Boolean = iterator.hasNext + + def count(): Option[Long] = { + if (isFullyConsumed) Some(_count) else None + } + + def next(): T = { + _count += 1 + iterator.next() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5acf8a9a811e..2252e28f22af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi @@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * being pushed into Spark's memory. */ def store(dataItem: T) { - executor.pushSingle(dataItem) + supervisor.pushSingle(dataItem) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def store(dataBuffer: ArrayBuffer[T]) { - executor.pushArrayBuffer(dataBuffer, None, None) + supervisor.pushArrayBuffer(dataBuffer, None, None) } /** @@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataBuffer: ArrayBuffer[T], metadata: Any) { - executor.pushArrayBuffer(dataBuffer, Some(metadata), None) + supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator.asScala, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator.asScala, None, None) } /** @@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** @@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - executor.pushBytes(bytes, None, None) + supervisor.pushBytes(bytes, None, None) } /** @@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(bytes: ByteBuffer, metadata: Any) { - executor.pushBytes(bytes, Some(metadata), None) + supervisor.pushBytes(bytes, Some(metadata), None) } /** Report exceptions in receiving data. */ def reportError(message: String, throwable: Throwable) { - executor.reportError(message, throwable) + supervisor.reportError(message, throwable) } /** @@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` will be reported to the driver. */ def restart(message: String) { - executor.restartReceiver(message) + supervisor.restartReceiver(message) } /** @@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` and `exception` will be reported to the driver. */ def restart(message: String, error: Throwable) { - executor.restartReceiver(message, Some(error)) + supervisor.restartReceiver(message, Some(error)) } /** @@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * in a background thread. */ def restart(message: String, error: Throwable, millisecond: Int) { - executor.restartReceiver(message, Some(error), millisecond) + supervisor.restartReceiver(message, Some(error), millisecond) } /** Stop the receiver completely. */ def stop(message: String) { - executor.stop(message, None) + supervisor.stop(message, None) } /** Stop the receiver completely due to an exception */ def stop(message: String, error: Throwable) { - executor.stop(message, Some(error)) + supervisor.stop(message, Some(error)) } /** Check if the receiver has started or not. */ def isStarted(): Boolean = { - executor.isReceiverStarted() + supervisor.isReceiverStarted() } /** @@ -238,14 +238,14 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * the receiving of data should be stopped. */ def isStopped(): Boolean = { - executor.isReceiverStopped() + supervisor.isReceiverStopped() } /** * Get the unique identifier the receiver input stream that this * receiver is associated with. */ - def streamId = id + def streamId: Int = id /* * ================= @@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - private[streaming] var executor_ : ReceiverSupervisor = null + @transient private var _supervisor : ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ private[streaming] def setReceiverId(id_ : Int) { @@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Attach Network Receiver executor to this receiver. */ - private[streaming] def attachExecutor(exec: ReceiverSupervisor) { - assert(executor_ == null) - executor_ = exec + private[streaming] def attachSupervisor(exec: ReceiverSupervisor) { + assert(_supervisor == null) + _supervisor = exec } - /** Get the attached executor. */ - private def executor = { - assert(executor_ != null, "Executor has not been attached to this receiver") - executor_ + /** Get the attached supervisor. */ + private[streaming] def supervisor: ReceiverSupervisor = { + assert(_supervisor != null, + "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "some computation in the receiver before the Receiver.onStart() has been called.") + _supervisor } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index 7bf3c3331949..1eb55affaa9d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -23,4 +23,5 @@ import org.apache.spark.streaming.Time private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage - +private[streaming] case class UpdateRateLimit(elementsPerSecond: Long) + extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 1f0244c251eb..158d1ba2f183 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -18,14 +18,15 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer +import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ +import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkEnv, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import java.util.concurrent.CountDownLatch -import scala.concurrent._ -import ExecutionContext.Implicits.global +import org.apache.spark.util.{Utils, ThreadUtils} /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -36,15 +37,18 @@ private[streaming] abstract class ReceiverSupervisor( conf: SparkConf ) extends Logging { - /** Enumeration to identify current state of the StreamingContext */ + /** Enumeration to identify current state of the Receiver */ object ReceiverState extends Enumeration { type CheckpointState = Value val Initialized, Started, Stopped = Value } import ReceiverState._ - // Attach the executor to the receiver - receiver.attachExecutor(this) + // Attach the supervisor to the receiver + receiver.attachSupervisor(this) + + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) /** Receiver id */ protected val streamId = receiver.streamId @@ -55,6 +59,9 @@ private[streaming] abstract class ReceiverSupervisor( /** Time between a receiver is stopped and started again */ private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + /** The current maximum rate limit for this receiver. */ + private[streaming] def getCurrentRateLimit: Long = Long.MaxValue + /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null @@ -85,17 +92,34 @@ private[streaming] abstract class ReceiverSupervisor( optionalBlockId: Option[StreamBlockId] ) + /** + * Create a custom [[BlockGenerator]] that the receiver implementation can directly control + * using their provided [[BlockGeneratorListener]]. + * + * Note: Do not explicitly start or stop the `BlockGenerator`, the `ReceiverSupervisorImpl` + * will take care of it. + */ + def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator + /** Report errors. */ def reportError(message: String, throwable: Throwable) - /** Called when supervisor is started */ + /** + * Called when supervisor is started. + * Note that this must be called before the receiver.onStart() is called to ensure + * things like [[BlockGenerator]]s are started before the receiver starts sending data. + */ protected def onStart() { } - /** Called when supervisor is stopped */ + /** + * Called when supervisor is stopped. + * Note that this must be called after the receiver.onStop() is called to ensure + * things like [[BlockGenerator]]s are cleaned up after the receiver stops sending data. + */ protected def onStop(message: String, error: Option[Throwable]) { } - /** Called when receiver is started */ - protected def onReceiverStart() { } + /** Called when receiver is started. Return true if the driver accepts us */ + protected def onReceiverStart(): Boolean /** Called when receiver is stopped */ protected def onReceiverStop(message: String, error: Option[Throwable]) { } @@ -111,19 +135,24 @@ private[streaming] abstract class ReceiverSupervisor( stoppingError = error.orNull stopReceiver(message, error) onStop(message, error) + futureExecutionContext.shutdownNow() stopLatch.countDown() } /** Start receiver */ def startReceiver(): Unit = synchronized { try { - logInfo("Starting receiver") - receiver.onStart() - logInfo("Called receiver onStart") - onReceiverStart() - receiverState = Started + if (onReceiverStart()) { + logInfo("Starting receiver") + receiverState = Started + receiver.onStart() + logInfo("Called receiver onStart") + } else { + // The driver refused us + stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) + } } catch { - case t: Throwable => + case NonFatal(t) => stop("Error starting receiver " + streamId, Some(t)) } } @@ -132,12 +161,19 @@ private[streaming] abstract class ReceiverSupervisor( def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized { try { logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse("")) - receiverState = Stopped - receiver.onStop() - logInfo("Called receiver onStop") - onReceiverStop(message, error) + receiverState match { + case Initialized => + logWarning("Skip stopping receiver because it has not yet stared") + case Started => + receiverState = Stopped + receiver.onStop() + logInfo("Called receiver onStop") + onReceiverStop(message, error) + case Stopped => + logWarning("Receiver has been stopped") + } } catch { - case t: Throwable => + case NonFatal(t) => logError("Error stopping receiver " + streamId + t.getStackTraceString) } } @@ -150,6 +186,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Restart receiver with delay */ def restartReceiver(message: String, error: Option[Throwable], delay: Int) { Future { + // This is a blocking action so we should use "futureExecutionContext" which is a cached + // thread pool. logWarning("Restarting receiver with delay " + delay + " ms: " + message, error.getOrElse(null)) stopReceiver("Restarting receiver with delay " + delay + "ms: " + message, error) @@ -158,17 +196,17 @@ private[streaming] abstract class ReceiverSupervisor( logInfo("Starting receiver again") startReceiver() logInfo("Receiver started again") - } + }(futureExecutionContext) } - /** Check if receiver has been marked for stopping */ - def isReceiverStarted() = { + /** Check if receiver has been marked for starting */ + def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started } /** Check if receiver has been marked for stopping */ - def isReceiverStopped() = { + def isReceiverStopped(): Boolean = { logDebug("state = " + receiverState) receiverState == Stopped } @@ -176,12 +214,12 @@ private[streaming] abstract class ReceiverSupervisor( /** Wait the thread until the supervisor is stopped */ def awaitTermination() { + logInfo("Waiting for receiver to be stopped") stopLatch.await() - logInfo("Waiting for executor stop is over") if (stoppingError != null) { - logError("Stopped executor with error: " + stoppingError) + logError("Stopped receiver with error: " + stoppingError) } else { - logWarning("Stopped executor without error") + logInfo("Stopped receiver without error") } if (stoppingError != null) { throw stoppingError diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 7d29ed88cfcb..59ef58d232ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -20,19 +20,19 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import akka.actor.{Actor, Props} -import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.util.RpcUtils +import org.apache.spark.{Logging, SparkEnv, SparkException} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -47,8 +47,10 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { + private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val receivedBlockHandler: ReceivedBlockHandler = { - if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { throw new SparkException( "Cannot enable receiver write-ahead log without checkpoint directory set. " + @@ -63,43 +65,37 @@ private[streaming] class ReceiverSupervisorImpl( } - /** Remote Akka actor for the ReceiverTracker */ - private val trackerActor = { - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = AkkaUtils.address( - AkkaUtils.protocol(env.actorSystem), - SparkEnv.driverActorSystemName, - ip, - port, - "ReceiverTracker") - env.actorSystem.actorSelection(url) - } - - /** Timeout for Akka actor messages */ - private val askTimeout = AkkaUtils.askTimeout(env.conf) + /** Remote RpcEndpointRef for the ReceiverTracker */ + private val trackerEndpoint = RpcUtils.makeDriverRef("ReceiverTracker", env.conf, env.rpcEnv) - /** Akka actor for receiving messages from the ReceiverTracker in the driver */ - private val actor = env.actorSystem.actorOf( - Props(new Actor { + /** RpcEndpointRef for receiving messages from the ReceiverTracker in the driver */ + private val endpoint = env.rpcEnv.setupEndpoint( + "Receiver-" + streamId + "-" + System.currentTimeMillis(), new ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = env.rpcEnv - override def receive() = { + override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") - stop("Stopped by driver", None) + ReceiverSupervisorImpl.this.stop("Stopped by driver", None) case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) + case UpdateRateLimit(eps) => + logInfo(s"Received a new rate limit: $eps.") + registeredBlockGenerators.foreach { bg => + bg.updateRate(eps) + } } - - def ref = self - }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) + }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) + private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] + with mutable.SynchronizedBuffer[BlockGenerator] + /** Divides received data records into data blocks for pushing in BlockManager. */ - private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + private val defaultBlockGeneratorListener = new BlockGeneratorListener { def onAddData(data: Any, metadata: Any): Unit = { } def onGenerateBlock(blockId: StreamBlockId): Unit = { } @@ -111,11 +107,15 @@ private[streaming] class ReceiverSupervisorImpl( def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { pushArrayBuffer(arrayBuffer, None, Some(blockId)) } - }, streamId, env.conf) + } + private val defaultBlockGenerator = createBlockGenerator(defaultBlockGeneratorListener) + + /** Get the current rate limit of the default block generator */ + override private[streaming] def getCurrentRateLimit: Long = defaultBlockGenerator.getCurrentLimit /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { - blockGenerator.addData(data) + defaultBlockGenerator.addData(data) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ @@ -152,53 +152,54 @@ private[streaming] class ReceiverSupervisorImpl( blockIdOption: Option[StreamBlockId] ) { val blockId = blockIdOption.getOrElse(nextBlockId) - val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size - case _ => -1 - } - val time = System.currentTimeMillis val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") - - val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) - Await.result(future, askTimeout) + val numRecords = blockStoreResult.numRecords + val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) + trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } /** Report error to the receiver tracker */ def reportError(message: String, error: Throwable) { val errorString = Option(error).map(Throwables.getStackTraceAsString).getOrElse("") - trackerActor ! ReportError(streamId, message, errorString) + trackerEndpoint.send(ReportError(streamId, message, errorString)) logWarning("Reported error " + message + " - " + error) } override protected def onStart() { - blockGenerator.start() + registeredBlockGenerators.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - blockGenerator.stop() - env.actorSystem.stop(actor) + registeredBlockGenerators.foreach { _.stop() } + env.rpcEnv.stop(endpoint) } - override protected def onReceiverStart() { + override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), actor) - val future = trackerActor.ask(msg)(askTimeout) - Await.result(future, askTimeout) + streamId, receiver.getClass.getSimpleName, hostPort, endpoint) + trackerEndpoint.askWithRetry[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - val future = trackerActor.ask( - DeregisterReceiver(streamId, message, errorString))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithRetry[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + // Cleanup BlockGenerators that have already been stopped + registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + + val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) + registeredBlockGenerators += newBlockGenerator + newBlockGenerator + } + /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 92dc113f397c..9922b6bc1201 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,6 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch + * @param streamIdToInputInfo A map of input stream id to its input info * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -32,12 +33,15 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]], + streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] ) { + @deprecated("Use streamIdToInputInfo instead", "1.5.0") + def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) + /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is @@ -58,4 +62,9 @@ case class BatchInfo( */ def totalDelay: Option[Long] = schedulingDelay.zip(processingDelay) .map(x => x._1 + x._2).headOption + + /** + * The number of recorders received by the receivers in this batch. + */ + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala new file mode 100644 index 000000000000..deb15d075975 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.{Time, StreamingContext} + +/** + * :: DeveloperApi :: + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + */ +@DeveloperApi +case class StreamInputInfo( + inputStreamId: Int, numRecords: Long, metadata: Map[String, Any] = Map.empty) { + require(numRecords >= 0, "numRecords must not be negative") + + def metadataDescription: Option[String] = + metadata.get(StreamInputInfo.METADATA_KEY_DESCRIPTION).map(_.toString) +} + +@DeveloperApi +object StreamInputInfo { + + /** + * The key for description in `StreamInputInfo.metadata`. + */ + val METADATA_KEY_DESCRIPTION: String = "Description" +} + +/** + * This class manages all the input streams as well as their input data statistics. The information + * will be exposed through StreamingListener for monitoring. + */ +private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { + + // Map to track all the InputInfo related to specific batch time and input stream. + private val batchTimeToInputInfos = + new mutable.HashMap[Time, mutable.HashMap[Int, StreamInputInfo]] + + /** Report the input information with batch time to the tracker */ + def reportInfo(batchTime: Time, inputInfo: StreamInputInfo): Unit = synchronized { + val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, + new mutable.HashMap[Int, StreamInputInfo]()) + + if (inputInfos.contains(inputInfo.inputStreamId)) { + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + + s"$batchTime is already added into InputInfoTracker, this is a illegal state") + } + inputInfos += ((inputInfo.inputStreamId, inputInfo)) + } + + /** Get the all the input stream's information of specified batch time */ + def getInfo(batchTime: Time): Map[Int, StreamInputInfo] = synchronized { + val inputInfos = batchTimeToInputInfos.get(batchTime) + // Convert mutable HashMap to immutable Map for the caller + inputInfos.map(_.toMap).getOrElse(Map[Int, StreamInputInfo]()) + } + + /** Cleanup the tracked input information older than threshold batch time */ + def cleanup(batchThreshTime: Time): Unit = synchronized { + val timesToCleanup = batchTimeToInputInfos.keys.filter(_ < batchThreshTime) + logInfo(s"remove old batch metadata: ${timesToCleanup.mkString(" ")}") + batchTimeToInputInfos --= timesToCleanup + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 7e0f6b2cdfc0..3c481bf3491f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -25,16 +25,50 @@ import scala.util.Try */ private[streaming] class Job(val time: Time, func: () => _) { - var id: String = _ - var result: Try[_] = null + private var _id: String = _ + private var _outputOpId: Int = _ + private var isSet = false + private var _result: Try[_] = null def run() { - result = Try(func()) + _result = Try(func()) } - def setId(number: Int) { - id = "streaming job " + time + "." + number + def result: Try[_] = { + if (_result == null) { + throw new IllegalStateException("Cannot access result before job finishes") + } + _result } - override def toString = id + /** + * @return the global unique id of this Job. + */ + def id: String = { + if (!isSet) { + throw new IllegalStateException("Cannot access id before calling setId") + } + _id + } + + /** + * @return the output op id of this Job. Each Job has a unique output op id in the same JobSet. + */ + def outputOpId: Int = { + if (!isSet) { + throw new IllegalStateException("Cannot access number before calling setId") + } + _outputOpId + } + + def setOutputOpId(outputOpId: Int) { + if (isSet) { + throw new IllegalStateException("Cannot call setOutputOpId more than once") + } + isSet = true + _id = s"streaming job $time.$outputOpId" + _outputOpId = outputOpId + } + + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8632c94349bf..2de035d166e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,17 +19,17 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import akka.actor.{ActorRef, Props, Actor} - import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} -import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent -private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent +private[scheduler] case class DoCheckpoint( + time: Time, clearCheckpointDataLater: Boolean) extends JobGeneratorEvent private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent /** @@ -45,12 +45,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clock = { val clockClass = ssc.sc.conf.get( - "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - Class.forName(clockClass).newInstance().asInstanceOf[Clock] + "spark.streaming.clock", "org.apache.spark.util.SystemClock") + try { + Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] + } catch { + case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => + val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") + Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] + } } private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator") // This is marked lazy so that this is initialized after checkpoint duration has been set // in the context and the generator has been started. @@ -62,22 +68,30 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { null } - // eventActor is created when generator starts. + // eventLoop is created when generator starts. // This not being null means the scheduler has been started and not stopped - private var eventActor: ActorRef = null + private var eventLoop: EventLoop[JobGeneratorEvent] = null // last batch whose completion,checkpointing and metadata cleanup has been completed private var lastProcessedBatch: Time = null /** Start generation of jobs */ def start(): Unit = synchronized { - if (eventActor != null) return // generator has already been started + if (eventLoop != null) return // generator has already been started + + // Call checkpointWriter here to initialize it before eventLoop uses it to avoid a deadlock. + // See SPARK-10125 + checkpointWriter + + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { + override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobGeneratorEvent => processEvent(event) + override protected def onError(e: Throwable): Unit = { + jobScheduler.reportError("Error in job generator", e) } - }), "JobGenerator") + } + eventLoop.start() + if (ssc.isCheckpointPresent) { restart() } else { @@ -91,22 +105,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * checkpoints written. */ def stop(processReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // generator has already been stopped + if (eventLoop == null) return // generator has already been stopped if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") val timeWhenStopStarted = System.currentTimeMillis() - val stopTimeout = conf.getLong( - "spark.streaming.gracefulStopTimeout", - 10 * ssc.graph.batchDuration.milliseconds - ) + val stopTimeoutMs = conf.getTimeAsMs( + "spark.streaming.gracefulStopTimeout", s"${10 * ssc.graph.batchDuration.milliseconds}ms") val pollTime = 100 // To prevent graceful stop to get stuck permanently - def hasTimedOut = { - val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + def hasTimedOut: Boolean = { + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeoutMs if (timedOut) { - logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") + logWarning("Timed out while stopping the job generator (timeout = " + stopTimeoutMs + ")") } timedOut } @@ -125,7 +137,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Stopped generation timer") // Wait for the jobs to complete and checkpoints to be written - def haveAllBatchesBeenProcessed = { + def haveAllBatchesBeenProcessed: Boolean = { lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime } logInfo("Waiting for jobs to be processed and checkpoints to be written") @@ -140,9 +152,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.stop() } - // Stop the actor and checkpoint writer + // Stop the event loop and checkpoint writer if (shouldCheckpoint) checkpointWriter.stop() - ssc.env.actorSystem.stop(eventActor) + eventLoop.stop() logInfo("Stopped JobGenerator") } @@ -150,14 +162,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { - eventActor ! ClearMetadata(time) + eventLoop.post(ClearMetadata(time)) } /** * Callback called when the checkpoint of a batch has been written. */ - def onCheckpointCompletion(time: Time) { - eventActor ! ClearCheckpointData(time) + def onCheckpointCompletion(time: Time, clearCheckpointDataLater: Boolean) { + if (clearCheckpointDataLater) { + eventLoop.post(ClearCheckpointData(time)) + } } /** Processes all events */ @@ -166,7 +180,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { event match { case GenerateJobs(time) => generateJobs(time) case ClearMetadata(time) => clearMetadata(time) - case DoCheckpoint(time) => doCheckpoint(time) + case DoCheckpoint(time, clearCheckpointDataLater) => + doCheckpoint(time, clearCheckpointDataLater) case ClearCheckpointData(time) => clearCheckpointData(time) } } @@ -232,13 +247,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.generateJobs(time) // generate jobs using allocated block } match { case Success(jobs) => - val receivedBlockInfos = - jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) + val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } - eventActor ! DoCheckpoint(time) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } /** Clear DStream metadata for the given `time`. */ @@ -248,13 +262,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { - eventActor ! DoCheckpoint(time) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = true)) } else { // If checkpointing is not enabled, then delete metadata information about // received blocks (block data not saved in any case). Otherwise, wait for // checkpointing of this batch to complete. val maxRememberDuration = graph.getMaxInputStreamRememberDuration() jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) + jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -267,15 +282,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // been saved to checkpoints, so its safe to delete block metadata and data WAL files val maxRememberDuration = graph.getMaxInputStreamRememberDuration() jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) + jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration) markBatchFullyProcessed(time) } /** Perform checkpoint for the give `time`. */ - private def doCheckpoint(time: Time) { + private def doCheckpoint(time: Time, clearCheckpointDataLater: Boolean) { if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) - checkpointWriter.write(new Checkpoint(ssc, time)) + checkpointWriter.write(new Checkpoint(ssc, time), clearCheckpointDataLater) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 0e0f5bd3b9db..0cd39594ee92 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,13 +17,15 @@ package org.apache.spark.streaming.scheduler -import scala.util.{Failure, Success, Try} -import scala.collection.JavaConversions._ -import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import akka.actor.{ActorRef, Actor, Props} -import org.apache.spark.{SparkException, Logging, SparkEnv} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} + +import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -38,42 +40,55 @@ private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends J private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { - private val jobSets = new ConcurrentHashMap[Time, JobSet] + // Use of ConcurrentHashMap.keySet later causes an odd runtime problem due to Java 7/8 diff + // https://gist.github.com/AlainODea/1375759b8720a3f9f094 + private val jobSets: java.util.Map[Time, JobSet] = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = + ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() // These two are created only when scheduler starts. - // eventActor not being null means the scheduler has been started and not stopped + // eventLoop not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null - private var eventActor: ActorRef = null + // A tracker to track all the input stream information as well as processed record number + var inputInfoTracker: InputInfoTracker = null + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { - if (eventActor != null) return // scheduler has already been started + if (eventLoop != null) return // scheduler has already been started logDebug("Starting JobScheduler") - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobSchedulerEvent => processEvent(event) - } - }), "JobScheduler") + eventLoop = new EventLoop[JobSchedulerEvent]("JobScheduler") { + override protected def onReceive(event: JobSchedulerEvent): Unit = processEvent(event) - listenerBus.start() + override protected def onError(e: Throwable): Unit = reportError("Error in job scheduler", e) + } + eventLoop.start() + + // attach rate controllers of input streams to receive batch completion updates + for { + inputDStream <- ssc.graph.getInputStreams + rateController <- inputDStream.rateController + } ssc.addStreamingListener(rateController) + + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) + inputInfoTracker = new InputInfoTracker(ssc) receiverTracker.start() jobGenerator.start() logInfo("Started JobScheduler") } def stop(processAllReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // scheduler has already been stopped + if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") // First, stop receiving - receiverTracker.stop() + receiverTracker.stop(processAllReceivedData) // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. @@ -96,8 +111,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // Stop everything else listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - eventActor = null + eventLoop.stop() + eventLoop = null logInfo("Stopped JobScheduler") } @@ -105,6 +120,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (jobSet.jobs.isEmpty) { logInfo("No jobs added for time " + jobSet.time) } else { + listenerBus.post(StreamingListenerBatchSubmitted(jobSet.toBatchInfo)) jobSets.put(jobSet.time, jobSet) jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + jobSet.time) @@ -112,11 +128,15 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def getPendingTimes(): Seq[Time] = { - jobSets.keySet.toSeq + jobSets.asScala.keys.toSeq } def reportError(msg: String, e: Throwable) { - eventActor ! ErrorReported(msg, e) + eventLoop.post(ErrorReported(msg, e)) + } + + def isStarted(): Boolean = synchronized { + eventLoop != null } private def processEvent(event: JobSchedulerEvent) { @@ -134,10 +154,13 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) - if (!jobSet.hasStarted) { + val isFirstJobOfJobSet = !jobSet.hasStarted + jobSet.handleJobStart(job) + if (isFirstJobOfJobSet) { + // "StreamingListenerBatchStarted" should be posted after calling "handleJobStart" to get the + // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } @@ -166,16 +189,39 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ssc.waiter.notifyError(e) } - private class JobHandler(job: Job) extends Runnable { + private class JobHandler(job: Job) extends Runnable with Logging { def run() { - eventActor ! JobStarted(job) - // Disable checks for existing output directories in jobs launched by the streaming scheduler, - // since we may need to write output to an existing directory during checkpoint recovery; - // see SPARK-4835 for more details. - PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - job.run() + ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) + ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + try { + // We need to assign `eventLoop` to a temp variable. Otherwise, because + // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then + // it's possible that when `post` is called, `eventLoop` happens to null. + var _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobStarted(job)) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } + _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobCompleted(job)) + } + } else { + // JobScheduler has been stopped. + } + } finally { + ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) + ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) } - eventActor ! JobCompleted(job) } } } + +private[streaming] object JobScheduler { + val BATCH_TIME_PROPERTY_KEY = "spark.streaming.internal.batchTime" + val OUTPUT_OP_ID_PROPERTY_KEY = "spark.streaming.internal.outputOpId" +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 8c15a75b1b0e..95833efc9417 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,15 +28,14 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty - ) { + streamIdToInputInfo: Map[Int, StreamInputInfo] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted private var processingStartTime = -1L // when the first job of this jobset started processing private var processingEndTime = -1L // when the last job of this jobset finished processing - jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } + jobs.zipWithIndex.foreach { case (job, i) => job.setOutputOpId(i) } incompleteJobs ++= jobs def handleJobStart(job: Job) { @@ -48,24 +47,24 @@ case class JobSet( if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted = processingStartTime > 0 + def hasStarted: Boolean = processingStartTime > 0 - def hasCompleted = incompleteJobs.isEmpty + def hasCompleted: Boolean = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) - def processingDelay = processingEndTime - processingStartTime + def processingDelay: Long = processingEndTime - processingStartTime // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay = { + def totalDelay: Long = { processingEndTime - time.milliseconds } def toBatchInfo: BatchInfo = { new BatchInfo( time, - receivedBlockInfo, + streamIdToInputInfo, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 000000000000..a46c0c1b25e7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.io.ObjectInputStream +import java.util.concurrent.atomic.AtomicLong + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + init() + + protected def publish(rate: Long): Unit + + @transient + implicit private var executionContext: ExecutionContext = _ + + @transient + private var rateLimit: AtomicLong = _ + + /** + * An initialization method called both from the constructor and Serialization code. + */ + private def init() { + executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + rateLimit = new AtomicLong(-1L) + } + + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + init() + } + + /** + * Compute the new rate limit and publish it asynchronously. + */ + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for { + processingEnd <- batchCompleted.batchInfo.processingEndTime + workDelay <- batchCompleted.batchInfo.processingDelay + waitDelay <- batchCompleted.batchInfo.schedulingDelay + elems <- elements.get(streamUID).map(_.numRecords) + } computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } +} + +object RateController { + def isBackPressureEnabled(conf: SparkConf): Boolean = + conf.getBoolean("spark.streaming.backpressure.enabled", false) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala index 94beb590f52d..656ac80df897 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala @@ -17,12 +17,40 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.streaming.receiver.ReceivedBlockStoreResult +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver.{ReceivedBlockStoreResult, WriteAheadLogBasedStoreResult} +import org.apache.spark.streaming.util.WriteAheadLogRecordHandle /** Information about blocks received by the receiver */ private[streaming] case class ReceivedBlockInfo( streamId: Int, - numRecords: Long, + numRecords: Option[Long], + metadataOption: Option[Any], blockStoreResult: ReceivedBlockStoreResult - ) + ) { + + require(numRecords.isEmpty || numRecords.get >= 0, "numRecords must not be negative") + + @volatile private var _isBlockIdValid = true + + def blockId: StreamBlockId = blockStoreResult.blockId + + def walRecordHandleOption: Option[WriteAheadLogRecordHandle] = { + blockStoreResult match { + case walStoreResult: WriteAheadLogBasedStoreResult => Some(walStoreResult.walRecordHandle) + case _ => None + } + } + + /** Is the block ID valid, that is, is the block present in the Spark executors. */ + def isBlockIdValid(): Boolean = _isBlockIdValid + + /** + * Set the block ID as invalid. This is useful when it is known that the block is not present + * in the Spark executors. + */ + def setBlockIdInvalid(): Unit = { + _isBlockIdValid = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index e19ac939f9ac..f2711d1355e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -19,16 +19,17 @@ package org.apache.spark.streaming.scheduler import java.nio.ByteBuffer +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} -import org.apache.spark.util.Utils +import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} +import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -45,7 +46,7 @@ private[streaming] case class BatchCleanupEvent(times: Seq[Time]) private[streaming] case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = { - streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty) + streamIdToAllocatedBlocks.getOrElse(streamId, Seq.empty) } } @@ -63,6 +64,7 @@ private[streaming] class ReceivedBlockTracker( hadoopConf: Configuration, streamIds: Seq[Int], clock: Clock, + recoverFromWriteAheadLog: Boolean, checkpointDirOption: Option[String]) extends Logging { @@ -70,12 +72,14 @@ private[streaming] class ReceivedBlockTracker( private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] - private val logManagerOption = createLogManager() + private val writeAheadLogOption = createWriteAheadLog() private var lastAllocatedBatchTime: Time = null // Recover block information from write ahead logs - recoverFromWriteAheadLogs() + if (recoverFromWriteAheadLog) { + recoverPastEvents() + } /** Add received block. This event will get written to the write ahead log (if enabled). */ def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { @@ -150,27 +154,28 @@ private[streaming] class ReceivedBlockTracker( * returns only after the files are cleaned up. */ def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { - assert(cleanupThreshTime.milliseconds < clock.currentTime()) + require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup - logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) } /** Stop the block tracker. */ def stop() { - logManagerOption.foreach { _.stop() } + writeAheadLogOption.foreach { _.close() } } /** * Recover all the tracker actions from the write ahead logs to recover the state (unallocated * and allocated block info) prior to failure. */ - private def recoverFromWriteAheadLogs(): Unit = synchronized { + private def recoverPastEvents(): Unit = synchronized { // Insert the recovered block information def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) { logTrace(s"Recovery: Inserting added block $receivedBlockInfo") + receivedBlockInfo.setBlockIdInvalid() getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo } @@ -190,11 +195,12 @@ private[streaming] class ReceivedBlockTracker( timeToAllocatedBlocks --= batchTimes } - logManagerOption.foreach { logManager => + writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - logManager.readFromLog().foreach { byteBuffer => + writeAheadLog.readAll().asScala.foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + Utils.deserialize[ReceivedBlockTrackerLogEvent]( + byteBuffer.array, Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => @@ -208,10 +214,10 @@ private[streaming] class ReceivedBlockTracker( /** Write an update to the tracker to the write ahead log */ private def writeToLog(record: ReceivedBlockTrackerLogEvent) { - if (isLogManagerEnabled) { + if (isWriteAheadLogEnabled) { logDebug(s"Writing to log $record") - logManagerOption.foreach { logManager => - logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + writeAheadLogOption.foreach { logManager => + logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) } } } @@ -222,28 +228,15 @@ private[streaming] class ReceivedBlockTracker( } /** Optionally create the write ahead log manager only if the feature is enabled */ - private def createLogManager(): Option[WriteAheadLogManager] = { - if (conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { - if (checkpointDirOption.isEmpty) { - throw new SparkException( - "Cannot enable receiver write-ahead log without checkpoint directory set. " + - "Please use streamingContext.checkpoint() to set the checkpoint directory. " + - "See documentation for more details.") - } + private def createWriteAheadLog(): Option[WriteAheadLog] = { + checkpointDirOption.map { checkpointDir => val logDir = ReceivedBlockTracker.checkpointDirToLogDir(checkpointDirOption.get) - val rollingIntervalSecs = conf.getInt( - "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) - val logManager = new WriteAheadLogManager(logDir, hadoopConf, - rollingIntervalSecs = rollingIntervalSecs, clock = clock, - callerName = "ReceivedBlockHandlerMaster") - Some(logManager) - } else { - None + WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) } } - /** Check if the log manager is enabled. This is only used for testing purposes. */ - private[streaming] def isLogManagerEnabled: Boolean = logManagerOption.nonEmpty + /** Check if the write ahead log is enabled. This is only used for testing purposes. */ + private[streaming] def isWriteAheadLogEnabled: Boolean = writeAheadLogOption.nonEmpty } private[streaming] object ReceivedBlockTracker { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index d7e39c528c51..59df892397fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import akka.actor.ActorRef import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef /** * :: DeveloperApi :: @@ -28,10 +28,10 @@ import org.apache.spark.annotation.DeveloperApi case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val actor: ActorRef, active: Boolean, location: String, lastErrorMessage: String = "", - lastError: String = "" + lastError: String = "", + lastErrorTime: Long = -1L ) { } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala new file mode 100644 index 000000000000..10b5a7f57a80 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.Map +import scala.collection.mutable + +import org.apache.spark.streaming.receiver.Receiver + +/** + * A class that tries to schedule receivers with evenly distributed. There are two phases for + * scheduling receivers. + * + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers with evenly distributed. ReceiverTracker should update its + * receiverTrackingInfoMap according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledExecutors` for each receiver will set to an executor list that + * contains the scheduled locations. Then when a receiver is starting, it will send a register + * request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled executors is set, it should check + * if the location of this receiver is one of the scheduled executors, if not, the register will + * be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * executors mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are + * still alive in the list of scheduled executors, then use them to launch the receiver job. + * - If a receiver is restarting without a scheduled executors list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should + * not set `ReceiverTrackingInfo.scheduledExecutors` for this executor, instead, it should clear + * it. Then when this receiver is registering, we can know this is a local scheduling, and + * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching + * location is matching. + * + * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, + * otherwise do local scheduling. + */ +private[streaming] class ReceiverSchedulingPolicy { + + /** + * Try our best to schedule receivers with evenly distributed. However, if the + * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly + * because we have to respect them. + * + * Here is the approach to schedule executors: + *
      + *
    1. First, schedule all the receivers with preferred locations (hosts), evenly among the + * executors running on those host.
    2. + *
    3. Then, schedule all other receivers evenly among all the executors such that overall + * distribution over all the receivers is even.
    4. + *
    + * + * This method is called when we start to launch receivers at the first time. + */ + def scheduleReceivers( + receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + if (receivers.isEmpty) { + return Map.empty + } + + if (executors.isEmpty) { + return receivers.map(_.streamId -> Seq.empty).toMap + } + + val hostToExecutors = executors.groupBy(_.split(":")(0)) + val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // Set the initial value to 0 + executors.foreach(e => numReceiversOnExecutor(e) = 0) + + // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", + // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. + for (i <- 0 until receivers.length) { + // Note: preferredLocation is host but executors are host:port + receivers(i).preferredLocation.foreach { host => + hostToExecutors.get(host) match { + case Some(executorsOnHost) => + // preferredLocation is a known host. Select an executor that has the least receivers in + // this host + val leastScheduledExecutor = + executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) + scheduledExecutors(i) += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = + numReceiversOnExecutor(leastScheduledExecutor) + 1 + case None => + // preferredLocation is an unknown host. + // Note: There are two cases: + // 1. This executor is not up. But it may be up later. + // 2. This executor is dead, or it's not a host in the cluster. + // Currently, simply add host to the scheduled executors. + scheduledExecutors(i) += host + } + } + } + + // For those receivers that don't have preferredLocation, make sure we assign at least one + // executor to them. + for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + // Select the executor that has the least receivers + val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) + scheduledExecutorsForOneReceiver += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 + } + + // Assign idle executors to receivers that have less executors + val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) + for (executor <- idleExecutors) { + // Assign an idle executor to the receiver that has least candidate executors. + val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + leastScheduledExecutors += executor + } + + receivers.map(_.streamId).zip(scheduledExecutors).toMap + } + + /** + * Return a list of candidate executors to run the receiver. If the list is empty, the caller can + * run this receiver in arbitrary executor. + * + * This method tries to balance executors' load. Here is the approach to schedule executors + * for a receiver. + *
      + *
    1. + * If preferredLocation is set, preferredLocation should be one of the candidate executors. + *
    2. + *
    3. + * Every executor will be assigned to a weight according to the receivers running or + * scheduling on it. + *
        + *
      • + * If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + *
      • + *
      • + * If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.
      • + *
      + * At last, if there are any idle executors (weight = 0), returns all idle executors. + * Otherwise, returns the executors that have the minimum weight. + *
    4. + *
    + * + * This method is called when a receiver is registering with ReceiverTracker or is restarting. + */ + def rescheduleReceiver( + receiverId: Int, + preferredLocation: Option[String], + receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], + executors: Seq[String]): Seq[String] = { + if (executors.isEmpty) { + return Seq.empty + } + + // Always try to schedule to the preferred locations + val scheduledExecutors = mutable.Set[String]() + scheduledExecutors ++= preferredLocation + + val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) + } + }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + + val idleExecutors = executors.toSet -- executorWeights.keys + if (idleExecutors.nonEmpty) { + scheduledExecutors ++= idleExecutors + } else { + // There is no idle executor. So select all executors that have the minimum weight. + val sortedExecutors = executorWeights.toSeq.sortBy(_._2) + if (sortedExecutors.nonEmpty) { + val minWeight = sortedExecutors(0)._2 + scheduledExecutors ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) + } else { + // This should not happen since "executors" is not empty + } + } + scheduledExecutors.toSeq + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 4f998869731e..3d532a675db0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,15 +17,27 @@ package org.apache.spark.streaming.scheduler +import java.util.concurrent.{TimeUnit, CountDownLatch} -import scala.collection.mutable.{HashMap, SynchronizedMap} +import scala.collection.mutable.HashMap +import scala.concurrent.ExecutionContext import scala.language.existentials +import scala.util.{Failure, Success} -import akka.actor._ - -import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} + + +/** Enumeration to identify current state of a Receiver */ +private[streaming] object ReceiverState extends Enumeration { + type ReceiverState = Value + val INACTIVE, SCHEDULED, ACTIVE = Value +} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -35,8 +47,8 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - host: String, - receiverActor: ActorRef + hostPort: String, + receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) extends ReceiverTrackerMessage @@ -44,6 +56,39 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage +/** + * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally. + */ +private[streaming] sealed trait ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver. + */ +private[streaming] case class RestartReceiver(receiver: Receiver[_]) + extends ReceiverTrackerLocalMessage + +/** + * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers + * at the first time. + */ +private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]]) + extends ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered + * receivers. + */ +private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage + +/** + * A message used by ReceiverTracker to ask all receiver's ids still stored in + * ReceiverTrackerEndpoint. + */ +private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage + +private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) + extends ReceiverTrackerLocalMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -56,46 +101,97 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private val receiverInputStreams = ssc.graph.getReceiverInputStreams() private val receiverInputStreamIds = receiverInputStreams.map { _.id } - private val receiverExecutor = new ReceiverLauncher() - private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] private val receivedBlockTracker = new ReceivedBlockTracker( ssc.sparkContext.conf, ssc.sparkContext.hadoopConfiguration, receiverInputStreamIds, ssc.scheduler.clock, + ssc.isCheckpointPresent, Option(ssc.checkpointDir) ) private val listenerBus = ssc.scheduler.listenerBus - // actor is created when generator starts. + /** Enumeration to identify current state of the ReceiverTracker */ + object TrackerState extends Enumeration { + type TrackerState = Value + val Initialized, Started, Stopping, Stopped = Value + } + import TrackerState._ + + /** State of the tracker. Protected by "trackerStateLock" */ + @volatile private var trackerState = Initialized + + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null + + private val schedulingPolicy = new ReceiverSchedulingPolicy() - /** Start the actor and receiver execution thread. */ - def start() = synchronized { - if (actor != null) { + // Track the active receiver job number. When a receiver job exits ultimately, countDown will + // be called. + private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size) + + /** + * Track all receivers' information. The key is the receiver id, the value is the receiver info. + * It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo] + + /** + * Store all preferred locations for all receivers. We need this information to schedule + * receivers. It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverPreferredLocations = new HashMap[Int, Option[String]] + + /** Start the endpoint and receiver execution thread. */ + def start(): Unit = synchronized { + if (isTrackerStarted) { throw new SparkException("ReceiverTracker already started") } if (!receiverInputStreams.isEmpty) { - actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), - "ReceiverTracker") - if (!skipReceiverLaunch) receiverExecutor.start() + endpoint = ssc.env.rpcEnv.setupEndpoint( + "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) + if (!skipReceiverLaunch) launchReceivers() logInfo("ReceiverTracker started") + trackerState = Started } } /** Stop the receiver execution thread. */ - def stop() = synchronized { - if (!receiverInputStreams.isEmpty && actor != null) { + def stop(graceful: Boolean): Unit = synchronized { + if (isTrackerStarted) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop() + trackerState = Stopping + if (!skipReceiverLaunch) { + // Send the stop signal to all the receivers + endpoint.askWithRetry[Boolean](StopAllReceivers) + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + receiverJobExitLatch.await(10, TimeUnit.SECONDS) + + if (graceful) { + logInfo("Waiting for receiver job to terminate gracefully") + receiverJobExitLatch.await() + logInfo("Waited for receiver job to terminate gracefully") + } + + // Check if all the receivers have been deregistered or not + val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds) + if (receivers.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receivers) + } else { + logInfo("All of the receivers have deregistered successfully") + } + } - // Finally, stop the actor - ssc.env.actorSystem.stop(actor) - actor = null + // Finally, stop the endpoint + ssc.env.rpcEnv.stop(endpoint) + endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") + trackerState = Stopped } } @@ -113,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Get the blocks allocated to the given batch and stream. */ def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { - synchronized { - receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) - } + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) } /** @@ -127,10 +221,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) // Signal the receivers to delete old block data - if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.actor) } - .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -138,30 +235,64 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - host: String, - receiverActor: ActorRef, - sender: ActorRef - ) { + hostPort: String, + receiverEndpoint: RpcEndpointRef, + senderAddress: RpcAddress + ): Boolean = { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverActor, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) + + if (isTrackerStopping || isTrackerStopped) { + return false + } + + val scheduledExecutors = receiverTrackingInfos(streamId).scheduledExecutors + val accetableExecutors = if (scheduledExecutors.nonEmpty) { + // This receiver is registering and it's scheduled by + // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledExecutors" to check it. + scheduledExecutors.get + } else { + // This receiver is scheduled by "ReceiverSchedulingPolicy.rescheduleReceiver", so calling + // "ReceiverSchedulingPolicy.rescheduleReceiver" again to check it. + scheduleReceiver(streamId) + } + + if (!accetableExecutors.contains(hostPort)) { + // Refuse it since it's scheduled to a wrong executor + false + } else { + val name = s"${typ}-${streamId}" + val receiverTrackingInfo = ReceiverTrackingInfo( + streamId, + ReceiverState.ACTIVE, + scheduledExecutors = None, + runningExecutor = Some(hostPort), + name = Some(name), + endpoint = Some(receiverEndpoint)) + receiverTrackingInfos.put(streamId, receiverTrackingInfo) + listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo)) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + true + } } /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val lastErrorTime = + if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() + val errorInfo = ReceiverErrorInfo( + lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime) + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) + oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo -= streamId - listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo)) + receiverTrackingInfos -= streamId + listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -170,6 +301,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logError(s"Deregistered receiver for stream $streamId: $messageWithError") } + /** Update a receiver's maximum ingestion rate */ + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + } + } + /** Add new blocks for the given stream */ private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { receivedBlockTracker.addBlock(receivedBlockInfo) @@ -177,15 +315,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Report error sent by a receiver */ private def reportError(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(lastErrorMessage = message, lastError = error) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L)) + oldInfo.copy(errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo(streamId) = newReceiverInfo - listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + + receiverTrackingInfos(streamId) = newReceiverTrackingInfo + listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -194,117 +338,266 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } + private def scheduleReceiver(receiverId: Int): Seq[String] = { + val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiverId, preferredLocation, receiverTrackingInfos, getExecutors) + updateReceiverScheduledExecutors(receiverId, scheduledExecutors) + scheduledExecutors + } + + private def updateReceiverScheduledExecutors( + receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { + case Some(oldInfo) => + oldInfo.copy(state = ReceiverState.SCHEDULED, + scheduledExecutors = Some(scheduledExecutors)) + case None => + ReceiverTrackingInfo( + receiverId, + ReceiverState.SCHEDULED, + Some(scheduledExecutors), + runningExecutor = None) + } + receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) + } + /** Check if any blocks are left to be processed */ def hasUnallocatedBlocks: Boolean = { receivedBlockTracker.hasUnallocatedReceivedBlocks } - /** Actor to receive messages from the receivers. */ - private class ReceiverTrackerActor extends Actor { - def receive = { - case RegisterReceiver(streamId, typ, host, receiverActor) => - registerReceiver(streamId, typ, host, receiverActor, sender) - sender ! true - case AddBlock(receivedBlockInfo) => - sender ! addBlock(receivedBlockInfo) - case ReportError(streamId, message, error) => - reportError(streamId, message, error) - case DeregisterReceiver(streamId, message, error) => - deregisterReceiver(streamId, message, error) - sender ! true + /** + * Get the list of executors excluding driver + */ + private def getExecutors: Seq[String] = { + if (ssc.sc.isLocal) { + Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + } else { + ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => + blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location + }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq } } - /** This thread class runs all the receivers on the cluster. */ - class ReceiverLauncher { - @transient val env = ssc.env - @transient val thread = new Thread() { - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") - } - } + /** + * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the + * receivers to be scheduled on the same node. + * + * TODO Should poll the executor number and wait for executors according to + * "spark.scheduler.minRegisteredResourcesRatio" and + * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job. + */ + private def runDummySparkJob(): Unit = { + if (!ssc.sparkContext.isLocal) { + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() } + assert(getExecutors.nonEmpty) + } - def start() { - thread.start() - } + /** + * Get the receivers from the ReceiverInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + private def launchReceivers(): Unit = { + val receivers = receiverInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setReceiverId(nis.id) + rcvr + }) - def stop() { - // Send the stop signal to all the receivers - stopReceivers() + runDummySparkJob() - // Wait for the Spark job that runs the receivers to be over - // That is, for the receivers to quit gracefully. - thread.join(10000) + logInfo("Starting " + receivers.length + " receivers") + endpoint.send(StartAllReceivers(receivers)) + } - // Check if all the receivers have been deregistered or not - if (!receiverInfo.isEmpty) { - logWarning("All of the receivers have not deregistered, " + receiverInfo) - } else { - logInfo("All of the receivers have deregistered successfully") - } + /** Check if tracker has been marked for starting */ + private def isTrackerStarted: Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping: Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped: Boolean = trackerState == Stopped + + /** RpcEndpoint to receive messages from the receivers. */ + private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + + // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged + private val submitJobThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + + override def receive: PartialFunction[Any, Unit] = { + // Local messages + case StartAllReceivers(receivers) => + val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + for (receiver <- receivers) { + val executors = scheduledExecutors(receiver.streamId) + updateReceiverScheduledExecutors(receiver.streamId, executors) + receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation + startReceiver(receiver, executors) + } + case RestartReceiver(receiver) => + // Old scheduled executors minus the ones that are not active any more + val oldScheduledExecutors = getStoredScheduledExecutors(receiver.streamId) + val scheduledExecutors = if (oldScheduledExecutors.nonEmpty) { + // Try global scheduling again + oldScheduledExecutors + } else { + val oldReceiverInfo = receiverTrackingInfos(receiver.streamId) + // Clear "scheduledExecutors" to indicate we are going to do local scheduling + val newReceiverInfo = oldReceiverInfo.copy( + state = ReceiverState.INACTIVE, scheduledExecutors = None) + receiverTrackingInfos(receiver.streamId) = newReceiverInfo + schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + } + // Assume there is one receiver restarting at one time, so we don't need to update + // receiverTrackingInfos + startReceiver(receiver, scheduledExecutors) + case c: CleanupOldBlocks => + receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) + case UpdateReceiverRateLimit(streamUID, newRate) => + for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) { + eP.send(UpdateRateLimit(newRate)) + } + // Remote messages + case ReportError(streamId, message, error) => + reportError(streamId, message, error) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Remote messages + case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => + val successful = + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) + context.reply(successful) + case AddBlock(receivedBlockInfo) => + context.reply(addBlock(receivedBlockInfo)) + case DeregisterReceiver(streamId, message, error) => + deregisterReceiver(streamId, message, error) + context.reply(true) + // Local messages + case AllReceiverIds => + context.reply(receiverTrackingInfos.keys.toSeq) + case StopAllReceivers => + assert(isTrackerStopping || isTrackerStopped) + stopReceivers() + context.reply(true) } /** - * Get the receivers from the ReceiverInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. + * Return the stored scheduled executors that are still alive. */ - private def startReceivers() { - val receivers = receiverInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setReceiverId(nis.id) - rcvr - }) - - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get))) - ssc.sc.makeRDD[Receiver[_]](receiversWithPreferences) + private def getStoredScheduledExecutors(receiverId: Int): Seq[String] = { + if (receiverTrackingInfos.contains(receiverId)) { + val scheduledExecutors = receiverTrackingInfos(receiverId).scheduledExecutors + if (scheduledExecutors.nonEmpty) { + val executors = getExecutors.toSet + // Only return the alive executors + scheduledExecutors.get.filter(executors) } else { - ssc.sc.makeRDD(receivers, receivers.size) + Nil } + } else { + Nil + } + } + + /** + * Start a receiver along with its scheduled executors + */ + private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + def shouldStartReceiver: Boolean = { + // It's okay to start when trackerState is Initialized or Started + !(isTrackerStopping || isTrackerStopped) + } + + val receiverId = receiver.streamId + if (!shouldStartReceiver) { + onReceiverJobFinish(receiverId) + return + } val checkpointDirOption = Option(ssc.checkpointDir) - val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + val serializableHadoopConf = + new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[Receiver[_]]) => { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") + val startReceiverFunc: Iterator[Receiver[_]] => Unit = + (iterator: Iterator[Receiver[_]]) => { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } } - val receiver = iterator.next() - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - if (!ssc.sparkContext.isLocal) { - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - } - // Distribute the receivers and start them - logInfo("Starting " + receivers.length + " receivers") - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - logInfo("All of the receivers have been terminated") + // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + val receiverRDD: RDD[Receiver[_]] = + if (scheduledExecutors.isEmpty) { + ssc.sc.makeRDD(Seq(receiver), 1) + } else { + ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) + } + receiverRDD.setName(s"Receiver $receiverId") + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( + receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) + // We will keep restarting the receiver job until ReceiverTracker is stopped + future.onComplete { + case Success(_) => + if (!shouldStartReceiver) { + onReceiverJobFinish(receiverId) + } else { + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + case Failure(e) => + if (!shouldStartReceiver) { + onReceiverJobFinish(receiverId) + } else { + logError("Receiver has been stopped. Try to restart it.", e) + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + }(submitJobThreadPool) + logInfo(s"Receiver ${receiver.streamId} started") + } + + override def onStop(): Unit = { + submitJobThreadPool.shutdownNow() } - /** Stops the receivers. */ + /** + * Call when a receiver is terminated. It means we won't restart its Spark job. + */ + private def onReceiverJobFinish(receiverId: Int): Unit = { + receiverJobExitLatch.countDown() + receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo => + if (receiverTrackingInfo.state == ReceiverState.ACTIVE) { + logWarning(s"Receiver $receiverId exited but didn't deregister") + } + } + } + + /** Send stop signal to the receivers. */ private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.actor)} - .foreach { _ ! StopReceiver } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") + receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers") } } + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala new file mode 100644 index 000000000000..043ff4d0ff05 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.streaming.scheduler.ReceiverState._ + +private[streaming] case class ReceiverErrorInfo( + lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L) + +/** + * Class having information about a receiver. + * + * @param receiverId the unique receiver id + * @param state the current Receiver state + * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param runningExecutor the running executor if the receiver is active + * @param name the receiver name + * @param endpoint the receiver endpoint. It can be used to send messages to the receiver + * @param errorInfo the receiver error information if it fails + */ +private[streaming] case class ReceiverTrackingInfo( + receiverId: Int, + state: ReceiverState, + scheduledExecutors: Option[Seq[String]], + runningExecutor: Option[String], + name: Option[String] = None, + endpoint: Option[RpcEndpointRef] = None, + errorInfo: Option[ReceiverErrorInfo] = None) { + + def toReceiverInfo: ReceiverInfo = ReceiverInfo( + receiverId, + name.getOrElse(""), + state == ReceiverState.ACTIVE, + location = runningExecutor.getOrElse(""), + lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), + lastError = errorInfo.map(_.lastError).getOrElse(""), + lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) + ) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala new file mode 100644 index 000000000000..84a3ca9d74e5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.Logging + +/** + * Implements a proportional-integral-derivative (PID) controller which acts on + * the speed of ingestion of elements into Spark Streaming. A PID controller works + * by calculating an '''error''' between a measured output and a desired value. In the + * case of Spark Streaming the error is the difference between the measured processing + * rate (number of elements/processing delay) and the previous rate. + * + * @see https://en.wikipedia.org/wiki/PID_controller + * + * @param batchIntervalMillis the batch duration, in milliseconds + * @param proportional how much the correction should depend on the current + * error. This term usually provides the bulk of correction and should be positive or zero. + * A value too large would make the controller overshoot the setpoint, while a small value + * would make the controller too insensitive. The default value is 1. + * @param integral how much the correction should depend on the accumulation + * of past errors. This value should be positive or 0. This term accelerates the movement + * towards the desired value, but a large value may lead to overshooting. The default value + * is 0.2. + * @param derivative how much the correction should depend on a prediction + * of future errors, based on current rate of change. This value should be positive or 0. + * This term is not used very often, as it impacts stability of the system. The default + * value is 0. + * @param minRate what is the minimum rate that can be estimated. + * This must be greater than zero, so that the system always receives some data for rate + * estimation to work. + */ +private[streaming] class PIDRateEstimator( + batchIntervalMillis: Long, + proportional: Double, + integral: Double, + derivative: Double, + minRate: Double + ) extends RateEstimator with Logging { + + private var firstRun: Boolean = true + private var latestTime: Long = -1L + private var latestRate: Double = -1D + private var latestError: Double = -1L + + require( + batchIntervalMillis > 0, + s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.") + require( + proportional >= 0, + s"Proportional term $proportional in PIDRateEstimator should be >= 0.") + require( + integral >= 0, + s"Integral term $integral in PIDRateEstimator should be >= 0.") + require( + derivative >= 0, + s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + require( + minRate > 0, + s"Minimum rate in PIDRateEstimator should be > 0") + + logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " + + s"derivative = $derivative, min rate = $minRate") + + def compute( + time: Long, // in milliseconds + numElements: Long, + processingDelay: Long, // in milliseconds + schedulingDelay: Long // in milliseconds + ): Option[Double] = { + logTrace(s"\ntime = $time, # records = $numElements, " + + s"processing time = $processingDelay, scheduling delay = $schedulingDelay") + this.synchronized { + if (time > latestTime && numElements > 0 && processingDelay > 0) { + + // in seconds, should be close to batchDuration + val delaySinceUpdate = (time - latestTime).toDouble / 1000 + + // in elements/second + val processingRate = numElements.toDouble / processingDelay * 1000 + + // In our system `error` is the difference between the desired rate and the measured rate + // based on the latest batch information. We consider the desired rate to be latest rate, + // which is what this estimator calculated for the previous batch. + // in elements/second + val error = latestRate - processingRate + + // The error integral, based on schedulingDelay as an indicator for accumulated errors. + // A scheduling delay s corresponds to s * processingRate overflowing elements. Those + // are elements that couldn't be processed in previous batches, leading to this delay. + // In the following, we assume the processingRate didn't change too much. + // From the number of overflowing elements we can calculate the rate at which they would be + // processed by dividing it by the batch interval. This rate is our "historical" error, + // or integral part, since if we subtracted this rate from the previous "calculated rate", + // there wouldn't have been any overflowing elements, and the scheduling delay would have + // been zero. + // (in elements/second) + val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis + + // in elements/(second ^ 2) + val dError = (error - latestError) / delaySinceUpdate + + val newRate = (latestRate - proportional * error - + integral * historicalError - + derivative * dError).max(minRate) + logTrace(s""" + | latestRate = $latestRate, error = $error + | latestError = $latestError, historicalError = $historicalError + | delaySinceUpdate = $delaySinceUpdate, dError = $dError + """.stripMargin) + + latestTime = time + if (firstRun) { + latestRate = processingRate + latestError = 0D + firstRun = false + logTrace("First run, rate estimation skipped") + None + } else { + latestRate = newRate + latestError = error + logTrace(s"New rate = $newRate") + Some(newRate) + } + } else { + logTrace("Rate estimation skipped") + None + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 000000000000..d7210f64fcc3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.Duration + +/** + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + * + * @param time The timetamp of the current batch interval that just finished + * @param elements The number of elements that were processed in this batch + * @param processingDelay The time in ms that took for the job to complete + * @param schedulingDelay The time in ms that the job spent in the scheduling queue + */ + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] +} + +object RateEstimator { + + /** + * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * + * The only known estimator right now is `pid`. + * + * @return An instance of RateEstimator + * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any + * known estimators. + */ + def create(conf: SparkConf, batchInterval: Duration): RateEstimator = + conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { + case "pid" => + val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) + val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) + val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) + val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) + + case estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala new file mode 100644 index 000000000000..f702bd5bc946 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import java.text.SimpleDateFormat +import java.util.Date + +import scala.xml.Node + +import org.apache.spark.ui.{UIUtils => SparkUIUtils} + +private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) { + + protected def columns: Seq[Node] = { +
    Batch TimeInput SizeScheduling Delay + {SparkUIUtils.tooltip("Time taken by Streaming scheduler to submit jobs of a batch", "top")} + Processing Time + {SparkUIUtils.tooltip("Time taken to process all jobs of a batch", "top")} + + {formattedBatchTime} + + {eventCount.toString} events + {formattedSchedulingDelay} + + {formattedProcessingTime} +
    + + {columns} + + + {renderRows} + +
    + } + + def toNodeSeq: Seq[Node] = { + batchTable + } + + /** + * Return HTML for all rows of this table. + */ + protected def renderRows: Seq[Node] +} + +private[ui] class ActiveBatchTable( + runningBatches: Seq[BatchUIData], + waitingBatches: Seq[BatchUIData], + batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { + + override protected def columns: Seq[Node] = super.columns ++ Status + + override protected def renderRows: Seq[Node] = { + // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display + // waiting batches before running batches + waitingBatches.flatMap(batch => {waitingBatchRow(batch)}) ++ + runningBatches.flatMap(batch => {runningBatchRow(batch)}) + } + + private def runningBatchRow(batch: BatchUIData): Seq[Node] = { + baseRow(batch) ++ processing + } + + private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { + baseRow(batch) ++ queued + } +} + +private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) + extends BatchTableBase("completed-batches-table", batchInterval) { + + override protected def columns: Seq[Node] = super.columns ++ + Total Delay + {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} + + override protected def renderRows: Seq[Node] = { + batches.flatMap(batch => {completedBatchRow(batch)}) + } + + private def completedBatchRow(batch: BatchUIData): Seq[Node] = { + val totalDelay = batch.totalDelay + val formattedTotalDelay = totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + baseRow(batch) ++ + + {formattedTotalDelay} + + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala new file mode 100644 index 000000000000..90d1b0fadecf --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node, Text, Unparsed} + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.streaming.Time +import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.streaming.ui.StreamingJobProgressListener.{SparkJobId, OutputOpId} +import org.apache.spark.ui.jobs.UIData.JobUIData + +private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) + +private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { + private val streamingListener = parent.listener + private val sparkListener = parent.ssc.sc.jobProgressListener + + private def columns: Seq[Node] = { + Output Op Id + Description + Duration + Job Id + Duration + Stages: Succeeded/Total + Tasks (for all stages): Succeeded/Total + Error + } + + private def generateJobRow( + outputOpId: OutputOpId, + outputOpDescription: Seq[Node], + formattedOutputOpDuration: String, + numSparkJobRowsInOutputOp: Int, + isFirstRow: Boolean, + sparkJob: SparkJobIdWithUIData): Seq[Node] = { + if (sparkJob.jobUIData.isDefined) { + generateNormalJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, + numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) + } else { + generateDroppedJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, + numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) + } + } + + /** + * Generate a row for a Spark Job. Because duplicated output op infos needs to be collapsed into + * one cell, we use "rowspan" for the first row of a output op. + */ + private def generateNormalJobRow( + outputOpId: OutputOpId, + outputOpDescription: Seq[Node], + formattedOutputOpDuration: String, + numSparkJobRowsInOutputOp: Int, + isFirstRow: Boolean, + sparkJob: JobUIData): Seq[Node] = { + val duration: Option[Long] = { + sparkJob.submissionTime.map { start => + val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) + end - start + } + } + val lastFailureReason = + sparkJob.stageIds.sorted.reverse.flatMap(sparkListener.stageIdToInfo.get). + dropWhile(_.failureReason == None).take(1). // get the first info that contains failure + flatMap(info => info.failureReason).headOption.getOrElse("") + val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") + val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + + // In the first row, output op id and its information needs to be shown. In other rows, these + // cells will be taken up due to "rowspan". + // scalastyle:off + val prefixCells = + if (isFirstRow) { + {outputOpId.toString} + + {outputOpDescription} + + {formattedOutputOpDuration} + } else { + Nil + } + // scalastyle:on + + + {prefixCells} + + + {sparkJob.jobId}{sparkJob.jobGroup.map(id => s"($id)").getOrElse("")} + + + + {formattedDuration} + + + {sparkJob.completedStageIndices.size}/{sparkJob.stageIds.size - sparkJob.numSkippedStages} + {if (sparkJob.numFailedStages > 0) s"(${sparkJob.numFailedStages} failed)"} + {if (sparkJob.numSkippedStages > 0) s"(${sparkJob.numSkippedStages} skipped)"} + + + { + SparkUIUtils.makeProgressBar( + started = sparkJob.numActiveTasks, + completed = sparkJob.numCompletedTasks, + failed = sparkJob.numFailedTasks, + skipped = sparkJob.numSkippedTasks, + total = sparkJob.numTasks - sparkJob.numSkippedTasks) + } + + {failureReasonCell(lastFailureReason)} + + } + + /** + * If a job is dropped by sparkListener due to exceeding the limitation, we only show the job id + * with "-" cells. + */ + private def generateDroppedJobRow( + outputOpId: OutputOpId, + outputOpDescription: Seq[Node], + formattedOutputOpDuration: String, + numSparkJobRowsInOutputOp: Int, + isFirstRow: Boolean, + jobId: Int): Seq[Node] = { + // In the first row, output op id and its information needs to be shown. In other rows, these + // cells will be taken up due to "rowspan". + // scalastyle:off + val prefixCells = + if (isFirstRow) { + {outputOpId.toString} + {outputOpDescription} + {formattedOutputOpDuration} + } else { + Nil + } + // scalastyle:on + + + {prefixCells} + + {jobId.toString} + + + - + + - + + - + + - + + } + + private def generateOutputOpIdRow( + outputOpId: OutputOpId, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { + // We don't count the durations of dropped jobs + val sparkJobDurations = sparkJobs.filter(_.jobUIData.nonEmpty).map(_.jobUIData.get). + map(sparkJob => { + sparkJob.submissionTime.map { start => + val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) + end - start + } + }) + val formattedOutputOpDuration = + if (sparkJobDurations.isEmpty || sparkJobDurations.exists(_ == None)) { + // If no job or any job does not finish, set "formattedOutputOpDuration" to "-" + "-" + } else { + SparkUIUtils.formatDuration(sparkJobDurations.flatMap(x => x).sum) + } + + val description = generateOutputOpDescription(sparkJobs) + + generateJobRow( + outputOpId, description, formattedOutputOpDuration, sparkJobs.size, true, sparkJobs.head) ++ + sparkJobs.tail.map { sparkJob => + generateJobRow( + outputOpId, description, formattedOutputOpDuration, sparkJobs.size, false, sparkJob) + }.flatMap(x => x) + } + + private def generateOutputOpDescription(sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { + val lastStageInfo = + sparkJobs.flatMap(_.jobUIData).headOption. // Get the first JobUIData + flatMap { sparkJob => // For the first job, get the latest Stage info + if (sparkJob.stageIds.isEmpty) { + None + } else { + sparkListener.stageIdToInfo.get(sparkJob.stageIds.max) + } + } + val lastStageData = lastStageInfo.flatMap { s => + sparkListener.stageIdToData.get((s.stageId, s.attemptId)) + } + + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") + + + {lastStageDescription} + ++ Text(lastStageName) + } + + private def failureReasonCell(failureReason: String): Seq[Node] = { + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ +

    + // scalastyle:on + } else { + "" + } + {failureReasonSummary}{details} + } + + private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { + sparkListener.activeJobs.get(sparkJobId).orElse { + sparkListener.completedJobs.find(_.jobId == sparkJobId).orElse { + sparkListener.failedJobs.find(_.jobId == sparkJobId) + } + } + } + + /** + * Generate the job table for the batch. + */ + private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId).toSeq. + sortBy(_._1). // sorted by OutputOpId + map { case (outputOpId, outputOpIdAndSparkJobIds) => + // sort SparkJobIds for each OutputOpId + (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) + } + sparkListener.synchronized { + val outputOpIdWithJobs: Seq[(OutputOpId, Seq[SparkJobIdWithUIData])] = + outputOpIdToSparkJobIds.map { case (outputOpId, sparkJobIds) => + (outputOpId, + sparkJobIds.map(sparkJobId => SparkJobIdWithUIData(sparkJobId, getJobData(sparkJobId)))) + } + + + + {columns} + + + { + outputOpIdWithJobs.map { + case (outputOpId, sparkJobIds) => generateOutputOpIdRow(outputOpId, sparkJobIds) + } + } + +
    + } + } + + def render(request: HttpServletRequest): Seq[Node] = { + val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { + throw new IllegalArgumentException(s"Missing id parameter") + } + val formattedBatchTime = + UIUtils.formatBatchTime(batchTime.milliseconds, streamingListener.batchDuration) + + val batchUIData = streamingListener.getBatchUIData(batchTime).getOrElse { + throw new IllegalArgumentException(s"Batch $formattedBatchTime does not exist") + } + + val formattedSchedulingDelay = + batchUIData.schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val formattedProcessingTime = + batchUIData.processingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val formattedTotalDelay = batchUIData.totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + + val inputMetadatas = batchUIData.streamIdToInputInfo.values.flatMap { inputInfo => + inputInfo.metadataDescription.map(desc => inputInfo.inputStreamId -> desc) + }.toSeq + val summary: NodeSeq = +
    +
      +
    • + Batch Duration: + {SparkUIUtils.formatDuration(streamingListener.batchDuration)} +
    • +
    • + Input data size: + {batchUIData.numRecords} records +
    • +
    • + Scheduling delay: + {formattedSchedulingDelay} +
    • +
    • + Processing time: + {formattedProcessingTime} +
    • +
    • + Total delay: + {formattedTotalDelay} +
    • + { + if (inputMetadatas.nonEmpty) { +
    • + Input Metadata:{generateInputMetadataTable(inputMetadatas)} +
    • + } + } +
    +
    + + val jobTable = + if (batchUIData.outputOpIdSparkJobIdPairs.isEmpty) { +
    Cannot find any job for Batch {formattedBatchTime}.
    + } else { + generateJobTable(batchUIData) + } + + val content = summary ++ jobTable + + SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + } + + def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { + + + + + + + + + {inputMetadatas.flatMap(generateInputMetadataRow)} + +
    InputMetadata
    + } + + def generateInputMetadataRow(inputMetadata: (Int, String)): Seq[Node] = { + val streamId = inputMetadata._1 + + + {streamingListener.streamName(streamId).getOrElse(s"Stream-$streamId")} + {metadataDescriptionToHTML(inputMetadata._2)} + + } + + private def metadataDescriptionToHTML(metadataDescription: String): Seq[Node] = { + // tab to 4 spaces and "\n" to "
    " + Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). + replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
    ")) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala new file mode 100644 index 000000000000..ae508c0e9577 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming.ui + +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} +import org.apache.spark.streaming.ui.StreamingJobProgressListener._ + +private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) + +private[ui] case class BatchUIData( + val batchTime: Time, + val streamIdToInputInfo: Map[Int, StreamInputInfo], + val submissionTime: Long, + val processingStartTime: Option[Long], + val processingEndTime: Option[Long], + var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { + + /** + * Time taken for the first job of this batch to start processing from the time this batch + * was submitted to the streaming scheduler. Essentially, it is + * `processingStartTime` - `submissionTime`. + */ + def schedulingDelay: Option[Long] = processingStartTime.map(_ - submissionTime) + + /** + * Time taken for the all jobs of this batch to finish processing from the time they started + * processing. Essentially, it is `processingEndTime` - `processingStartTime`. + */ + def processingDelay: Option[Long] = { + for (start <- processingStartTime; + end <- processingEndTime) + yield end - start + } + + /** + * Time taken for all the jobs of this batch to finish processing from the time they + * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + */ + def totalDelay: Option[Long] = processingEndTime.map(_ - submissionTime) + + /** + * The number of recorders received by the receivers in this batch. + */ + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum +} + +private[ui] object BatchUIData { + + def apply(batchInfo: BatchInfo): BatchUIData = { + new BatchUIData( + batchInfo.batchTime, + batchInfo.streamIdToInputInfo, + batchInfo.submissionTime, + batchInfo.processingStartTime, + batchInfo.processingEndTime + ) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 5ee53a5c5f56..78aeb004e18b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,28 +17,57 @@ package org.apache.spark.streaming.ui +import java.util.LinkedHashMap +import java.util.{Map => JMap} +import java.util.Properties + +import scala.collection.mutable.{ArrayBuffer, Queue, HashMap, SynchronizedBuffer} + +import org.apache.spark.scheduler._ import org.apache.spark.streaming.{Time, StreamingContext} import org.apache.spark.streaming.scheduler._ -import scala.collection.mutable.{Queue, HashMap} import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted -import org.apache.spark.streaming.scheduler.BatchInfo import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted -import org.apache.spark.util.Distribution private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) - extends StreamingListener { + extends StreamingListener with SparkListener { - private val waitingBatchInfos = new HashMap[Time, BatchInfo] - private val runningBatchInfos = new HashMap[Time, BatchInfo] - private val completedaBatchInfos = new Queue[BatchInfo] - private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + private val waitingBatchUIData = new HashMap[Time, BatchUIData] + private val runningBatchUIData = new HashMap[Time, BatchUIData] + private val completedBatchUIData = new Queue[BatchUIData] + private val batchUIDataLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) private var totalCompletedBatches = 0L private var totalReceivedRecords = 0L private var totalProcessedRecords = 0L private val receiverInfos = new HashMap[Int, ReceiverInfo] + // Because onJobStart and onBatchXXX messages are processed in different threads, + // we may not be able to get the corresponding BatchUIData when receiving onJobStart. So here we + // cannot use a map of (Time, BatchUIData). + private[ui] val batchTimeToOutputOpIdSparkJobIdPair = + new LinkedHashMap[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]] { + override def removeEldestEntry( + p1: JMap.Entry[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]]): Boolean = { + // If a lot of "onBatchCompleted"s happen before "onJobStart" (image if + // SparkContext.listenerBus is very slow), "batchTimeToOutputOpIdToSparkJobIds" + // may add some information for a removed batch when processing "onJobStart". It will be a + // memory leak. + // + // To avoid the memory leak, we control the size of "batchTimeToOutputOpIdToSparkJobIds" and + // evict the eldest one. + // + // Note: if "onJobStart" happens before "onBatchSubmitted", the size of + // "batchTimeToOutputOpIdToSparkJobIds" may be greater than the number of the retained + // batches temporarily, so here we use "10" to handle such case. This is not a perfect + // solution, but at least it can handle most of cases. + size() > + waitingBatchUIData.size + runningBatchUIData.size + completedBatchUIData.size + 10 + } + } + + val batchDuration = ssc.graph.batchDuration.milliseconds override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { @@ -59,33 +88,72 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + synchronized { + waitingBatchUIData(batchSubmitted.batchInfo.batchTime) = + BatchUIData(batchSubmitted.batchInfo) + } + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { + val batchUIData = BatchUIData(batchStarted.batchInfo) + runningBatchUIData(batchStarted.batchInfo.batchTime) = BatchUIData(batchStarted.batchInfo) + waitingBatchUIData.remove(batchStarted.batchInfo.batchTime) + + totalReceivedRecords += batchUIData.numRecords } - override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { - runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo - waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + synchronized { + waitingBatchUIData.remove(batchCompleted.batchInfo.batchTime) + runningBatchUIData.remove(batchCompleted.batchInfo.batchTime) + val batchUIData = BatchUIData(batchCompleted.batchInfo) + completedBatchUIData.enqueue(batchUIData) + if (completedBatchUIData.size > batchUIDataLimit) { + val removedBatch = completedBatchUIData.dequeue() + batchTimeToOutputOpIdSparkJobIdPair.remove(removedBatch.batchTime) + } + totalCompletedBatches += 1L - batchStarted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalReceivedRecords += infos.map(_.numRecords).sum + totalProcessedRecords += batchUIData.numRecords } } - override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { - waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) - runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() - totalCompletedBatches += 1L + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { + getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => + var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) + if (outputOpIdToSparkJobIds == null) { + outputOpIdToSparkJobIds = + new ArrayBuffer[OutputOpIdAndSparkJobId]() + with SynchronizedBuffer[OutputOpIdAndSparkJobId] + batchTimeToOutputOpIdSparkJobIdPair.put(batchTime, outputOpIdToSparkJobIds) + } + outputOpIdToSparkJobIds += OutputOpIdAndSparkJobId(outputOpId, jobStart.jobId) + } + } - batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalProcessedRecords += infos.map(_.numRecords).sum + private def getBatchTimeAndOutputOpId(properties: Properties): Option[(Time, Int)] = { + val batchTime = properties.getProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY) + if (batchTime == null) { + // Not submitted from JobScheduler + None + } else { + val outputOpId = properties.getProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY) + assert(outputOpId != null) + Some(Time(batchTime.toLong) -> outputOpId.toInt) } } - def numReceivers = synchronized { - ssc.graph.getReceiverInputStreams().size + def numReceivers: Int = synchronized { + receiverInfos.size + } + + def numActiveReceivers: Int = synchronized { + receiverInfos.count(_._2.active) + } + + def numInactiveReceivers: Int = { + ssc.graph.getReceiverInputStreams().size - numActiveReceivers } def numTotalCompletedBatches: Long = synchronized { @@ -101,78 +169,94 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def numUnprocessedBatches: Long = synchronized { - waitingBatchInfos.size + runningBatchInfos.size - } - - def waitingBatches: Seq[BatchInfo] = synchronized { - waitingBatchInfos.values.toSeq - } - - def runningBatches: Seq[BatchInfo] = synchronized { - runningBatchInfos.values.toSeq + waitingBatchUIData.size + runningBatchUIData.size } - def retainedCompletedBatches: Seq[BatchInfo] = synchronized { - completedaBatchInfos.toSeq + def waitingBatches: Seq[BatchUIData] = synchronized { + waitingBatchUIData.values.toSeq } - def processingDelayDistribution: Option[Distribution] = synchronized { - extractDistribution(_.processingDelay) + def runningBatches: Seq[BatchUIData] = synchronized { + runningBatchUIData.values.toSeq } - def schedulingDelayDistribution: Option[Distribution] = synchronized { - extractDistribution(_.schedulingDelay) + def retainedCompletedBatches: Seq[BatchUIData] = synchronized { + completedBatchUIData.toSeq } - def totalDelayDistribution: Option[Distribution] = synchronized { - extractDistribution(_.totalDelay) + def streamName(streamId: Int): Option[String] = { + ssc.graph.getInputStreamName(streamId) } - def receivedRecordsDistributions: Map[Int, Option[Distribution]] = synchronized { - val latestBatchInfos = retainedBatches.reverse.take(batchInfoLimit) - val latestBlockInfos = latestBatchInfos.map(_.receivedBlockInfo) - (0 until numReceivers).map { receiverId => - val blockInfoOfParticularReceiver = latestBlockInfos.map { batchInfo => - batchInfo.get(receiverId).getOrElse(Array.empty) - } - val recordsOfParticularReceiver = blockInfoOfParticularReceiver.map { blockInfo => - // calculate records per second for each batch - blockInfo.map(_.numRecords).sum.toDouble * 1000 / batchDuration + /** + * Return all InputDStream Ids + */ + def streamIds: Seq[Int] = ssc.graph.getInputStreams().map(_.id) + + /** + * Return all of the event rates for each InputDStream in each batch. The key of the return value + * is the stream id, and the value is a sequence of batch time with its event rate. + */ + def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { + val _retainedBatches = retainedBatches + val latestBatches = _retainedBatches.map { batchUIData => + (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) + } + streamIds.map { streamId => + val eventRates = latestBatches.map { + case (batchTime, streamIdToNumRecords) => + val numRecords = streamIdToNumRecords.getOrElse(streamId, 0L) + (batchTime, numRecords * 1000.0 / batchDuration) } - val distributionOption = Distribution(recordsOfParticularReceiver) - (receiverId, distributionOption) + (streamId, eventRates) }.toMap } - def lastReceivedBatchRecords: Map[Int, Long] = { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) + def lastReceivedBatchRecords: Map[Int, Long] = synchronized { + val lastReceivedBlockInfoOption = + lastReceivedBatch.map(_.streamIdToInputInfo.mapValues(_.numRecords)) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => - (0 until numReceivers).map { receiverId => - (receiverId, lastReceivedBlockInfo(receiverId).map(_.numRecords).sum) + streamIds.map { streamId => + (streamId, lastReceivedBlockInfo.getOrElse(streamId, 0L)) }.toMap }.getOrElse { - (0 until numReceivers).map(receiverId => (receiverId, 0L)).toMap + streamIds.map(streamId => (streamId, 0L)).toMap } } - def receiverInfo(receiverId: Int): Option[ReceiverInfo] = { + def receiverInfo(receiverId: Int): Option[ReceiverInfo] = synchronized { receiverInfos.get(receiverId) } - def lastCompletedBatch: Option[BatchInfo] = { - completedaBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + def lastCompletedBatch: Option[BatchUIData] = synchronized { + completedBatchUIData.sortBy(_.batchTime)(Time.ordering).lastOption } - def lastReceivedBatch: Option[BatchInfo] = { + def lastReceivedBatch: Option[BatchUIData] = synchronized { retainedBatches.lastOption } - private def retainedBatches: Seq[BatchInfo] = synchronized { - (waitingBatchInfos.values.toSeq ++ - runningBatchInfos.values.toSeq ++ completedaBatchInfos).sortBy(_.batchTime)(Time.ordering) + def retainedBatches: Seq[BatchUIData] = synchronized { + (waitingBatchUIData.values.toSeq ++ + runningBatchUIData.values.toSeq ++ completedBatchUIData).sortBy(_.batchTime)(Time.ordering) } - private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { - Distribution(completedaBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + def getBatchUIData(batchTime: Time): Option[BatchUIData] = synchronized { + val batchUIData = waitingBatchUIData.get(batchTime).orElse { + runningBatchUIData.get(batchTime).orElse { + completedBatchUIData.find(batch => batch.batchTime == batchTime) + } + } + batchUIData.foreach { _batchUIData => + val outputOpIdToSparkJobIds = + Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime)).getOrElse(Seq.empty) + _batchUIData.outputOpIdSparkJobIdPairs = outputOpIdToSparkJobIds + } + batchUIData } } + +private[streaming] object StreamingJobProgressListener { + type SparkJobId = Int + type OutputOpId = Int +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 98e9a2e639e2..96d943e75d27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -17,177 +17,544 @@ package org.apache.spark.streaming.ui -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.collection.mutable.ArrayBuffer +import scala.xml.{Node, Unparsed} import org.apache.spark.Logging import org.apache.spark.ui._ -import org.apache.spark.ui.UIUtils._ -import org.apache.spark.util.Distribution +import org.apache.spark.ui.{UIUtils => SparkUIUtils} + +/** + * A helper class to generate JavaScript and HTML for both timeline and histogram graphs. + * + * @param timelineDivId the timeline `id` used in the html `div` tag + * @param histogramDivId the timeline `id` used in the html `div` tag + * @param data the data for the graph + * @param minX the min value of X axis + * @param maxX the max value of X axis + * @param minY the min value of Y axis + * @param maxY the max value of Y axis + * @param unitY the unit of Y axis + * @param batchInterval if `batchInterval` is not None, we will draw a line for `batchInterval` in + * the graph + */ +private[ui] class GraphUIData( + timelineDivId: String, + histogramDivId: String, + data: Seq[(Long, Double)], + minX: Long, + maxX: Long, + minY: Double, + maxY: Double, + unitY: String, + batchInterval: Option[Double] = None) { + + private var dataJavaScriptName: String = _ + + def generateDataJs(jsCollector: JsCollector): Unit = { + val jsForData = data.map { case (x, y) => + s"""{"x": $x, "y": $y}""" + }.mkString("[", ",", "]") + dataJavaScriptName = jsCollector.nextVariableName + jsCollector.addPreparedStatement(s"var $dataJavaScriptName = $jsForData;") + } + + def generateTimelineHtml(jsCollector: JsCollector): Seq[Node] = { + jsCollector.addPreparedStatement(s"registerTimeline($minY, $maxY);") + if (batchInterval.isDefined) { + jsCollector.addStatement( + "drawTimeline(" + + s"'#$timelineDivId', $dataJavaScriptName, $minX, $maxX, $minY, $maxY, '$unitY'," + + s" ${batchInterval.get}" + + ");") + } else { + jsCollector.addStatement( + s"drawTimeline('#$timelineDivId', $dataJavaScriptName, $minX, $maxX, $minY, $maxY," + + s" '$unitY');") + } +
    + } + + def generateHistogramHtml(jsCollector: JsCollector): Seq[Node] = { + val histogramData = s"$dataJavaScriptName.map(function(d) { return d.y; })" + jsCollector.addPreparedStatement(s"registerHistogram($histogramData, $minY, $maxY);") + if (batchInterval.isDefined) { + jsCollector.addStatement( + "drawHistogram(" + + s"'#$histogramDivId', $histogramData, $minY, $maxY, '$unitY', ${batchInterval.get}" + + ");") + } else { + jsCollector.addStatement( + s"drawHistogram('#$histogramDivId', $histogramData, $minY, $maxY, '$unitY');") + } +
    + } +} + +/** + * A helper class for "scheduling delay", "processing time" and "total delay" to generate data that + * will be used in the timeline and histogram graphs. + * + * @param data (batchTime, milliseconds). "milliseconds" is something like "processing time". + */ +private[ui] class MillisecondsStatUIData(data: Seq[(Long, Long)]) { + + /** + * Converting the original data as per `unit`. + */ + def timelineData(unit: TimeUnit): Seq[(Long, Double)] = + data.map(x => x._1 -> UIUtils.convertToTimeUnit(x._2, unit)) + + /** + * Converting the original data as per `unit`. + */ + def histogramData(unit: TimeUnit): Seq[Double] = + data.map(x => UIUtils.convertToTimeUnit(x._2, unit)) + + val avg: Option[Long] = if (data.isEmpty) None else Some(data.map(_._2).sum / data.size) + + val formattedAvg: String = StreamingPage.formatDurationOption(avg) + + val max: Option[Long] = if (data.isEmpty) None else Some(data.map(_._2).max) +} + +/** + * A helper class for "input rate" to generate data that will be used in the timeline and histogram + * graphs. + * + * @param data (batchTime, event-rate). + */ +private[ui] class EventRateUIData(val data: Seq[(Long, Double)]) { + + val avg: Option[Double] = if (data.isEmpty) None else Some(data.map(_._2).sum / data.size) + + val formattedAvg: String = avg.map(_.formatted("%.2f")).getOrElse("-") + + val max: Option[Double] = if (data.isEmpty) None else Some(data.map(_._2).max) +} /** Page for Spark Web UI that shows statistics of a streaming job */ private[ui] class StreamingPage(parent: StreamingTab) extends WebUIPage("") with Logging { + import StreamingPage._ + private val listener = parent.listener - private val startTime = Calendar.getInstance().getTime() - private val emptyCell = "-" + private val startTime = System.currentTimeMillis() /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val content = - generateBasicStats() ++

    ++ -

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ - generateReceiverStats() ++ - generateBatchStatsTable() - UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) - } - - /** Generate basic stats of the streaming program */ - private def generateBasicStats(): Seq[Node] = { - val timeSinceStart = System.currentTimeMillis() - startTime.getTime -
      -
    • - Started at: {startTime.toString} -
    • -
    • - Time since start: {formatDurationVerbose(timeSinceStart)} -
    • -
    • - Network receivers: {listener.numReceivers} -
    • -
    • - Batch interval: {formatDurationVerbose(listener.batchDuration)} -
    • -
    • - Processed batches: {listener.numTotalCompletedBatches} -
    • -
    • - Waiting batches: {listener.numUnprocessedBatches} -
    • -
    • - Received records: {listener.numTotalReceivedRecords} -
    • -
    • - Processed records: {listener.numTotalProcessedRecords} -
    • -
    - } - - /** Generate stats of data received by the receivers in the streaming program */ - private def generateReceiverStats(): Seq[Node] = { - val receivedRecordDistributions = listener.receivedRecordsDistributions - val lastBatchReceivedRecord = listener.lastReceivedBatchRecords - val table = if (receivedRecordDistributions.size > 0) { - val headerRow = Seq( - "Receiver", - "Status", - "Location", - "Records in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", - "Minimum rate\n[records/sec]", - "Median rate\n[records/sec]", - "Maximum rate\n[records/sec]", - "Last Error" - ) - val dataRows = (0 until listener.numReceivers).map { receiverId => - val receiverInfo = listener.receiverInfo(receiverId) - val receiverName = receiverInfo.map(_.name).getOrElse(s"Receiver-$receiverId") - val receiverActive = receiverInfo.map { info => - if (info.active) "ACTIVE" else "INACTIVE" - }.getOrElse(emptyCell) - val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) - val receiverLastBatchRecords = formatNumber(lastBatchReceivedRecord(receiverId)) - val receivedRecordStats = receivedRecordDistributions(receiverId).map { d => - d.getQuantiles(Seq(0.0, 0.5, 1.0)).map(r => formatNumber(r.toLong)) - }.getOrElse { - Seq(emptyCell, emptyCell, emptyCell, emptyCell, emptyCell) - } - val receiverLastError = listener.receiverInfo(receiverId).map { info => - val msg = s"${info.lastErrorMessage} - ${info.lastError}" - if (msg.size > 100) msg.take(97) + "..." else msg - }.getOrElse(emptyCell) - Seq(receiverName, receiverActive, receiverLocation, receiverLastBatchRecords) ++ - receivedRecordStats ++ Seq(receiverLastError) + val resources = generateLoadResources() + val basicInfo = generateBasicInfo() + val content = resources ++ + basicInfo ++ + listener.synchronized { + generateStatTable() ++ + generateBatchListTables() } - Some(listingTable(headerRow, dataRows)) - } else { - None - } + SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + } - val content = -
    Receiver Statistics
    ++ -
    {table.getOrElse("No receivers")}
    + /** + * Generate html that will load css/js files for StreamingPage + */ + private def generateLoadResources(): Seq[Node] = { + // scalastyle:off + + + + // scalastyle:on + } - content + /** Generate basic information of the streaming program */ + private def generateBasicInfo(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime +
    Running batches of + + {SparkUIUtils.formatDurationVerbose(listener.batchDuration)} + + for + + {SparkUIUtils.formatDurationVerbose(timeSinceStart)} + + since + + {SparkUIUtils.formatDate(startTime)} + + ({listener.numTotalCompletedBatches} + completed batches, {listener.numTotalReceivedRecords} records) +
    +
    } - /** Generate stats of batch jobs of the streaming program */ - private def generateBatchStatsTable(): Seq[Node] = { - val numBatches = listener.retainedCompletedBatches.size - val lastCompletedBatch = listener.lastCompletedBatch - val table = if (numBatches > 0) { - val processingDelayQuantilesRow = { - Seq( - "Processing Time", - formatDurationOption(lastCompletedBatch.flatMap(_.processingDelay)) - ) ++ getQuantiles(listener.processingDelayDistribution) - } - val schedulingDelayQuantilesRow = { - Seq( - "Scheduling Delay", - formatDurationOption(lastCompletedBatch.flatMap(_.schedulingDelay)) - ) ++ getQuantiles(listener.schedulingDelayDistribution) - } - val totalDelayQuantilesRow = { - Seq( - "Total Delay", - formatDurationOption(lastCompletedBatch.flatMap(_.totalDelay)) - ) ++ getQuantiles(listener.totalDelayDistribution) - } - val headerRow = Seq("Metric", "Last batch", "Minimum", "25th percentile", - "Median", "75th percentile", "Maximum") - val dataRows: Seq[Seq[String]] = Seq( - processingDelayQuantilesRow, - schedulingDelayQuantilesRow, - totalDelayQuantilesRow - ) - Some(listingTable(headerRow, dataRows)) - } else { - None + /** + * Generate a global "timeFormat" dictionary in the JavaScript to store the time and its formatted + * string. Because we cannot specify a timezone in JavaScript, to make sure the server and client + * use the same timezone, we use the "timeFormat" dictionary to format all time values used in the + * graphs. + * + * @param times all time values that will be used in the graphs. + */ + private def generateTimeMap(times: Seq[Long]): Seq[Node] = { + val js = "var timeFormat = {};\n" + times.map { time => + val formattedTime = + UIUtils.formatBatchTime(time, listener.batchDuration, showYYYYMMSS = false) + s"timeFormat[$time] = '$formattedTime';" + }.mkString("\n") + + + } + + private def generateStatTable(): Seq[Node] = { + val batches = listener.retainedBatches + + val batchTimes = batches.map(_.batchTime.milliseconds) + val minBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.min + val maxBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.max + + val eventRateForAllStreams = new EventRateUIData(batches.map { batchInfo => + (batchInfo.batchTime.milliseconds, batchInfo.numRecords * 1000.0 / listener.batchDuration) + }) + + val schedulingDelay = new MillisecondsStatUIData(batches.flatMap { batchInfo => + batchInfo.schedulingDelay.map(batchInfo.batchTime.milliseconds -> _) + }) + val processingTime = new MillisecondsStatUIData(batches.flatMap { batchInfo => + batchInfo.processingDelay.map(batchInfo.batchTime.milliseconds -> _) + }) + val totalDelay = new MillisecondsStatUIData(batches.flatMap { batchInfo => + batchInfo.totalDelay.map(batchInfo.batchTime.milliseconds -> _) + }) + + // Use the max value of "schedulingDelay", "processingTime", and "totalDelay" to make the + // Y axis ranges same. + val _maxTime = + (for (m1 <- schedulingDelay.max; m2 <- processingTime.max; m3 <- totalDelay.max) yield + m1 max m2 max m3).getOrElse(0L) + // Should start at 0 + val minTime = 0L + val (maxTime, normalizedUnit) = UIUtils.normalizeDuration(_maxTime) + val formattedUnit = UIUtils.shortTimeUnitString(normalizedUnit) + + // Use the max input rate for all InputDStreams' graphs to make the Y axis ranges same. + // If it's not an integral number, just use its ceil integral number. + val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) + val minEventRate = 0L + + val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit) + + val jsCollector = new JsCollector + + val graphUIDataForEventRateOfAllStreams = + new GraphUIData( + "all-stream-events-timeline", + "all-stream-events-histogram", + eventRateForAllStreams.data, + minBatchTime, + maxBatchTime, + minEventRate, + maxEventRate, + "events/sec") + graphUIDataForEventRateOfAllStreams.generateDataJs(jsCollector) + + val graphUIDataForSchedulingDelay = + new GraphUIData( + "scheduling-delay-timeline", + "scheduling-delay-histogram", + schedulingDelay.timelineData(normalizedUnit), + minBatchTime, + maxBatchTime, + minTime, + maxTime, + formattedUnit) + graphUIDataForSchedulingDelay.generateDataJs(jsCollector) + + val graphUIDataForProcessingTime = + new GraphUIData( + "processing-time-timeline", + "processing-time-histogram", + processingTime.timelineData(normalizedUnit), + minBatchTime, + maxBatchTime, + minTime, + maxTime, + formattedUnit, Some(batchInterval)) + graphUIDataForProcessingTime.generateDataJs(jsCollector) + + val graphUIDataForTotalDelay = + new GraphUIData( + "total-delay-timeline", + "total-delay-histogram", + totalDelay.timelineData(normalizedUnit), + minBatchTime, + maxBatchTime, + minTime, + maxTime, + formattedUnit) + graphUIDataForTotalDelay.generateDataJs(jsCollector) + + // It's false before the user registers the first InputDStream + val hasStream = listener.streamIds.nonEmpty + + val numCompletedBatches = listener.retainedCompletedBatches.size + val numActiveBatches = batchTimes.length - numCompletedBatches + val numReceivers = listener.numInactiveReceivers + listener.numActiveReceivers + val table = + // scalastyle:off + + + + + + + + + + + + + + {if (hasStream) { + + + + }} + + + + + + + + + + + + + + + + +
    Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed)Histograms
    +
    +
    + { + if (hasStream) { + + + + Input Rate + + + } else { + Input Rate + } + } +
    + { + if (numReceivers > 0) { +
    Receivers: {listener.numActiveReceivers} / {numReceivers} active
    + } + } +
    Avg: {eventRateForAllStreams.formattedAvg} events/sec
    +
    +
    {graphUIDataForEventRateOfAllStreams.generateTimelineHtml(jsCollector)}{graphUIDataForEventRateOfAllStreams.generateHistogramHtml(jsCollector)}
    +
    +
    Scheduling Delay {SparkUIUtils.tooltip("Time taken by Streaming scheduler to submit jobs of a batch", "right")}
    +
    Avg: {schedulingDelay.formattedAvg}
    +
    +
    {graphUIDataForSchedulingDelay.generateTimelineHtml(jsCollector)}{graphUIDataForSchedulingDelay.generateHistogramHtml(jsCollector)}
    +
    +
    Processing Time {SparkUIUtils.tooltip("Time taken to process all jobs of a batch", "right")}
    +
    Avg: {processingTime.formattedAvg}
    +
    +
    {graphUIDataForProcessingTime.generateTimelineHtml(jsCollector)}{graphUIDataForProcessingTime.generateHistogramHtml(jsCollector)}
    +
    +
    Total Delay {SparkUIUtils.tooltip("Total time taken to handle a batch", "right")}
    +
    Avg: {totalDelay.formattedAvg}
    +
    +
    {graphUIDataForTotalDelay.generateTimelineHtml(jsCollector)}{graphUIDataForTotalDelay.generateHistogramHtml(jsCollector)}
    + // scalastyle:on + + generateTimeMap(batchTimes) ++ table ++ jsCollector.toHtml + } + + private def generateInputDStreamsTable( + jsCollector: JsCollector, + minX: Long, + maxX: Long, + minY: Double, + maxY: Double): Seq[Node] = { + val content = listener.receivedEventRateWithBatchTime.map { case (streamId, eventRates) => + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + }.foldLeft[Seq[Node]](Nil)(_ ++ _) + + // scalastyle:off + + + + + + + + + + + + {content} + +
    Status
    Location
    Last Error Time
    Last Error Message
    + // scalastyle:on + } + + private def generateInputDStreamRow( + jsCollector: JsCollector, + streamId: Int, + eventRates: Seq[(Long, Double)], + minX: Long, + maxX: Long, + minY: Double, + maxY: Double): Seq[Node] = { + // If this is a ReceiverInputDStream, we need to show the receiver info. Or we only need the + // InputDStream name. + val receiverInfo = listener.receiverInfo(streamId) + val receiverName = receiverInfo.map(_.name). + orElse(listener.streamName(streamId)).getOrElse(s"Stream-$streamId") + val receiverActive = receiverInfo.map { info => + if (info.active) "ACTIVE" else "INACTIVE" + }.getOrElse(emptyCell) + val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLastError = receiverInfo.map { info => + val msg = s"${info.lastErrorMessage} - ${info.lastError}" + if (msg.size > 100) msg.take(97) + "..." else msg + }.getOrElse(emptyCell) + val receiverLastErrorTime = receiverInfo.map { + r => if (r.lastErrorTime < 0) "-" else SparkUIUtils.formatDate(r.lastErrorTime) + }.getOrElse(emptyCell) + val receivedRecords = new EventRateUIData(eventRates) + + val graphUIDataForEventRate = + new GraphUIData( + s"stream-$streamId-events-timeline", + s"stream-$streamId-events-histogram", + receivedRecords.data, + minX, + maxX, + minY, + maxY, + "events/sec") + graphUIDataForEventRate.generateDataJs(jsCollector) + + + +
    +
    {receiverName}
    +
    Avg: {receivedRecords.formattedAvg} events/sec
    +
    + + {receiverActive} + {receiverLocation} + {receiverLastErrorTime} +
    {receiverLastError}
    + + + + {graphUIDataForEventRate.generateTimelineHtml(jsCollector)} + + {graphUIDataForEventRate.generateHistogramHtml(jsCollector)} + + } + + private def generateBatchListTables(): Seq[Node] = { + val runningBatches = listener.runningBatches.sortBy(_.batchTime.milliseconds).reverse + val waitingBatches = listener.waitingBatches.sortBy(_.batchTime.milliseconds).reverse + val completedBatches = listener.retainedCompletedBatches. + sortBy(_.batchTime.milliseconds).reverse + + val activeBatchesContent = { +

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ + new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq } - val content = -
    Batch Processing Statistics
    ++ -
    -
      - {table.getOrElse("No statistics have been generated yet.")} -
    -
    + val completedBatchesContent = { +

    + Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) +

    ++ + new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq + } - content + activeBatchesContent ++ completedBatchesContent } +} + +private[ui] object StreamingPage { + val BLACK_RIGHT_TRIANGLE_HTML = "▶" + val BLACK_DOWN_TRIANGLE_HTML = "▼" + val emptyCell = "-" /** * Returns a human-readable string representing a duration such as "5 second 35 ms" */ - private def formatDurationOption(msOption: Option[Long]): String = { - msOption.map(formatDurationVerbose).getOrElse(emptyCell) + def formatDurationOption(msOption: Option[Long]): String = { + msOption.map(SparkUIUtils.formatDurationVerbose).getOrElse(emptyCell) } - /** Get quantiles for any time distribution */ - private def getQuantiles(timeDistributionOption: Option[Distribution]) = { - timeDistributionOption.get.getQuantiles().map { ms => formatDurationVerbose(ms.toLong) } +} + +/** + * A helper class that allows the user to add JavaScript statements which will be executed when the + * DOM has finished loading. + */ +private[ui] class JsCollector { + + private var variableId = 0 + + /** + * Return the next unused JavaScript variable name + */ + def nextVariableName: String = { + variableId += 1 + "v" + variableId } - /** Generate HTML table from string data */ - private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { - def generateDataRow(data: Seq[String]): Seq[Node] = { - {data.map(d => {d})} - } - UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) + /** + * JavaScript statements that will execute before `statements` + */ + private val preparedStatements = ArrayBuffer[String]() + + /** + * JavaScript statements that will execute after `preparedStatements` + */ + private val statements = ArrayBuffer[String]() + + def addPreparedStatement(js: String): Unit = { + preparedStatements += js + } + + def addStatement(js: String): Unit = { + statements += js + } + + /** + * Generate a html snippet that will execute all scripts when the DOM has finished loading. + */ + def toHtml: Seq[Node] = { + val js = + s""" + |$$(document).ready(function(){ + | ${preparedStatements.mkString("\n")} + | ${statements.mkString("\n")} + |});""".stripMargin + + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index d9d04cd706a0..bc53f2a31f6d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -27,15 +27,28 @@ import StreamingTab._ * Spark Web UI tab that shows statistics of a streaming job. * This assumes the given SparkContext has enabled its SparkUI. */ -private[spark] class StreamingTab(ssc: StreamingContext) +private[spark] class StreamingTab(val ssc: StreamingContext) extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" + val parent = getSparkUI(ssc) val listener = ssc.progressListener ssc.addStreamingListener(listener) + ssc.sc.addSparkListener(listener) attachPage(new StreamingPage(this)) - parent.attachTab(this) + attachPage(new BatchPage(this)) + + def attach() { + getSparkUI(ssc).attachTab(this) + getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") + } + + def detach() { + getSparkUI(ssc).detachTab(this) + getSparkUI(ssc).removeStaticHandler("/static/streaming") + } } private object StreamingTab { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala new file mode 100644 index 000000000000..86cfb1fa4737 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import java.text.SimpleDateFormat +import java.util.TimeZone +import java.util.concurrent.TimeUnit + +private[streaming] object UIUtils { + + /** + * Return the short string for a `TimeUnit`. + */ + def shortTimeUnitString(unit: TimeUnit): String = unit match { + case TimeUnit.NANOSECONDS => "ns" + case TimeUnit.MICROSECONDS => "us" + case TimeUnit.MILLISECONDS => "ms" + case TimeUnit.SECONDS => "sec" + case TimeUnit.MINUTES => "min" + case TimeUnit.HOURS => "hrs" + case TimeUnit.DAYS => "days" + } + + /** + * Find the best `TimeUnit` for converting milliseconds to a friendly string. Return the value + * after converting, also with its TimeUnit. + */ + def normalizeDuration(milliseconds: Long): (Double, TimeUnit) = { + if (milliseconds < 1000) { + return (milliseconds, TimeUnit.MILLISECONDS) + } + val seconds = milliseconds.toDouble / 1000 + if (seconds < 60) { + return (seconds, TimeUnit.SECONDS) + } + val minutes = seconds / 60 + if (minutes < 60) { + return (minutes, TimeUnit.MINUTES) + } + val hours = minutes / 60 + if (hours < 24) { + return (hours, TimeUnit.HOURS) + } + val days = hours / 24 + (days, TimeUnit.DAYS) + } + + /** + * Convert `milliseconds` to the specified `unit`. We cannot use `TimeUnit.convert` because it + * will discard the fractional part. + */ + def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { + case TimeUnit.NANOSECONDS => milliseconds * 1000 * 1000 + case TimeUnit.MICROSECONDS => milliseconds * 1000 + case TimeUnit.MILLISECONDS => milliseconds + case TimeUnit.SECONDS => milliseconds / 1000.0 + case TimeUnit.MINUTES => milliseconds / 1000.0 / 60.0 + case TimeUnit.HOURS => milliseconds / 1000.0 / 60.0 / 60.0 + case TimeUnit.DAYS => milliseconds / 1000.0 / 60.0 / 60.0 / 24.0 + } + + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. + private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + } + + private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + } + + /** + * If `batchInterval` is less than 1 second, format `batchTime` with milliseconds. Otherwise, + * format `batchTime` without milliseconds. + * + * @param batchTime the batch time to be formatted + * @param batchInterval the batch interval + * @param showYYYYMMSS if showing the `yyyy/MM/dd` part. If it's false, the return value wll be + * only `HH:mm:ss` or `HH:mm:ss.SSS` depending on `batchInterval` + * @param timezone only for test + */ + def formatBatchTime( + batchTime: Long, + batchInterval: Long, + showYYYYMMSS: Boolean = true, + timezone: TimeZone = null): String = { + val oldTimezones = + (batchTimeFormat.get.getTimeZone, batchTimeFormatWithMilliseconds.get.getTimeZone) + if (timezone != null) { + batchTimeFormat.get.setTimeZone(timezone) + batchTimeFormatWithMilliseconds.get.setTimeZone(timezone) + } + try { + val formattedBatchTime = + if (batchInterval < 1000) { + batchTimeFormatWithMilliseconds.get.format(batchTime) + } else { + // If batchInterval >= 1 second, don't show milliseconds + batchTimeFormat.get.format(batchTime) + } + if (showYYYYMMSS) { + formattedBatchTime + } else { + formattedBatchTime.substring(formattedBatchTime.indexOf(' ') + 1) + } + } finally { + if (timezone != null) { + batchTimeFormat.get.setTimeZone(oldTimezones._1) + batchTimeFormatWithMilliseconds.get.setTimeZone(oldTimezones._2) + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala deleted file mode 100644 index d6d96d7ba00f..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.util - -private[streaming] -trait Clock { - def currentTime(): Long - def waitTillTime(targetTime: Long): Long -} - -private[streaming] -class SystemClock() extends Clock { - - val minPollTime = 25L - - def currentTime(): Long = { - System.currentTimeMillis() - } - - def waitTillTime(targetTime: Long): Long = { - var currentTime = 0L - currentTime = System.currentTimeMillis() - - var waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - - val pollTime = math.max(waitTime / 10.0, minPollTime).toLong - - while (true) { - currentTime = System.currentTimeMillis() - waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - val sleepTime = math.min(waitTime, pollTime) - Thread.sleep(sleepTime) - } - -1 - } -} - -private[streaming] -class ManualClock() extends Clock { - - private var time = 0L - - def currentTime() = this.synchronized { - time - } - - def setTime(timeToSet: Long) = { - this.synchronized { - time = timeToSet - this.notifyAll() - } - } - - def addToTime(timeToAdd: Long) = { - this.synchronized { - time += timeToAdd - this.notifyAll() - } - } - def waitTillTime(targetTime: Long): Long = { - this.synchronized { - while (time < targetTime) { - this.wait(100) - } - } - currentTime() - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala new file mode 100644 index 000000000000..9f4a4d6806ab --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util.{Iterator => JIterator} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.language.postfixOps + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.util.{CompletionIterator, ThreadUtils} +import org.apache.spark.{Logging, SparkConf} + +/** + * This class manages write ahead log files. + * - Writes records (bytebuffers) to periodically rotating log files. + * - Recovers the log files and the reads the recovered records upon failures. + * - Cleans up old log files. + * + * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write + * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. + * + * @param logDirectory Directory when rotating log files will be created. + * @param hadoopConf Hadoop configuration for reading/writing log files. + */ +private[streaming] class FileBasedWriteAheadLog( + conf: SparkConf, + logDirectory: String, + hadoopConf: Configuration, + rollingIntervalSecs: Int, + maxFailures: Int + ) extends WriteAheadLog with Logging { + + import FileBasedWriteAheadLog._ + + private val pastLogs = new ArrayBuffer[LogInfo] + private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") + + private val threadpoolName = s"WriteAheadLogManager $callerNameTag" + implicit private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) + override protected val logName = s"WriteAheadLogManager $callerNameTag" + + private var currentLogPath: Option[String] = None + private var currentLogWriter: FileBasedWriteAheadLogWriter = null + private var currentLogWriterStartTime: Long = -1L + private var currentLogWriterStopTime: Long = -1L + + initializeOrRecover() + + /** + * Write a byte buffer to the log file. This method synchronously writes the data in the + * ByteBuffer to HDFS. When this method returns, the data is guaranteed to have been flushed + * to HDFS, and will be available for readers to read. + */ + def write(byteBuffer: ByteBuffer, time: Long): FileBasedWriteAheadLogSegment = synchronized { + var fileSegment: FileBasedWriteAheadLogSegment = null + var failures = 0 + var lastException: Exception = null + var succeeded = false + while (!succeeded && failures < maxFailures) { + try { + fileSegment = getLogWriter(time).write(byteBuffer) + succeeded = true + } catch { + case ex: Exception => + lastException = ex + logWarning("Failed to write to write ahead log") + resetWriter() + failures += 1 + } + } + if (fileSegment == null) { + logError(s"Failed to write to write ahead log after $failures failures") + throw lastException + } + fileSegment + } + + def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + val fileSegment = segment.asInstanceOf[FileBasedWriteAheadLogSegment] + var reader: FileBasedWriteAheadLogRandomReader = null + var byteBuffer: ByteBuffer = null + try { + reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) + byteBuffer = reader.read(fileSegment) + } finally { + reader.close() + } + byteBuffer + } + + /** + * Read all the existing logs from the log directory. + * + * Note that this is typically called when the caller is initializing and wants + * to recover past state from the write ahead logs (that is, before making any writes). + * If this is called after writes have been made using this manager, then it may not return + * the latest the records. This does not deal with currently active log files, and + * hence the implementation is kept simple. + */ + def readAll(): JIterator[ByteBuffer] = synchronized { + val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath + logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) + + logFilesToRead.iterator.map { file => + logDebug(s"Creating log reader with $file") + val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) + CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) + }.flatten.asJava + } + + /** + * Delete the log files that are older than the threshold time. + * + * Its important to note that the threshold time is based on the time stamps used in the log + * files, which is usually based on the local system time. So if there is coordination necessary + * between the node calculating the threshTime (say, driver node), and the local system time + * (say, worker node), the caller has to take account of possible time skew. + * + * If waitForCompletion is set to true, this method will return only after old logs have been + * deleted. This should be set to true only for testing. Else the files will be deleted + * asynchronously. + */ + def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } + logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") + + def deleteFiles() { + oldLogFiles.foreach { logInfo => + try { + val path = new Path(logInfo.path) + val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) + fs.delete(path, true) + synchronized { pastLogs -= logInfo } + logDebug(s"Cleared log file $logInfo") + } catch { + case ex: Exception => + logWarning(s"Error clearing write ahead log file $logInfo", ex) + } + } + logInfo(s"Cleared log files in $logDirectory older than $threshTime") + } + if (!executionContext.isShutdown) { + val f = Future { deleteFiles() } + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } + } + + + /** Stop the manager, close any open log writer */ + def close(): Unit = synchronized { + if (currentLogWriter != null) { + currentLogWriter.close() + } + executionContext.shutdown() + logInfo("Stopped write ahead log manager") + } + + /** Get the current log writer while taking care of rotation */ + private def getLogWriter(currentTime: Long): FileBasedWriteAheadLogWriter = synchronized { + if (currentLogWriter == null || currentTime > currentLogWriterStopTime) { + resetWriter() + currentLogPath.foreach { + pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) + } + currentLogWriterStartTime = currentTime + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + val newLogPath = new Path(logDirectory, + timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) + currentLogPath = Some(newLogPath.toString) + currentLogWriter = new FileBasedWriteAheadLogWriter(currentLogPath.get, hadoopConf) + } + currentLogWriter + } + + /** Initialize the log directory or recover existing logs inside the directory */ + private def initializeOrRecover(): Unit = synchronized { + val logDirectoryPath = new Path(logDirectory) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + + if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) + pastLogs.clear() + pastLogs ++= logFileInfo + logInfo(s"Recovered ${logFileInfo.size} write ahead log files from $logDirectory") + logDebug(s"Recovered files are:\n${logFileInfo.map(_.path).mkString("\n")}") + } + } + + private def resetWriter(): Unit = synchronized { + if (currentLogWriter != null) { + currentLogWriter.close() + currentLogWriter = null + } + } +} + +private[streaming] object FileBasedWriteAheadLog { + + case class LogInfo(startTime: Long, endTime: Long, path: String) + + val logFileRegex = """log-(\d+)-(\d+)""".r + + def timeToLogFile(startTime: Long, stopTime: Long): String = { + s"log-$startTime-$stopTime" + } + + def getCallerName(): Option[String] = { + val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + } + + /** Convert a sequence of files to a sequence of sorted LogInfo objects */ + def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = { + files.flatMap { file => + logFileRegex.findFirstIn(file.getName()) match { + case Some(logFileRegex(startTimeStr, stopTimeStr)) => + val startTime = startTimeStr.toLong + val stopTime = stopTimeStr.toLong + Some(LogInfo(startTime, stopTime, file.toString)) + case None => + None + } + }.sortBy { _.startTime } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala new file mode 100644 index 000000000000..f7168229ec15 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.util + +import java.io.Closeable +import java.nio.ByteBuffer + +import org.apache.hadoop.conf.Configuration + +/** + * A random access reader for reading write ahead log files written using + * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. Given the file segment info, + * this reads the record (ByteBuffer) from the log file. + */ +private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: Configuration) + extends Closeable { + + private val instream = HdfsUtils.getInputStream(path, conf) + private var closed = false + + def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized { + assertOpen() + instream.seek(segment.offset) + val nextLength = instream.readInt() + HdfsUtils.checkState(nextLength == segment.length, + s"Expected message length to be ${segment.length}, but was $nextLength") + val buffer = new Array[Byte](nextLength) + instream.readFully(buffer) + ByteBuffer.wrap(buffer) + } + + override def close(): Unit = synchronized { + closed = true + instream.close() + } + + private def assertOpen() { + HdfsUtils.checkState(!closed, "Stream is closed. Create a new Reader to read from the file.") + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala new file mode 100644 index 000000000000..c3bb59f3fef9 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.util + +import java.io.{Closeable, EOFException} +import java.nio.ByteBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.spark.Logging + +/** + * A reader for reading write ahead log files written using + * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. This reads + * the records (bytebuffers) in the log file sequentially and return them as an + * iterator of bytebuffers. + */ +private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Configuration) + extends Iterator[ByteBuffer] with Closeable with Logging { + + private val instream = HdfsUtils.getInputStream(path, conf) + private var closed = false + private var nextItem: Option[ByteBuffer] = None + + override def hasNext: Boolean = synchronized { + if (closed) { + return false + } + + if (nextItem.isDefined) { // handle the case where hasNext is called without calling next + true + } else { + try { + val length = instream.readInt() + val buffer = new Array[Byte](length) + instream.readFully(buffer) + nextItem = Some(ByteBuffer.wrap(buffer)) + logTrace("Read next item " + nextItem.get) + true + } catch { + case e: EOFException => + logDebug("Error reading next item, EOF reached", e) + close() + false + case e: Exception => + logWarning("Error while trying to read data from HDFS.", e) + close() + throw e + } + } + } + + override def next(): ByteBuffer = synchronized { + val data = nextItem.getOrElse { + close() + throw new IllegalStateException( + "next called without calling hasNext or after hasNext returned false") + } + nextItem = None // Ensure the next hasNext call loads new data. + data + } + + override def close(): Unit = synchronized { + if (!closed) { + instream.close() + } + closed = true + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala new file mode 100644 index 000000000000..2e1f1528fad2 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.util + +/** Class for representing a segment of data in a write ahead log file */ +private[streaming] case class FileBasedWriteAheadLogSegment(path: String, offset: Long, length: Int) + extends WriteAheadLogRecordHandle diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala new file mode 100644 index 000000000000..e146bec32a45 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.util + +import java.io._ +import java.nio.ByteBuffer + +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataOutputStream + +/** + * A writer for writing byte-buffers to a write ahead log file. + */ +private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: Configuration) + extends Closeable { + + private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) + + private lazy val hadoopFlushMethod = { + // Use reflection to get the right flush operation + val cls = classOf[FSDataOutputStream] + Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption + } + + private var nextOffset = stream.getPos() + private var closed = false + + /** Write the bytebuffer to the log file */ + def write(data: ByteBuffer): FileBasedWriteAheadLogSegment = synchronized { + assertOpen() + data.rewind() // Rewind to ensure all data in the buffer is retrieved + val lengthToWrite = data.remaining() + val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) + stream.writeInt(lengthToWrite) + if (data.hasArray) { + stream.write(data.array()) + } else { + // If the buffer is not backed by an array, we transfer using temp array + // Note that despite the extra array copy, this should be faster than byte-by-byte copy + while (data.hasRemaining) { + val array = new Array[Byte](data.remaining) + data.get(array) + stream.write(array) + } + } + flush() + nextOffset = stream.getPos() + segment + } + + override def close(): Unit = synchronized { + closed = true + stream.close() + } + + private def flush() { + hadoopFlushMethod.foreach { _.invoke(stream) } + // Useful for local file system where hflush/sync does not work (HADOOP-7844) + stream.getWrappedStream.flush() + } + + private def assertOpen() { + HdfsUtils.checkState(!closed, "Stream is closed. Create a new Writer to write to file.") + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 858ba3c9eb4e..f60688f173c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -27,7 +27,7 @@ private[streaming] object HdfsUtils { // If the file exists and we have append support, append instead of creating a new file val stream: FSDataOutputStream = { if (dfs.isFile(dfsPath)) { - if (conf.getBoolean("hdfs.append.support", false)) { + if (conf.getBoolean("hdfs.append.support", false) || dfs.isInstanceOf[RawLocalFileSystem]) { dfs.append(dfsPath) } else { throw new IllegalStateException("File exists and there is no append support!") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index a73d6f3bf066..408936653c79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.util.collection.OpenHashMap -import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { @@ -29,7 +27,7 @@ object RawTextHelper { * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OpenHashMap[String,Long] + val map = new OpenHashMap[String, Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -71,7 +69,7 @@ object RawTextHelper { var count = 0 while(data.hasNext) { - value = data.next + value = data.next() if (value != null) { count += 1 if (len == 0) { @@ -100,7 +98,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { + for (i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) @@ -108,9 +106,13 @@ object RawTextHelper { } } - def add(v1: Long, v2: Long) = (v1 + v2) + def add(v1: Long, v2: Long): Long = { + v1 + v2 + } - def subtract(v1: Long, v2: Long) = (v1 - v2) + def subtract(v1: Long, v2: Long): Long = { + v1 - v2 + } - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long): Long = math.max(v1, v2) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index a7850812bd61..6addb9675203 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -35,7 +35,9 @@ private[streaming] object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { + // scalastyle:off println System.err.println("Usage: RawTextSender ") + // scalastyle:on println System.exit(1) } // Parse the arguments using a pattern match @@ -72,7 +74,8 @@ object RawTextSender extends Logging { } catch { case e: IOException => logError("Client disconnected") - socket.close() + } finally { + socket.close() } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 1a616a0434f2..dd32ad5ad811 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.Logging +import org.apache.spark.util.{Clock, SystemClock} private[streaming] class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) @@ -38,7 +39,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: * current system time. */ def getStartTime(): Long = { - (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + (math.floor(clock.getTimeMillis().toDouble / period) + 1).toLong * period } /** @@ -48,7 +49,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: * more than current time. */ def getRestartTime(originalStartTime: Long): Long = { - val gap = clock.currentTime - originalStartTime + val gap = clock.getTimeMillis() - originalStartTime (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime } @@ -105,7 +106,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: } private[streaming] -object RecurringTimer { +object RecurringTimer extends Logging { def main(args: Array[String]) { var lastRecurTime = 0L @@ -113,7 +114,7 @@ object RecurringTimer { def onRecur(time: Long) { val currentTime = System.currentTimeMillis() - println("" + currentTime + ": " + (currentTime - lastRecurTime)) + logInfo("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala deleted file mode 100644 index 1005a2c8ec30..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.util - -/** Class for representing a segment of data in a write ahead log file */ -private[streaming] case class WriteAheadLogFileSegment (path: String, offset: Long, length: Int) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala deleted file mode 100644 index 166661b7496d..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.util - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration.Duration -import scala.concurrent.{Await, ExecutionContext, Future} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.spark.Logging -import org.apache.spark.util.Utils -import WriteAheadLogManager._ - -/** - * This class manages write ahead log files. - * - Writes records (bytebuffers) to periodically rotating log files. - * - Recovers the log files and the reads the recovered records upon failures. - * - Cleans up old log files. - * - * Uses [[org.apache.spark.streaming.util.WriteAheadLogWriter]] to write - * and [[org.apache.spark.streaming.util.WriteAheadLogReader]] to read. - * - * @param logDirectory Directory when rotating log files will be created. - * @param hadoopConf Hadoop configuration for reading/writing log files. - * @param rollingIntervalSecs The interval in seconds with which logs will be rolled over. - * Default is one minute. - * @param maxFailures Max number of failures that is tolerated for every attempt to write to log. - * Default is three. - * @param callerName Optional name of the class who is using this manager. - * @param clock Optional clock that is used to check for rotation interval. - */ -private[streaming] class WriteAheadLogManager( - logDirectory: String, - hadoopConf: Configuration, - rollingIntervalSecs: Int = 60, - maxFailures: Int = 3, - callerName: String = "", - clock: Clock = new SystemClock - ) extends Logging { - - private val pastLogs = new ArrayBuffer[LogInfo] - private val callerNameTag = - if (callerName.nonEmpty) s" for $callerName" else "" - private val threadpoolName = s"WriteAheadLogManager $callerNameTag" - implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(1, threadpoolName)) - override protected val logName = s"WriteAheadLogManager $callerNameTag" - - private var currentLogPath: Option[String] = None - private var currentLogWriter: WriteAheadLogWriter = null - private var currentLogWriterStartTime: Long = -1L - private var currentLogWriterStopTime: Long = -1L - - initializeOrRecover() - - /** - * Write a byte buffer to the log file. This method synchronously writes the data in the - * ByteBuffer to HDFS. When this method returns, the data is guaranteed to have been flushed - * to HDFS, and will be available for readers to read. - */ - def writeToLog(byteBuffer: ByteBuffer): WriteAheadLogFileSegment = synchronized { - var fileSegment: WriteAheadLogFileSegment = null - var failures = 0 - var lastException: Exception = null - var succeeded = false - while (!succeeded && failures < maxFailures) { - try { - fileSegment = getLogWriter(clock.currentTime).write(byteBuffer) - succeeded = true - } catch { - case ex: Exception => - lastException = ex - logWarning("Failed to write to write ahead log") - resetWriter() - failures += 1 - } - } - if (fileSegment == null) { - logError(s"Failed to write to write ahead log after $failures failures") - throw lastException - } - fileSegment - } - - /** - * Read all the existing logs from the log directory. - * - * Note that this is typically called when the caller is initializing and wants - * to recover past state from the write ahead logs (that is, before making any writes). - * If this is called after writes have been made using this manager, then it may not return - * the latest the records. This does not deal with currently active log files, and - * hence the implementation is kept simple. - */ - def readFromLog(): Iterator[ByteBuffer] = synchronized { - val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath - logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) - logFilesToRead.iterator.map { file => - logDebug(s"Creating log reader with $file") - new WriteAheadLogReader(file, hadoopConf) - } flatMap { x => x } - } - - /** - * Delete the log files that are older than the threshold time. - * - * Its important to note that the threshold time is based on the time stamps used in the log - * files, which is usually based on the local system time. So if there is coordination necessary - * between the node calculating the threshTime (say, driver node), and the local system time - * (say, worker node), the caller has to take account of possible time skew. - * - * If waitForCompletion is set to true, this method will return only after old logs have been - * deleted. This should be set to true only for testing. Else the files will be deleted - * asynchronously. - */ - def cleanupOldLogs(threshTime: Long, waitForCompletion: Boolean): Unit = { - val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } - logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + - s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") - - def deleteFiles() { - oldLogFiles.foreach { logInfo => - try { - val path = new Path(logInfo.path) - val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) - fs.delete(path, true) - synchronized { pastLogs -= logInfo } - logDebug(s"Cleared log file $logInfo") - } catch { - case ex: Exception => - logWarning(s"Error clearing write ahead log file $logInfo", ex) - } - } - logInfo(s"Cleared log files in $logDirectory older than $threshTime") - } - if (!executionContext.isShutdown) { - val f = Future { deleteFiles() } - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) - } - } - } - - - /** Stop the manager, close any open log writer */ - def stop(): Unit = synchronized { - if (currentLogWriter != null) { - currentLogWriter.close() - } - executionContext.shutdown() - logInfo("Stopped write ahead log manager") - } - - /** Get the current log writer while taking care of rotation */ - private def getLogWriter(currentTime: Long): WriteAheadLogWriter = synchronized { - if (currentLogWriter == null || currentTime > currentLogWriterStopTime) { - resetWriter() - currentLogPath.foreach { - pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) - } - currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) - val newLogPath = new Path(logDirectory, - timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) - currentLogPath = Some(newLogPath.toString) - currentLogWriter = new WriteAheadLogWriter(currentLogPath.get, hadoopConf) - } - currentLogWriter - } - - /** Initialize the log directory or recover existing logs inside the directory */ - private def initializeOrRecover(): Unit = synchronized { - val logDirectoryPath = new Path(logDirectory) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { - val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) - pastLogs.clear() - pastLogs ++= logFileInfo - logInfo(s"Recovered ${logFileInfo.size} write ahead log files from $logDirectory") - logDebug(s"Recovered files are:\n${logFileInfo.map(_.path).mkString("\n")}") - } - } - - private def resetWriter(): Unit = synchronized { - if (currentLogWriter != null) { - currentLogWriter.close() - currentLogWriter = null - } - } -} - -private[util] object WriteAheadLogManager { - - case class LogInfo(startTime: Long, endTime: Long, path: String) - - val logFileRegex = """log-(\d+)-(\d+)""".r - - def timeToLogFile(startTime: Long, stopTime: Long): String = { - s"log-$startTime-$stopTime" - } - - /** Convert a sequence of files to a sequence of sorted LogInfo objects */ - def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = { - files.flatMap { file => - logFileRegex.findFirstIn(file.getName()) match { - case Some(logFileRegex(startTimeStr, stopTimeStr)) => - val startTime = startTimeStr.toLong - val stopTime = stopTimeStr.toLong - Some(LogInfo(startTime, stopTime, file.toString)) - case None => - None - } - }.sortBy { _.startTime } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala deleted file mode 100644 index 003989092a42..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.util - -import java.io.Closeable -import java.nio.ByteBuffer - -import org.apache.hadoop.conf.Configuration - -/** - * A random access reader for reading write ahead log files written using - * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. Given the file segment info, - * this reads the record (bytebuffer) from the log file. - */ -private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configuration) - extends Closeable { - - private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false - - def read(segment: WriteAheadLogFileSegment): ByteBuffer = synchronized { - assertOpen() - instream.seek(segment.offset) - val nextLength = instream.readInt() - HdfsUtils.checkState(nextLength == segment.length, - s"Expected message length to be ${segment.length}, but was $nextLength") - val buffer = new Array[Byte](nextLength) - instream.readFully(buffer) - ByteBuffer.wrap(buffer) - } - - override def close(): Unit = synchronized { - closed = true - instream.close() - } - - private def assertOpen() { - HdfsUtils.checkState(!closed, "Stream is closed. Create a new Reader to read from the file.") - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala deleted file mode 100644 index 2afc0d1551ac..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.util - -import java.io.{Closeable, EOFException} -import java.nio.ByteBuffer - -import org.apache.hadoop.conf.Configuration -import org.apache.spark.Logging - -/** - * A reader for reading write ahead log files written using - * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. This reads - * the records (bytebuffers) in the log file sequentially and return them as an - * iterator of bytebuffers. - */ -private[streaming] class WriteAheadLogReader(path: String, conf: Configuration) - extends Iterator[ByteBuffer] with Closeable with Logging { - - private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false - private var nextItem: Option[ByteBuffer] = None - - override def hasNext: Boolean = synchronized { - if (closed) { - return false - } - - if (nextItem.isDefined) { // handle the case where hasNext is called without calling next - true - } else { - try { - val length = instream.readInt() - val buffer = new Array[Byte](length) - instream.readFully(buffer) - nextItem = Some(ByteBuffer.wrap(buffer)) - logTrace("Read next item " + nextItem.get) - true - } catch { - case e: EOFException => - logDebug("Error reading next item, EOF reached", e) - close() - false - case e: Exception => - logWarning("Error while trying to read data from HDFS.", e) - close() - throw e - } - } - } - - override def next(): ByteBuffer = synchronized { - val data = nextItem.getOrElse { - close() - throw new IllegalStateException( - "next called without calling hasNext or after hasNext returned false") - } - nextItem = None // Ensure the next hasNext call loads new data. - data - } - - override def close(): Unit = synchronized { - if (!closed) { - instream.close() - } - closed = true - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala new file mode 100644 index 000000000000..7f6ff12c58d4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkException} + +/** A helper class with utility functions related to the WriteAheadLog interface */ +private[streaming] object WriteAheadLogUtils extends Logging { + val RECEIVER_WAL_ENABLE_CONF_KEY = "spark.streaming.receiver.writeAheadLog.enable" + val RECEIVER_WAL_CLASS_CONF_KEY = "spark.streaming.receiver.writeAheadLog.class" + val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY = + "spark.streaming.receiver.writeAheadLog.rollingIntervalSecs" + val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures" + + val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class" + val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = + "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" + val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + + val DEFAULT_ROLLING_INTERVAL_SECS = 60 + val DEFAULT_MAX_FAILURES = 3 + + def enableReceiverLog(conf: SparkConf): Boolean = { + conf.getBoolean(RECEIVER_WAL_ENABLE_CONF_KEY, false) + } + + def getRollingIntervalSecs(conf: SparkConf, isDriver: Boolean): Int = { + if (isDriver) { + conf.getInt(DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS) + } else { + conf.getInt(RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS) + } + } + + def getMaxFailures(conf: SparkConf, isDriver: Boolean): Int = { + if (isDriver) { + conf.getInt(DRIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES) + } else { + conf.getInt(RECEIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES) + } + } + + /** + * Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try + * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. + */ + def createLogForDriver( + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + createLog(true, sparkConf, fileWalLogDirectory, fileWalHadoopConf) + } + + /** + * Create a WriteAheadLog for the receiver. If configured with custom WAL class, it will try + * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. + */ + def createLogForReceiver( + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + createLog(false, sparkConf, fileWalLogDirectory, fileWalHadoopConf) + } + + /** + * Create a WriteAheadLog based on the value of the given config key. The config key is used + * to get the class name from the SparkConf. If the class is configured, it will try to + * create instance of that class by first trying `new CustomWAL(sparkConf, logDir)` then trying + * `new CustomWAL(sparkConf)`. If either fails, it will fail. If no class is configured, then + * it will create the default FileBasedWriteAheadLog. + */ + private def createLog( + isDriver: Boolean, + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + + val classNameOption = if (isDriver) { + sparkConf.getOption(DRIVER_WAL_CLASS_CONF_KEY) + } else { + sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) + } + classNameOption.map { className => + try { + instantiateClass( + Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) + } catch { + case NonFatal(e) => + throw new SparkException(s"Could not create a write ahead log of class $className", e) + } + }.getOrElse { + new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf, + getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver)) + } + } + + /** Instantiate the class, either using single arg constructor or zero arg constructor */ + private def instantiateClass(cls: Class[_ <: WriteAheadLog], conf: SparkConf): WriteAheadLog = { + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf) + } catch { + case nsme: NoSuchMethodException => + cls.getConstructor().newInstance() + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala deleted file mode 100644 index 679f6a6dfd7c..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.util - -import java.io._ -import java.net.URI -import java.nio.ByteBuffer - -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataOutputStream, FileSystem} - -/** - * A writer for writing byte-buffers to a write ahead log file. - */ -private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configuration) - extends Closeable { - - private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) - - private lazy val hadoopFlushMethod = { - // Use reflection to get the right flush operation - val cls = classOf[FSDataOutputStream] - Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption - } - - private var nextOffset = stream.getPos() - private var closed = false - - /** Write the bytebuffer to the log file */ - def write(data: ByteBuffer): WriteAheadLogFileSegment = synchronized { - assertOpen() - data.rewind() // Rewind to ensure all data in the buffer is retrieved - val lengthToWrite = data.remaining() - val segment = new WriteAheadLogFileSegment(path, nextOffset, lengthToWrite) - stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } - flush() - nextOffset = stream.getPos() - segment - } - - override def close(): Unit = synchronized { - closed = true - stream.close() - } - - private def flush() { - hadoopFlushMethod.foreach { _.invoke(stream) } - // Useful for local file system where hflush/sync does not work (HADOOP-7844) - stream.getWrappedStream.flush() - } - - private def assertOpen() { - HdfsUtils.checkState(!closed, "Stream is closed. Create a new Writer to write to file.") - } -} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 2df8cf6a8a3d..e0718f73aa13 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -21,11 +21,13 @@ import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; + import scala.Tuple2; import org.junit.Assert; @@ -45,6 +47,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; +import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -67,6 +70,20 @@ public void testInitialization() { Assert.assertNotNull(ssc.sparkContext()); } + @SuppressWarnings("unchecked") + @Test + public void testContextState() { + List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); + Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream); + Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + ssc.start(); + Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + ssc.stop(); + Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + } + @SuppressWarnings("unchecked") @Test public void testCount() { @@ -316,6 +333,7 @@ public void testReduceByWindowWithoutInverse() { testReduceByWindow(false); } + @SuppressWarnings("unchecked") private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), @@ -346,6 +364,14 @@ private void testReduceByWindow(boolean withInverse) { @SuppressWarnings("unchecked") @Test public void testQueueStream() { + ssc.stop(); + // Create a new JavaStreamingContext without checkpointing + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + List> expected = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -684,6 +710,7 @@ public void testStreamingContextTransform(){ JavaDStream transformed1 = ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { + @Override public JavaRDD call(List> listOfRDDs, Time time) { Assert.assertEquals(2, listOfRDDs.size()); return null; @@ -697,6 +724,7 @@ public JavaRDD call(List> listOfRDDs, Time time) { JavaPairDStream> transformed2 = ssc.transformToPair( listOfDStreams2, new Function2>, Time, JavaPairRDD>>() { + @Override public JavaPairRDD> call(List> listOfRDDs, Time time) { Assert.assertEquals(3, listOfRDDs.size()); JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); @@ -926,7 +954,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -984,12 +1012,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1120,7 +1148,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1141,14 +1169,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1246,17 +1274,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1289,17 +1317,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1325,7 +1353,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1704,6 +1732,56 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } + @SuppressWarnings("unchecked") + @Test + public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); + + final SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("newContext", "true"); + + File emptyDir = Files.createTempDir(); + emptyDir.deleteOnExit(); + StreamingContextSuite contextSuite = new StreamingContextSuite(); + String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); + String checkpointDir = contextSuite.createValidCheckpoint(); + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + final AtomicBoolean newContextCreated = new AtomicBoolean(false); + Function0 creatingFunc = new Function0() { + public JavaStreamingContext call() { + newContextCreated.set(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); + } + }; + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration(), true); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", !newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + JavaSparkContext sc = new JavaSparkContext(conf); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", !newContextCreated.get()); + ssc.stop(); + } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") @@ -1828,4 +1906,23 @@ private List> fileTestPrepare(File testDir) throws IOException { return expected; } + + @SuppressWarnings("unchecked") + // SPARK-5795: no logic assertions, just testing that intended API invocations compile + private void compileSaveAsJavaAPI(JavaPairDStream pds) { + pds.saveAsNewAPIHadoopFiles( + "", "", LongWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + pds.saveAsHadoopFiles( + "", "", LongWritable.class, Text.class, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class); + // Checks that a previous common workaround for this API still compiles + pds.saveAsNewAPIHadoopFiles( + "", "", LongWritable.class, Text.class, + (Class) org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + pds.saveAsHadoopFiles( + "", "", LongWritable.class, Text.class, + (Class) org.apache.hadoop.mapred.SequenceFileOutputFormat.class); + } + } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index c0ea0491c313..57b50bdfd652 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -17,16 +17,13 @@ package org.apache.spark.streaming -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import java.util.{List => JList} -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext} /** Exposes streaming test functionality in a Java-friendly way. */ trait JavaTestBase extends TestSuiteBase { @@ -39,7 +36,7 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, data: JList[JList[T]], numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) + val seqData = data.asScala.map(_.asScala) implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] @@ -70,10 +67,9 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out + res.map(_.asJava).asJava } /** @@ -89,12 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[JList[V]]]() - res.map{entry => - val lists = entry.map(new ArrayList[V](_)) - out.append(new ArrayList[JList[V]](lists)) - } - out + res.map(entry => entry.map(_.asJava).asJava).asJava } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java new file mode 100644 index 000000000000..175b8a496b4e --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import com.google.common.base.Function; +import com.google.common.collect.Iterators; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.util.WriteAheadLog; +import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; +import org.apache.spark.streaming.util.WriteAheadLogUtils; + +import org.junit.Test; +import org.junit.Assert; + +public class JavaWriteAheadLogSuite extends WriteAheadLog { + + static class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { + int index = -1; + JavaWriteAheadLogSuiteHandle(int idx) { + index = idx; + } + } + + static class Record { + long time; + int index; + ByteBuffer buffer; + + Record(long tym, int idx, ByteBuffer buf) { + index = idx; + time = tym; + buffer = buf; + } + } + private int index = -1; + private final List records = new ArrayList<>(); + + + // Methods for WriteAheadLog + @Override + public WriteAheadLogRecordHandle write(ByteBuffer record, long time) { + index += 1; + records.add(new Record(time, index, record)); + return new JavaWriteAheadLogSuiteHandle(index); + } + + @Override + public ByteBuffer read(WriteAheadLogRecordHandle handle) { + if (handle instanceof JavaWriteAheadLogSuiteHandle) { + int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index; + for (Record record: records) { + if (record.index == reqdIndex) { + return record.buffer; + } + } + } + return null; + } + + @Override + public Iterator readAll() { + return Iterators.transform(records.iterator(), new Function() { + @Override + public ByteBuffer apply(Record input) { + return input.buffer; + } + }); + } + + @Override + public void clean(long threshTime, boolean waitForCompletion) { + for (int i = 0; i < records.size(); i++) { + if (records.get(i).time < threshTime) { + records.remove(i); + i--; + } + } + } + + @Override + public void close() { + records.clear(); + } + + @Test + public void testCustomWAL() { + SparkConf conf = new SparkConf(); + conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); + + String data1 = "data1"; + WriteAheadLogRecordHandle handle = + wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); + Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); + Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); + + wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); + wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); + wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); + wal.clean(1236, false); + + Iterator dataIterator = wal.readAll(); + List readData = new ArrayList<>(); + while (dataIterator.hasNext()) { + readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); + } + Assert.assertEquals(readData, Arrays.asList("data3", "data4")); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index e8f4a7779ec2..255376807c95 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -22,13 +22,12 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag -import util.ManualClock - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} +import org.apache.spark.util.{Clock, ManualClock} import org.apache.spark.HashPartitioner class BasicOperationsSuite extends TestSuiteBase { @@ -82,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { @@ -172,7 +173,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -254,7 +257,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) + s1.map(x => (x, 1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -426,9 +429,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("updateStateByKey - object lifecycle") { val inputData = Seq( - Seq("a","b"), + Seq("a", "b"), null, - Seq("a","c","a"), + Seq("a", "c", "a"), Seq("c"), null, null @@ -475,7 +478,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } @@ -556,6 +559,9 @@ class BasicOperationsSuite extends TestSuiteBase { withTestServer(new TestServer()) { testServer => withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => testServer.start() + + val batchCounter = new BatchCounter(ssc) + // Set up the streaming context and input streams val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) @@ -586,7 +592,11 @@ class BasicOperationsSuite extends TestSuiteBase { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(200) - clock.addToTime(batchDuration.milliseconds) + val numCompletedBatches = batchCounter.getNumCompletedBatches + clock.advance(batchDuration.milliseconds) + if (!batchCounter.waitUntilBatchesCompleted(numCompletedBatches + 1, 5000)) { + fail("Batch took more than 5 seconds to complete") + } collectRddInfo() } @@ -637,8 +647,8 @@ class BasicOperationsSuite extends TestSuiteBase { ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] if (rememberDuration != null) ssc.remember(rememberDuration) val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.currentTime() === Seconds(10).milliseconds) + val clock = ssc.scheduler.clock.asInstanceOf[Clock] + assert(clock.getTimeMillis() === Seconds(10).milliseconds) assert(output.size === numExpectedOutput) operatedStream } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8f8bc61437ba..1bba7a143edf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import java.io.File -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag import com.google.common.base.Charsets @@ -30,10 +30,11 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.util.Utils +import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** * This test suites tests the checkpointing functionality of DStreams - @@ -44,7 +45,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -61,7 +62,7 @@ class CheckpointSuite extends TestSuiteBase { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") val stateStreamCheckpointInterval = Seconds(1) val fs = FileSystem.getLocal(new Configuration()) @@ -73,7 +74,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -147,7 +148,7 @@ class CheckpointSuite extends TestSuiteBase { // This tests whether spark conf persists through checkpoints, and certain // configs gets scrubbed - test("persistence of conf through checkpoints") { + test("recovery of conf through checkpoints") { val key = "spark.mykey" val value = "myvalue" System.setProperty(key, value) @@ -155,7 +156,7 @@ class CheckpointSuite extends TestSuiteBase { val originalConf = ssc.conf val cp = new Checkpoint(ssc, Time(1000)) - val cpConf = cp.sparkConf + val cpConf = cp.createSparkConf() assert(cpConf.get("spark.master") === originalConf.get("spark.master")) assert(cpConf.get("spark.app.name") === originalConf.get("spark.app.name")) assert(cpConf.get(key) === value) @@ -164,7 +165,8 @@ class CheckpointSuite extends TestSuiteBase { // Serialize/deserialize to simulate write to storage and reading it back val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - val newCpConf = newCp.sparkConf + // Verify new SparkConf has all the previous properties + val newCpConf = newCp.createSparkConf() assert(newCpConf.get("spark.master") === originalConf.get("spark.master")) assert(newCpConf.get("spark.app.name") === originalConf.get("spark.app.name")) assert(newCpConf.get(key) === value) @@ -175,17 +177,79 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(null, newCp, null) val restoredConf = ssc.conf assert(restoredConf.get(key) === value) + ssc.stop() + + // Verify new SparkConf picks up new master url if it is set in the properties. See SPARK-6331. + try { + val newMaster = "local[100]" + System.setProperty("spark.master", newMaster) + val newCpConf = newCp.createSparkConf() + assert(newCpConf.get("spark.master") === newMaster) + assert(newCpConf.get("spark.app.name") === originalConf.get("spark.app.name")) + ssc = new StreamingContext(null, newCp, null) + assert(ssc.sparkContext.master === newMaster) + } finally { + System.clearProperty("spark.master") + } } + // This tests if "spark.driver.host" and "spark.driver.port" is set by user, can be recovered + // with correct value. + test("get correct spark.driver.[host|port] from checkpoint") { + val conf = Map("spark.driver.host" -> "localhost", "spark.driver.port" -> "9999") + conf.foreach(kv => System.setProperty(kv._1, kv._2)) + ssc = new StreamingContext(master, framework, batchDuration) + val originalConf = ssc.conf + assert(originalConf.get("spark.driver.host") === "localhost") + assert(originalConf.get("spark.driver.port") === "9999") + + val cp = new Checkpoint(ssc, Time(1000)) + ssc.stop() + + // Serialize/deserialize to simulate write to storage and reading it back + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + + val newCpConf = newCp.createSparkConf() + assert(newCpConf.contains("spark.driver.host")) + assert(newCpConf.contains("spark.driver.port")) + assert(newCpConf.get("spark.driver.host") === "localhost") + assert(newCpConf.get("spark.driver.port") === "9999") + + // Check if all the parameters have been restored + ssc = new StreamingContext(null, newCp, null) + val restoredConf = ssc.conf + assert(restoredConf.get("spark.driver.host") === "localhost") + assert(restoredConf.get("spark.driver.port") === "9999") + ssc.stop() + + // If spark.driver.host and spark.driver.host is not set in system property, these two + // parameters should not be presented in the newly recovered conf. + conf.foreach(kv => System.clearProperty(kv._1)) + val newCpConf1 = newCp.createSparkConf() + assert(!newCpConf1.contains("spark.driver.host")) + assert(!newCpConf1.contains("spark.driver.port")) - // This tests whether the systm can recover from a master failure with simple + // Spark itself will dispatch a random, not-used port for spark.driver.port if it is not set + // explicitly. + ssc = new StreamingContext(null, newCp, null) + val restoredConf1 = ssc.conf + assert(restoredConf1.get("spark.driver.host") === "localhost") + assert(restoredConf1.get("spark.driver.port") !== "9999") + } + + // This tests whether the system can recover from a master failure with simple // non-stateful operations. This assumes as reliable, replayable input // source - TestInputDStream. test("recovery with map and reduceByKey operations") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -198,7 +262,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -208,7 +273,7 @@ class CheckpointSuite extends TestSuiteBase { } test("recovery with saveAsHadoopFiles operation") { - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -222,7 +287,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -231,7 +302,7 @@ class CheckpointSuite extends TestSuiteBase { } test("recovery with saveAsNewAPIHadoopFiles operation") { - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -245,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -269,7 +346,7 @@ class CheckpointSuite extends TestSuiteBase { // // After SPARK-5079 is addressed, should be able to remove this test since a strengthened // version of the other saveAsHadoopFile* tests would prevent regressions for this issue. - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -284,7 +361,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -310,6 +393,30 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery maintains rate controller") { + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir) + + val dstream = new RateTestInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200))) + } + + val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) + output.register() + runStreams(ssc, 5, 5) + + ssc = new StreamingContext(checkpointDir) + ssc.start() + val outputNew = advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(10.seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200) + } + ssc.stop() + } + // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the @@ -324,13 +431,13 @@ class CheckpointSuite extends TestSuiteBase { * Writes a file named `i` (which contains the number `i`) to the test directory and sets its * modification time to `clock`'s current time. */ - def writeFile(i: Int, clock: ManualClock): Unit = { + def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) Files.write(i + "\n", file, Charsets.UTF_8) - assert(file.setLastModified(clock.currentTime())) + assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: - assert(file.lastModified() === clock.currentTime()) + assert(file.lastModified() === clock.getTimeMillis()) } /** @@ -372,13 +479,13 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Advance half a batch so that the first file is created after the StreamingContext starts - clock.addToTime(batchDuration.milliseconds / 2) + clock.advance(batchDuration.milliseconds / 2) // Create files and advance manual clock to process them for (i <- Seq(1, 2, 3)) { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) if (i != 3) { // Since we want to shut down while the 3rd batch is processing eventually(eventuallyTimeout) { @@ -386,15 +493,14 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.addToTime(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } + clock.advance(batchDuration.milliseconds) // Wait for a checkpoint to be written - val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed @@ -410,16 +516,19 @@ class CheckpointSuite extends TestSuiteBase { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } // Recover context from checkpoint file and verify whether the files that were // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") withStreamingContext(new StreamingContext(checkpointDir)) { ssc => - // So that the restarted StreamingContext's clock has gone forward in time since failure - ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.currentTime() + // "batchDuration.milliseconds * 3" has gone before restarting StreamingContext. And because + // the recovery time is read from the checkpoint time but the original clock doesn't align + // with the batch time, we need to add the offset "batchDuration.milliseconds / 2". + ssc.conf.set("spark.streaming.manualClock.jump", + (batchDuration.milliseconds / 2 + batchDuration.milliseconds * 3).toString) + val oldClockTime = clock.getTimeMillis() // 15000ms clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -430,10 +539,10 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.currentTime() === oldClockTime + assert(clock.getTimeMillis() === oldClockTime) } - // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) - val numBatchesAfterRestart = 4 + // There are 5 batches between 6000ms and 15000ms (inclusive). + val numBatchesAfterRestart = 5 eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) } @@ -441,12 +550,11 @@ class CheckpointSuite extends TestSuiteBase { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.addToTime(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() @@ -519,17 +627,18 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.currentTime()) + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.currentTime()) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -538,4 +647,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala new file mode 100644 index 000000000000..9b5e4dc819a2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.NotSerializableException + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.ReturnStatementInClosureException + +/** + * Test that closures passed to DStream operations are actually cleaned. + */ +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { + private var ssc: StreamingContext = null + + override def beforeAll(): Unit = { + val sc = new SparkContext("local", "test") + ssc = new StreamingContext(sc, Seconds(1)) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + ssc = null + } + + test("user provided closures are actually cleaned") { + val dstream = new DummyInputDStream(ssc) + val pairDstream = dstream.map { i => (i, i) } + // DStream + testMap(dstream) + testFlatMap(dstream) + testFilter(dstream) + testMapPartitions(dstream) + testReduce(dstream) + testForeach(dstream) + testForeachRDD(dstream) + testTransform(dstream) + testTransformWith(dstream) + testReduceByWindow(dstream) + // PairDStreamFunctions + testReduceByKey(pairDstream) + testCombineByKey(pairDstream) + testReduceByKeyAndWindow(pairDstream) + testUpdateStateByKey(pairDstream) + testMapValues(pairDstream) + testFlatMapValues(pairDstream) + // StreamingContext + testTransform2(ssc, dstream) + } + + /** + * Verify that the expected exception is thrown. + * + * We use return statements as an indication that a closure is actually being cleaned. + * We expect closure cleaner to find the return statements in the user provided closures. + */ + private def expectCorrectException(body: => Unit): Unit = { + try { + body + } catch { + case rse: ReturnStatementInClosureException => // Success! + case e @ (_: NotSerializableException | _: SparkException) => + throw new TestException( + s"Expected ReturnStatementInClosureException, but got $e.\n" + + "This means the closure provided by user is not actually cleaned.") + } + } + + // DStream operations + private def testMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.map { _ => return; 1 } + } + private def testFlatMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.flatMap { _ => return; Seq.empty } + } + private def testFilter(ds: DStream[Int]): Unit = expectCorrectException { + ds.filter { _ => return; true } + } + private def testMapPartitions(ds: DStream[Int]): Unit = expectCorrectException { + ds.mapPartitions { _ => return; Seq.empty.toIterator } + } + private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { + ds.reduce { case (_, _) => return; 1 } + } + private def testForeach(ds: DStream[Int]): Unit = { + val foreachF1 = (rdd: RDD[Int], t: Time) => return + val foreachF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreach(foreachF1) } + expectCorrectException { ds.foreach(foreachF2) } + } + private def testForeachRDD(ds: DStream[Int]): Unit = { + val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return + val foreachRDDF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreachRDD(foreachRDDF1) } + expectCorrectException { ds.foreachRDD(foreachRDDF2) } + } + private def testTransform(ds: DStream[Int]): Unit = { + val transformF1 = (rdd: RDD[Int]) => { return; rdd } + val transformF2 = (rdd: RDD[Int], time: Time) => { return; rdd } + expectCorrectException { ds.transform(transformF1) } + expectCorrectException { ds.transform(transformF2) } + } + private def testTransformWith(ds: DStream[Int]): Unit = { + val transformF1 = (rdd1: RDD[Int], rdd2: RDD[Int]) => { return; rdd1 } + val transformF2 = (rdd1: RDD[Int], rdd2: RDD[Int], time: Time) => { return; rdd2 } + expectCorrectException { ds.transformWith(ds, transformF1) } + expectCorrectException { ds.transformWith(ds, transformF2) } + } + private def testReduceByWindow(ds: DStream[Int]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByWindow(reduceF, reduceF, Seconds(1), Seconds(2)) } + } + + // PairDStreamFunctions operations + private def testReduceByKey(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByKey(reduceF) } + expectCorrectException { ds.reduceByKey(reduceF, 5) } + expectCorrectException { ds.reduceByKey(reduceF, new HashPartitioner(5)) } + } + private def testCombineByKey(ds: DStream[(Int, Int)]): Unit = { + expectCorrectException { + ds.combineByKey[Int]( + { _: Int => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + new HashPartitioner(5) + ) + } + } + private def testReduceByKeyAndWindow(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + val filterF = (_: (Int, Int)) => { return; false } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), 5) } + expectCorrectException { + ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), new HashPartitioner(5)) + } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, reduceF, Seconds(2)) } + expectCorrectException { + ds.reduceByKeyAndWindow( + reduceF, reduceF, Seconds(2), Seconds(3), new HashPartitioner(5), filterF) + } + } + private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { + val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } + val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) } + expectCorrectException { ds.updateStateByKey(updateF1) } + expectCorrectException { ds.updateStateByKey(updateF1, 5) } + expectCorrectException { ds.updateStateByKey(updateF1, new HashPartitioner(5)) } + expectCorrectException { + ds.updateStateByKey(updateF1, new HashPartitioner(5), initialRDD) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD) + } + } + private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.mapValues { _ => return; 1 } + } + private def testFlatMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.flatMapValues { _ => return; Seq.empty } + } + + // StreamingContext operations + private def testTransform2(ssc: StreamingContext, ds: DStream[Int]): Unit = { + val transformF = (rdds: Seq[RDD[_]], time: Time) => { return; ssc.sparkContext.emptyRDD[Int] } + expectCorrectException { ssc.transform(Seq(ds), transformF) } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala new file mode 100644 index 000000000000..8844c9d74b93 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.ui.UIUtils + +/** + * Tests whether scope information is passed from DStream operations to RDDs correctly. + */ +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + private var ssc: StreamingContext = null + private val batchDuration: Duration = Seconds(1) + + override def beforeAll(): Unit = { + ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + } + + before { assertPropertiesNotSet() } + after { assertPropertiesNotSet() } + + test("dstream without scope") { + val dummyStream = new DummyDStream(ssc) + dummyStream.initialize(Time(0)) + + // This DStream is not instantiated in any scope, so all RDDs + // created by this stream should similarly not have a scope + assert(dummyStream.baseScope === None) + assert(dummyStream.getOrCompute(Time(1000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(2000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(3000)).get.scope === None) + } + + test("input dstream without scope") { + val inputStream = new DummyInputDStream(ssc) + inputStream.initialize(Time(0)) + + val baseScope = inputStream.baseScope.map(RDDOperationScope.fromJson) + val scope1 = inputStream.getOrCompute(Time(1000)).get.scope + val scope2 = inputStream.getOrCompute(Time(2000)).get.scope + val scope3 = inputStream.getOrCompute(Time(3000)).get.scope + + // This DStream is not instantiated in any scope, so all RDDs + assertDefined(baseScope, scope1, scope2, scope3) + assert(baseScope.get.name.startsWith("dummy stream")) + assertScopeCorrect(baseScope.get, scope1.get, 1000) + assertScopeCorrect(baseScope.get, scope2.get, 2000) + assertScopeCorrect(baseScope.get, scope3.get, 3000) + } + + test("scoping simple operations") { + val inputStream = new DummyInputDStream(ssc) + val mappedStream = inputStream.map { i => i + 1 } + val filteredStream = mappedStream.filter { i => i % 2 == 0 } + filteredStream.initialize(Time(0)) + + val mappedScopeBase = mappedStream.baseScope.map(RDDOperationScope.fromJson) + val mappedScope1 = mappedStream.getOrCompute(Time(1000)).get.scope + val mappedScope2 = mappedStream.getOrCompute(Time(2000)).get.scope + val mappedScope3 = mappedStream.getOrCompute(Time(3000)).get.scope + val filteredScopeBase = filteredStream.baseScope.map(RDDOperationScope.fromJson) + val filteredScope1 = filteredStream.getOrCompute(Time(1000)).get.scope + val filteredScope2 = filteredStream.getOrCompute(Time(2000)).get.scope + val filteredScope3 = filteredStream.getOrCompute(Time(3000)).get.scope + + // These streams are defined in their respective scopes "map" and "filter", so all + // RDDs created by these streams should inherit the IDs and names of their parent + // DStream's base scopes + assertDefined(mappedScopeBase, mappedScope1, mappedScope2, mappedScope3) + assertDefined(filteredScopeBase, filteredScope1, filteredScope2, filteredScope3) + assert(mappedScopeBase.get.name === "map") + assert(filteredScopeBase.get.name === "filter") + assertScopeCorrect(mappedScopeBase.get, mappedScope1.get, 1000) + assertScopeCorrect(mappedScopeBase.get, mappedScope2.get, 2000) + assertScopeCorrect(mappedScopeBase.get, mappedScope3.get, 3000) + assertScopeCorrect(filteredScopeBase.get, filteredScope1.get, 1000) + assertScopeCorrect(filteredScopeBase.get, filteredScope2.get, 2000) + assertScopeCorrect(filteredScopeBase.get, filteredScope3.get, 3000) + } + + test("scoping nested operations") { + val inputStream = new DummyInputDStream(ssc) + val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) + countStream.initialize(Time(0)) + + val countScopeBase = countStream.baseScope.map(RDDOperationScope.fromJson) + val countScope1 = countStream.getOrCompute(Time(1000)).get.scope + val countScope2 = countStream.getOrCompute(Time(2000)).get.scope + val countScope3 = countStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(countScopeBase, countScope1, countScope2, countScope3) + assert(countScopeBase.get.name === "countByWindow") + assertScopeCorrect(countScopeBase.get, countScope1.get, 1000) + assertScopeCorrect(countScopeBase.get, countScope2.get, 2000) + assertScopeCorrect(countScopeBase.get, countScope3.get, 3000) + + // All streams except the input stream should share the same scopes as `countStream` + def testStream(stream: DStream[_]): Unit = { + if (stream != inputStream) { + val myScopeBase = stream.baseScope.map(RDDOperationScope.fromJson) + val myScope1 = stream.getOrCompute(Time(1000)).get.scope + val myScope2 = stream.getOrCompute(Time(2000)).get.scope + val myScope3 = stream.getOrCompute(Time(3000)).get.scope + assertDefined(myScopeBase, myScope1, myScope2, myScope3) + assert(myScopeBase === countScopeBase) + assert(myScope1 === countScope1) + assert(myScope2 === countScope2) + assert(myScope3 === countScope3) + // Climb upwards to test the parent streams + stream.dependencies.foreach(testStream) + } + } + testStream(countStream) + } + + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ + private def assertPropertiesNotSet(): Unit = { + assert(ssc != null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY) == null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY) == null) + } + + /** Assert that the given RDD scope inherits the name and ID of the base scope correctly. */ + private def assertScopeCorrect( + baseScope: RDDOperationScope, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) + } + + /** Assert that the given RDD scope inherits the base name and ID correctly. */ + private def assertScopeCorrect( + baseScopeId: String, + baseScopeName: String, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + val formattedBatchTime = UIUtils.formatBatchTime( + batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + assert(rddScope.id === s"${baseScopeId}_$batchTime") + assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + } + + /** Assert that all the specified options are defined. */ + private def assertDefined[T](options: Option[T]*): Unit = { + options.zipWithIndex.foreach { case (o, i) => assert(o.isDefined, s"Option $i was empty!") } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 6500608bba87..e82c2fa4e72a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -17,35 +17,40 @@ package org.apache.spark.streaming -import org.apache.spark.Logging -import org.apache.spark.util.Utils - import java.io.File +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkFunSuite, Logging} +import org.apache.spark.util.Utils + /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends TestSuiteBase with Logging { - - val directory = Utils.createTempDir().getAbsolutePath - val numBatches = 30 +class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { - override def batchDuration = Milliseconds(1000) + private val batchDuration: Duration = Milliseconds(1000) + private val numBatches = 30 + private var directory: File = null - override def useManualClock = false + before { + directory = Utils.createTempDir() + } - override def afterFunction() { - Utils.deleteRecursively(new File(directory)) - super.afterFunction() + after { + if (directory != null) { + Utils.deleteRecursively(directory) + } + StreamingContext.getActive().foreach { _.stop() } } test("multiple failures with map") { - MasterFailureTest.testMap(directory, numBatches, batchDuration) + MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration) } test("multiple failures with updateStateByKey") { - MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) + MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index bddf51e13042..ec2852d9a020 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -17,84 +17,176 @@ package org.apache.spark.streaming -import akka.actor.Actor -import akka.actor.Props -import akka.util.ByteString - import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{InetSocketAddress, SocketException, ServerSocket} +import java.net.{Socket, SocketException, ServerSocket} import java.nio.charset.Charset -import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} import scala.language.postfixOps import com.google.common.io.Files +import org.apache.hadoop.io.{Text, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.util.Utils -import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} import org.apache.spark.rdd.RDD -import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -import org.apache.hadoop.fs.Path +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.Receiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("socket input stream") { - // Start the server - val testServer = new TestServer() - testServer.start() + withTestServer(new TestServer()) { testServer => + // Start the server + testServer.start() - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.socketTextStream( - "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.addStreamingListener(ssc.progressListener) - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - val expectedOutput = input.map(_.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString + "\n") - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + val input = Seq(1, 2, 3, 4, 5) + // Use "batchCount" to make sure we check the result after all batches finish + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val expectedOutput = input.map(_.toString) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(500) + clock.advance(batchDuration.milliseconds) + } + // Make sure we finish all batches before "stop" + if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + + // Verify all "InputInfo"s have been reported + assert(ssc.progressListener.numTotalReceivedRecords === input.size) + assert(ssc.progressListener.numTotalProcessedRecords === input.size) + + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() + } - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") + test("socket input stream - no block in a batch") { + withTestServer(new TestServer()) { testServer => + testServer.start() - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.addStreamingListener(ssc.progressListener) + + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // Make sure the first batch is finished + if (!batchCounter.waitUntilBatchesCompleted(1, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + + networkStream.generatedRDDs.foreach { case (_, rdd) => + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + } + } } } + test("binary records stream") { + val testDir: File = null + try { + val batchDuration = Seconds(2) + val testDir = Utils.createTempDir() + // Create a file that exists before the StreamingContext is created: + val existingFile = new File(testDir, "0") + Files.write("0\n", existingFile, Charset.forName("UTF-8")) + assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) + + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + // This `setTime` call ensures that the clock is past the creation time of `existingFile` + clock.setTime(existingFile.lastModified + batchDuration.milliseconds) + val batchCounter = new BatchCounter(ssc) + val fileStream = ssc.binaryRecordsStream(testDir.toString, 1) + val outputBuffer = new ArrayBuffer[Seq[Array[Byte]]] + with SynchronizedBuffer[Seq[Array[Byte]]] + val outputStream = new TestOutputStream(fileStream, outputBuffer) + outputStream.register() + ssc.start() + + // Advance the clock so that the files are created after StreamingContext starts, but + // not enough to trigger a batch + clock.advance(batchDuration.milliseconds / 2) + + val input = Seq(1, 2, 3, 4, 5) + input.foreach { i => + Thread.sleep(batchDuration.milliseconds) + val file = new File(testDir, i.toString) + Files.write(Array[Byte](i.toByte), file) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) + logInfo("Created file " + file) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.advance(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === i) + } + } + + val expectedOutput = input.map(i => i.toByte) + val obtainedOutput = outputBuffer.flatten.toList.map(i => i(0).toByte) + assert(obtainedOutput === expectedOutput) + } + } finally { + if (testDir != null) Utils.deleteRecursively(testDir) + } + } test("file input stream - newFilesOnly = true") { testFileStream(newFilesOnly = true) @@ -118,7 +210,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -128,7 +220,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && System.currentTimeMillis() - startTime < 5000) { Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping context") @@ -150,7 +242,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -158,12 +250,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping context") @@ -193,7 +285,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -205,12 +297,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Enqueue the first 3 items (one by one), they should be merged in the next batch val inputIterator = input.toIterator inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(1000) // Enqueue the remaining items (again one by one), merged in the final batch inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(1000) logInfo("Stopping context") ssc.stop() @@ -232,6 +324,34 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("test track the number of input stream") { + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } + + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } + } + def testFileStream(newFilesOnly: Boolean) { val testDir: File = null try { @@ -257,19 +377,19 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Advance the clock so that the files are created after StreamingContext starts, but // not enough to trigger a batch - clock.addToTime(batchDuration.milliseconds / 2) + clock.advance(batchDuration.milliseconds / 2) // Over time, create files in the directory val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) Files.write(i + "\n", file, Charset.forName("UTF-8")) - assert(file.setLastModified(clock.currentTime())) - assert(file.lastModified === clock.currentTime) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === i) } @@ -297,30 +417,45 @@ class TestServer(portToBind: Int = 0) extends Logging { val serverSocket = new ServerSocket(portToBind) + private val startLatch = new CountDownLatch(1) + val servingThread = new Thread() { override def run() { try { - while(true) { + while (true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() - logInfo("New connection") - try { - clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) - - while(clientSocket.isConnected) { - val msg = queue.poll(100, TimeUnit.MILLISECONDS) - if (msg != null) { - outputStream.write(msg) - outputStream.flush() - logInfo("Message '" + msg + "' sent") + if (startLatch.getCount == 1) { + // The first connection is a test connection to implement "waitForStart", so skip it + // and send a signal + if (!clientSocket.isClosed) { + clientSocket.close() + } + startLatch.countDown() + } else { + // Real connections + logInfo("New connection") + try { + clientSocket.setTcpNoDelay(true) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) + + while (clientSocket.isConnected) { + val msg = queue.poll(100, TimeUnit.MILLISECONDS) + if (msg != null) { + outputStream.write(msg) + outputStream.flush() + logInfo("Message '" + msg + "' sent") + } + } + } catch { + case e: SocketException => logError("TestServer error", e) + } finally { + logInfo("Connection closed") + if (!clientSocket.isClosed) { + clientSocket.close() } } - } catch { - case e: SocketException => logError("TestServer error", e) - } finally { - logInfo("Connection closed") - if (!clientSocket.isClosed) clientSocket.close() } } } catch { @@ -332,13 +467,35 @@ class TestServer(portToBind: Int = 0) extends Logging { } } - def start() { servingThread.start() } + def start(): Unit = { + servingThread.start() + if (!waitForStart(10000)) { + stop() + throw new AssertionError("Timeout: TestServer cannot start in 10 seconds") + } + } + + /** + * Wait until the server starts. Return true if the server starts in "millis" milliseconds. + * Otherwise, return false to indicate it's timeout. + */ + private def waitForStart(millis: Long): Boolean = { + // We will create a test connection to the server so that we can make sure it has started. + val socket = new Socket("localhost", port) + try { + startLatch.await(millis, TimeUnit.MILLISECONDS) + } finally { + if (!socket.isClosed) { + socket.close() + } + } + } def send(msg: String) { queue.put(msg) } def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index e0f14fd95428..0e64b57e0ffd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -43,6 +43,7 @@ object MasterFailureTest extends Logging { @volatile var setupCalled = false def main(args: Array[String]) { + // scalastyle:off println if (args.size < 2) { println( "Usage: MasterFailureTest <# batches> " + @@ -60,6 +61,7 @@ object MasterFailureTest extends Logging { testUpdateStateByKey(directory, numBatches, batchDuration) println("\n\nSUCCESS\n\n") + // scalastyle:on println } def testMap(directory: String, numBatches: Int, batchDuration: Duration) { @@ -242,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) @@ -291,10 +299,12 @@ object MasterFailureTest extends Logging { } // Log the output + // scalastyle:off println println("Expected output, size = " + expectedOutput.size) println(expectedOutput.mkString("[", ",", "]")) println("Output, size = " + output.size) println(output.mkString("[", ",", "]")) + // scalastyle:on println // Match the output with the expected output output.foreach(o => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 132ff2443fc0..6c0c926755c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,30 +24,31 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} -import com.google.common.io.Files -import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -55,48 +56,45 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val serializer = new KryoSerializer(conf) val manualClock = new ManualClock val blockManagerSize = 10000000 + val blockManagerBuffer = new ArrayBuffer[BlockManager]() - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var storageLevel: StorageLevel = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, - blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr), securityMgr, 0) - blockManager.initialize("app-id") + storageLevel = StorageLevel.MEMORY_ONLY_SER + blockManager = createBlockManager(blockManagerSize, conf) - tempDirectory = Files.createTempDir() + tempDirectory = Utils.createTempDir() manualClock.setTime(0) } after { - if (blockManager != null) { - blockManager.stop() - blockManager = null + for ( blockManager <- blockManagerBuffer ) { + if (blockManager != null) { + blockManager.stop() + } } + blockManager = null + blockManagerBuffer.clear() if (blockManagerMaster != null) { blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null - if (tempDirectory != null && tempDirectory.exists()) { - FileUtils.deleteDirectory(tempDirectory) - tempDirectory = null - } + Utils.deleteRecursively(tempDirectory) } test("BlockManagerBasedBlockHandler - store blocks") { @@ -104,7 +102,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -128,7 +126,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -138,10 +136,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche "Unexpected store result type" ) // Verify the data in write ahead log files is correct - val fileSegments = storeResults.map { _.asInstanceOf[WriteAheadLogBasedStoreResult].segment} - val loggedData = fileSegments.flatMap { segment => - val reader = new WriteAheadLogRandomReader(segment.path, hadoopConf) - val bytes = reader.read(segment) + val walSegments = storeResults.map { result => + result.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle + } + val loggedData = walSegments.flatMap { walSegment => + val fileSegment = walSegment.asInstanceOf[FileBasedWriteAheadLogSegment] + val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) + val bytes = reader.read(fileSegment) reader.close() blockManager.dataDeserialize(generateBlockId(), bytes).toList } @@ -156,16 +157,16 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche } } - test("WriteAheadLogBasedBlockHandler - cleanup old blocks") { + test("WriteAheadLogBasedBlockHandler - clean old blocks") { withWriteAheadLogBasedBlockHandler { handler => val blocks = Seq.tabulate(10) { i => IteratorBlock(Iterator(1 to i)) } storeBlocks(handler, blocks) val preCleanupLogFiles = getWriteAheadLogFiles() - preCleanupLogFiles.size should be > 1 + require(preCleanupLogFiles.size > 1) // this depends on the number of blocks inserted using generateAndStoreData() - manualClock.currentTime() shouldEqual 5000L + manualClock.getTimeMillis() shouldEqual 5000L val cleanupThreshTime = 3000L handler.cleanupOldBlocks(cleanupThreshTime) @@ -175,6 +176,130 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche } } + test("Test Block - count messages") { + // Test count with BlockManagedBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(true) + // Test count with WriteAheadLogBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(false) + } + + test("Test Block - isFullyConsumed") { + val sparkConf = new SparkConf() + sparkConf.set("spark.storage.unrollMemoryThreshold", "512") + // spark.storage.unrollFraction set to 0.4 for BlockManager + sparkConf.set("spark.storage.unrollFraction", "0.4") + // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll + blockManager = createBlockManager(12000, sparkConf) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + testRecordcount(false, StorageLevel.MEMORY_ONLY, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + testRecordcount(true, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block With MEMORY_ONLY StorageLevel. + // BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block, resulting the SparkException + storageLevel = StorageLevel.MEMORY_ONLY + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator)) + } + } + } + + private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) { + // ByteBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ByteBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ArrayBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50)) + // ArrayBufferBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75)) + // IteratorBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125)) + // IteratorBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150)) + } + + private def createBlockManager( + maxMem: Long, + conf: SparkConf, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + blockManagerBuffer += manager + manager + } + + /** + * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * and verify the correct record count + */ + private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, + sLevel: StorageLevel, + receivedBlock: ReceivedBlock, + bManager: BlockManager, + expectedNumRecords: Option[Long] + ) { + blockManager = bManager + storageLevel = sLevel + var bId: StreamBlockId = null + try { + if (isBlockManagedBasedBlockHandler) { + // test received block with BlockManager based handler + withBlockManagerBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using BlockManagerBasedBlockHandler with " + sLevel) + } + } else { + // test received block with WAL based handler + withWriteAheadLogBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel) + } + } + } finally { + // Removing the Block Id to use same blockManager for next test + blockManager.removeBlock(bId, true) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -226,6 +351,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1) val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) try { @@ -243,7 +369,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val blockIds = Seq.fill(blocks.size)(generateBlockId()) val storeResults = blocks.zip(blockIds).map { case (block, id) => - manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf + manualClock.advance(500) // log rolling interval set to 1000 ms through SparkConf logDebug("Inserting block " + id) receivedBlockHandler.storeBlock(id, block) }.toList @@ -251,9 +377,21 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche (blockIds, storeResults) } + /** Store single block using a handler */ + private def storeSingleBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): (StreamBlockId, ReceivedBlockStoreResult) = { + val blockId = generateBlockId + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + (blockId, blockStoreResult) + } + private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index fbb7b0bfebaf..f793a12843b2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -24,22 +24,20 @@ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} import scala.util.Random -import com.google.common.io.Files -import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader} import org.apache.spark.streaming.util.WriteAheadLogSuite._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds @@ -51,20 +49,17 @@ class ReceivedBlockTrackerSuite before { conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") - checkpointDirectory = Files.createTempDir() + checkpointDirectory = Utils.createTempDir() } after { allReceivedBlockTrackers.foreach { _.stop() } - if (checkpointDirectory != null && checkpointDirectory.exists()) { - FileUtils.deleteDirectory(checkpointDirectory) - checkpointDirectory = null - } + Utils.deleteRecursively(checkpointDirectory) } test("block addition, and block to batch allocation") { val receivedBlockTracker = createTracker(setCheckpointDir = false) - receivedBlockTracker.isLogManagerEnabled should be (false) // should be disable by default + receivedBlockTracker.isWriteAheadLogEnabled should be (false) // should be disable by default receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty val blockInfos = generateBlockInfos() @@ -72,15 +67,20 @@ class ReceivedBlockTrackerSuite // Verify added blocks are unallocated blocks receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos + receivedBlockTracker.hasUnallocatedReceivedBlocks should be (true) + // Allocate the blocks to a batch and verify that all of them have been allocated receivedBlockTracker.allocateBlocksToBatch(1) receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos + receivedBlockTracker.getBlocksOfBatch(1) shouldEqual Map(streamId -> blockInfos) receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty + receivedBlockTracker.hasUnallocatedReceivedBlocks should be (false) // Allocate no blocks to another batch receivedBlockTracker.allocateBlocksToBatch(2) receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + receivedBlockTracker.getBlocksOfBatch(2) shouldEqual Map(streamId -> Seq.empty) // Verify that older batches have no operation on batch allocation, // will return the same blocks as previously allocated. @@ -93,14 +93,14 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } - test("block addition, block to batch allocation and cleanup with write ahead log") { + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates // a new log file def incrementTime() { val timeIncrementMillis = 2000L - manualClock.addToTime(timeIncrementMillis) + manualClock.advance(timeIncrementMillis) } // Generate and add blocks to the given tracker @@ -118,11 +118,13 @@ class ReceivedBlockTrackerSuite logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") } + // Set WAL configuration + conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1") + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1) + // Start tracker and add blocks - conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") - conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") val tracker1 = createTracker(clock = manualClock) - tracker1.isLogManagerEnabled should be (true) + tracker1.isWriteAheadLogEnabled should be (true) val blockInfos1 = addBlockInfos(tracker1) tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 @@ -132,19 +134,31 @@ class ReceivedBlockTrackerSuite getWrittenLogData() shouldEqual expectedWrittenData1 getWriteAheadLogFiles() should have size 1 - // Restart tracker and verify recovered list of unallocated blocks incrementTime() - val tracker2 = createTracker(clock = manualClock) - tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 + + // Recovery without recovery from WAL and verify list of unallocated blocks is empty + val tracker1_ = createTracker(clock = manualClock, recoverFromWriteAheadLog = false) + tracker1_.getUnallocatedBlocks(streamId) shouldBe empty + tracker1_.hasUnallocatedReceivedBlocks should be (false) + + // Restart tracker and verify recovered list of unallocated blocks + val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) + val unallocatedBlocks = tracker2.getUnallocatedBlocks(streamId).toList + unallocatedBlocks shouldEqual blockInfos1 + unallocatedBlocks.foreach { block => + block.isBlockIdValid() should be (false) + } + // Allocate blocks to batch and verify whether the unallocated blocks got allocated - val batchTime1 = manualClock.currentTime + val batchTime1 = manualClock.getTimeMillis() tracker2.allocateBlocksToBatch(batchTime1) tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 + tracker2.getBlocksOfBatch(batchTime1) shouldEqual Map(streamId -> blockInfos1) // Add more blocks and allocate to another batch incrementTime() - val batchTime2 = manualClock.currentTime + val batchTime2 = manualClock.getTimeMillis() val blockInfos2 = addBlockInfos(tracker2) tracker2.allocateBlocksToBatch(batchTime2) tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 @@ -158,7 +172,7 @@ class ReceivedBlockTrackerSuite // Restart tracker and verify recovered state incrementTime() - val tracker3 = createTracker(clock = manualClock) + val tracker3 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 tracker3.getUnallocatedBlocks(streamId) shouldBe empty @@ -176,52 +190,42 @@ class ReceivedBlockTrackerSuite eventually(timeout(10 seconds), interval(10 millisecond)) { getWriteAheadLogFiles() should not contain oldestLogFile } - printLogFiles("After cleanup") + printLogFiles("After clean") // Restart tracker and verify recovered state, specifically whether info about the first // batch has been removed, but not the second batch incrementTime() - val tracker4 = createTracker(clock = manualClock) + val tracker4 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) tracker4.getUnallocatedBlocks(streamId) shouldBe empty tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 } - test("enabling write ahead log but not setting checkpoint dir") { - conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") - intercept[SparkException] { - createTracker(setCheckpointDir = false) - } - } - - test("setting checkpoint dir but not enabling write ahead log") { - // When WAL config is not set, log manager should not be enabled - val tracker1 = createTracker(setCheckpointDir = true) - tracker1.isLogManagerEnabled should be (false) - - // When WAL is explicitly disabled, log manager should not be enabled - conf.set("spark.streaming.receiver.writeAheadLog.enable", "false") - val tracker2 = createTracker(setCheckpointDir = true) - tracker2.isLogManagerEnabled should be(false) + test("disable write ahead log when checkpoint directory is not set") { + // When checkpoint is disabled, then the write ahead log is disabled + val tracker1 = createTracker(setCheckpointDir = false) + tracker1.isWriteAheadLogEnabled should be (false) } /** * Create tracker object with the optional provided clock. Use fake clock if you - * want to control time by manually incrementing it to test log cleanup. + * want to control time by manually incrementing it to test log clean. */ def createTracker( setCheckpointDir: Boolean = true, + recoverFromWriteAheadLog: Boolean = false, clock: Clock = new SystemClock): ReceivedBlockTracker = { val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None - val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption) + val tracker = new ReceivedBlockTracker( + conf, hadoopConf, Seq(streamId), clock, recoverFromWriteAheadLog, cpDirOption) allReceivedBlockTrackers += tracker tracker } /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { - List.fill(5)(ReceivedBlockInfo(streamId, 0, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ @@ -233,9 +237,10 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { - file => new WriteAheadLogReader(file, hadoopConf).toSeq + file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) }.toList @@ -249,11 +254,12 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } - /** Create batch cleanup object from the given info */ + /** Create batch clean object from the given info */ def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala new file mode 100644 index 000000000000..6d388d9624d9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.{SparkConf, SparkEnv} + +class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { + + override def afterAll(): Unit = { + StreamingContext.getActive().map { _.stop() } + } + + testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithoutWAL("createBlockRDD creates correct BlockRDD with block info") { receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = false) } + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.forall(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + testWithoutWAL("createBlockRDD filters non-existent blocks before creating BlockRDD") { + receiverStream => + val presentBlockInfos = Seq.fill(2)(createBlockInfo(withWALInfo = false, createBlock = true)) + val absentBlockInfos = Seq.fill(3)(createBlockInfo(withWALInfo = false, createBlock = false)) + val blockInfos = presentBlockInfos ++ absentBlockInfos + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.exists(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + require(blockIds.exists(blockId => !SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === presentBlockInfos.map { _.blockId}) + } + + testWithWAL("createBlockRDD creates empty WALBackedBlockRDD when no block info") { + receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithWAL( + "createBlockRDD creates correct WALBackedBlockRDD with all block info having WAL info") { + receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = true) } + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[WriteAheadLogBackedBlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) + } + + testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + receiverStream => + val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } + val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } + val blockInfos = blockInfos1 ++ blockInfos2 + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + + private def testWithoutWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"Without WAL enabled: $msg") { + runTest(enableWAL = false, body) + } + } + + private def testWithWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"With WAL enabled: $msg") { + runTest(enableWAL = true, body) + } + } + + private def runTest(enableWAL: Boolean, body: ReceiverInputDStream[_] => Unit): Unit = { + val conf = new SparkConf() + conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") + conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) + require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) + val ssc = new StreamingContext(conf, Seconds(1)) + val receiverStream = new ReceiverInputDStream[Int](ssc) { + override def getReceiver(): Receiver[Int] = null + } + withStreamingContext(ssc) { ssc => + body(receiverStream) + } + } + + /** + * Create a block info for input to the ReceiverInputDStream.createBlockRDD + * @param withWALInfo Create block with WAL info in it + * @param createBlock Actually create the block in the BlockManager + * @return + */ + private def createBlockInfo( + withWALInfo: Boolean, + createBlock: Boolean = true): ReceivedBlockInfo = { + val blockId = new StreamBlockId(0, Random.nextLong()) + if (createBlock) { + SparkEnv.get.blockManager.putSingle(blockId, 1, StorageLevel.MEMORY_ONLY, tellMaster = true) + require(SparkEnv.get.blockManager.master.contains(blockId)) + } + val storeResult = if (withWALInfo) { + new WriteAheadLogBasedStoreResult(blockId, None, new WriteAheadLogRecordHandle { }) + } else { + new BlockManagerBasedStoreResult(blockId, None) + } + new ReceivedBlockInfo(0, None, None, storeResult) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e8c34a9ee40b..01279b34f73d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,7 +24,6 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -34,6 +33,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { @@ -129,43 +129,17 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } } - test("block generator") { + ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString) - val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) - val expectedBlocks = 5 - val waitTime = expectedBlocks * blockInterval + (blockInterval / 2) - val generatedData = new ArrayBuffer[Int] - - // Generate blocks - val startTime = System.currentTimeMillis() - blockGenerator.start() - var count = 0 - while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator.addData(count) - generatedData += count - count += 1 - Thread.sleep(10) - } - blockGenerator.stop() - - val recordedData = blockGeneratorListener.arrayBuffers.flatten - assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.toSet === generatedData.toSet) - } - - test("block generator throttling") { - val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 100 - val maxRate = 100 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString). + val blockIntervalMs = 100 + val maxRate = 1001 + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms"). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 20 - val waitTime = expectedBlocks * blockInterval + val waitTime = expectedBlocks * blockIntervalMs val expectedMessages = maxRate * waitTime / 1000 - val expectedMessagesPerBlock = maxRate * blockInterval / 1000 + val expectedMessagesPerBlock = maxRate * blockIntervalMs / 1000 val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -176,7 +150,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { blockGenerator.addData(count) generatedData += count count += 1 - Thread.sleep(1) } blockGenerator.stop() @@ -185,25 +158,31 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received") assert(recordedData.toSet === generatedData.toSet, "Received data not same") - // recordedData size should be close to the expected rate - val minExpectedMessages = expectedMessages - 3 - val maxExpectedMessages = expectedMessages + 1 + // recordedData size should be close to the expected rate; use an error margin proportional to + // the value, so that rate changes don't cause a brittle test + val minExpectedMessages = expectedMessages - 0.05 * expectedMessages + val maxExpectedMessages = expectedMessages + 0.05 * expectedMessages val numMessages = recordedData.size assert( numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages, s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages" ) - val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 - val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 + // XXX Checking every block would require an even distribution of messages across blocks, + // which throttling code does not control. Therefore, test against the average. + val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 0.05 * expectedMessagesPerBlock + val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 0.05 * expectedMessagesPerBlock val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") + + // the first and last block may be incomplete, so we slice them out + val validBlocks = recordedBlocks.drop(1).dropRight(1) + val averageBlockSize = validBlocks.map(block => block.size).sum / validBlocks.size + assert( - // the first and last block may be incomplete, so we slice them out - recordedBlocks.drop(1).dropRight(1).forall { block => - block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock - }, + averageBlockSize >= minExpectedMessagesPerBlock && + averageBlockSize <= maxExpectedMessagesPerBlock, s"# records in received blocks = [$receivedBlockSizes], not between " + - s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock" + s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock, on average" ) } @@ -220,9 +199,9 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { .setAppName(framework) .set("spark.ui.enabled", "true") .set("spark.streaming.receiver.writeAheadLog.enable", "true") - .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val batchDuration = Milliseconds(500) - val tempDirectory = Files.createTempDir() + val tempDirectory = Utils.createTempDir() val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) val allLogFiles1 = new mutable.HashSet[String]() @@ -251,9 +230,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => - tempDirectory.deleteOnExit() - val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) - val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver1 = new FakeReceiver(sendData = true) + val receiver2 = new FakeReceiver(sendData = true) val receiverStream1 = ssc.receiverStream(receiver1) val receiverStream2 = ssc.receiverStream(receiver2) receiverStream1.register() @@ -309,7 +287,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -321,30 +299,34 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } def reportError(message: String, throwable: Throwable) { errors += throwable } + + override protected def onReceiverStart(): Boolean = true + + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + null + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 9f352bdcb089..7423ef6bcb6e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,28 +17,36 @@ package org.apache.spark.streaming +import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Timeouts +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue + +import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName val batchDuration = Milliseconds(500) val sparkHome = "someDir" val envPair = "key" -> "value" + val conf = new SparkConf().setMaster(master).setAppName(appName) var sc: SparkContext = null var ssc: StreamingContext = null @@ -73,9 +81,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from existing SparkContext") { @@ -85,44 +93,101 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) - assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") + assert( + Utils.timeStringAsSeconds(cp.sparkConfPairs + .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.sparkConf.getInt("spark.cleaner.ttl", -1) === 10) + assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + } + + test("checkPoint from conf") { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.checkpointDir != null) + } + + test("state matching") { + import StreamingContextState._ + assert(INITIALIZED === INITIALIZED) + assert(INITIALIZED != ACTIVE) } test("start and stop state check") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() - assert(ssc.state === ssc.StreamingContextState.Initialized) + assert(ssc.getState() === StreamingContextState.INITIALIZED) ssc.start() - assert(ssc.state === ssc.StreamingContextState.Started) + assert(ssc.getState() === StreamingContextState.ACTIVE) ssc.stop() - assert(ssc.state === ssc.StreamingContextState.Stopped) + assert(ssc.getState() === StreamingContextState.STOPPED) + + // Make sure that the SparkContext is also stopped by default + intercept[Exception] { + ssc.sparkContext.makeRDD(1 to 10) + } + } + + test("start with non-seriazable DStream checkpoints") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir.getAbsolutePath) + addInputStream(ssc).foreachRDD { rdd => + // Refer to this.appName from inside closure so that this closure refers to + // the instance of StreamingContextSuite, and is therefore not serializable + rdd.count() + appName + } + + // Test whether start() fails early when checkpointing is enabled + val exception = intercept[NotSerializableException] { + ssc.start() + } + assert(exception.getMessage().contains("DStreams with their functions are not serializable")) + assert(ssc.getState() !== StreamingContextState.ACTIVE) + assert(StreamingContext.getActive().isEmpty) + } + + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) } + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.start() - intercept[SparkException] { - ssc.start() - } + assert(ssc.getState() === StreamingContextState.ACTIVE) + ssc.start() + assert(ssc.getState() === StreamingContextState.ACTIVE) } test("stop multiple times") { @@ -130,13 +195,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w addInputStream(ssc).register() ssc.start() ssc.stop() + assert(ssc.getState() === StreamingContextState.STOPPED) ssc.stop() + assert(ssc.getState() === StreamingContextState.STOPPED) } test("stop before start") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() // stop before start should not throw exception + assert(ssc.getState() === StreamingContextState.STOPPED) } test("start after stop") { @@ -144,22 +212,34 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() - intercept[SparkException] { + intercept[IllegalStateException] { ssc.start() // start after stop should throw exception } + assert(ssc.getState() === StreamingContextState.STOPPED) } test("stop only streaming context") { - ssc = new StreamingContext(master, appName, batchDuration) + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Explicitly do not stop SparkContext + ssc = new StreamingContext(conf, batchDuration) sc = ssc.sparkContext addInputStream(ssc).register() ssc.start() ssc.stop(stopSparkContext = false) + assert(ssc.getState() === StreamingContextState.STOPPED) assert(sc.makeRDD(1 to 100).collect().size === 100) - ssc = new StreamingContext(sc, batchDuration) + sc.stop() + + // Implicitly do not stop SparkContext + conf.set("spark.streaming.stopSparkContextByDefault", "false") + ssc = new StreamingContext(conf, batchDuration) + sc = ssc.sparkContext addInputStream(ssc).register() ssc.start() ssc.stop() + assert(sc.makeRDD(1 to 100).collect().size === 100) + sc.stop() } test("stop(stopSparkContext=true) after stop(stopSparkContext=false)") { @@ -176,12 +256,12 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600") + conf.set("spark.cleaner.ttl", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) - var runningCount = 0 + @volatile var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) input.count().foreachRDD { rdd => @@ -190,14 +270,14 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTermination(500) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(runningCount > 0) + } ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) - assert(runningCount > 0) assert( - (TestReceiver.counter.get() == runningCount + 1) || - (TestReceiver.counter.get() == runningCount + 2), + TestReceiver.counter.get() == runningCount + 1, "Received records = " + TestReceiver.counter.get() + ", " + "processed records = " + runningCount ) @@ -205,6 +285,66 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("stop gracefully even if a receiver misses StopReceiver") { + // This is not a deterministic unit. But if this unit test is flaky, then there is definitely + // something wrong. See SPARK-5681 + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + ssc = new StreamingContext(sc, Milliseconds(100)) + val input = ssc.receiverStream(new TestReceiver) + input.foreachRDD(_ => {}) + ssc.start() + // Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver" + failAfter(30000 millis) { + ssc.stop(stopSparkContext = true, stopGracefully = true) + } + } + + test("stop slow receiver gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.streaming.gracefulStopTimeout", "20000s") + sc = new SparkContext(conf) + logInfo("==================================\n\n\n") + ssc = new StreamingContext(sc, Milliseconds(100)) + var runningCount = 0 + SlowTestReceiver.receivedAllRecords = false + // Create test receiver that sleeps in onStop() + val totalNumRecords = 15 + val recordsPerSecond = 1 + val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) + input.count().foreachRDD { rdd => + val count = rdd.first() + runningCount += count.toInt + logInfo("Count = " + count + ", Running count = " + runningCount) + } + ssc.start() + ssc.awaitTerminationOrTimeout(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + assert(runningCount > 0) + assert(runningCount == totalNumRecords) + Thread.sleep(100) + } + + test ("registering and de-registering of streamingSource") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + ssc = new StreamingContext(conf, batchDuration) + assert(ssc.getState() === StreamingContextState.INITIALIZED) + addInputStream(ssc).register() + ssc.start() + + val sources = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSource = StreamingContextSuite.getStreamingSource(ssc) + assert(sources.contains(streamingSource)) + assert(ssc.getState() === StreamingContextState.ACTIVE) + + ssc.stop() + val sourcesAfterStop = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSourceAfterStop = StreamingContextSuite.getStreamingSource(ssc) + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(!sourcesAfterStop.contains(streamingSourceAfterStop)) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -217,7 +357,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w // test whether awaitTermination() exits after give amount of time failAfter(1000 millis) { - ssc.awaitTermination(500) + ssc.awaitTerminationOrTimeout(500) } // test whether awaitTermination() does not exit if not time is given @@ -229,16 +369,22 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -262,7 +408,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val exception = intercept[Exception] { ssc.start() - ssc.awaitTermination(5000) + ssc.awaitTerminationOrTimeout(5000) } assert(exception.getMessage.contains("map task"), "Expected exception not thrown") } @@ -273,20 +419,352 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w inputStream.transform { rdd => throw new TestException("error in transform"); rdd }.register() val exception = intercept[TestException] { ssc.start() - ssc.awaitTermination(5000) + ssc.awaitTerminationOrTimeout(5000) } assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } + test("awaitTerminationOrTimeout") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register() + + ssc.start() + + // test whether awaitTerminationOrTimeout() return false after give amount of time + failAfter(1000 millis) { + assert(ssc.awaitTerminationOrTimeout(500) === false) + } + + var t: Thread = null + // test whether awaitTerminationOrTimeout() return true if context is stopped + failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown + t = new Thread() { + override def run() { + Thread.sleep(500) + ssc.stop() + } + } + t.start() + assert(ssc.awaitTerminationOrTimeout(10000) === true) + } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() + } + + test("getOrCreate") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val corruptedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corruptedCheckpointPath, creatingFunction _) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corruptedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corruptedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val checkpointPath = createValidCheckpoint() + + // getOrCreate should recover context with checkpoint path, and recover old configuration + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue", "checkpointed config not recovered") + } + + // getOrCreate should recover StreamingContext with existing SparkContext + testGetOrCreate { + sc = new SparkContext(conf) + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(!ssc.conf.contains("someKey"), "checkpointed config unexpectedly recovered") + } + } + + test("getActive and getActiveOrCreate") { + require(StreamingContext.getActive().isEmpty, "context exists from before") + sc = new SparkContext(conf) + + var newContextCreated = false + + def creatingFunc(): StreamingContext = { + newContextCreated = true + val newSsc = new StreamingContext(sc, batchDuration) + val input = addInputStream(newSsc) + input.foreachRDD { rdd => rdd.count } + newSsc + } + + def testGetActiveOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + + if (ssc != null) { + ssc.stop(stopSparkContext = false) + } + ssc = null + } + } + + // getActiveOrCreate should create new context and getActive should return it only + // after starting the context + testGetActiveOrCreate { + ssc = StreamingContext.getActiveOrCreate(creatingFunc _) + assert(ssc != null, "no context created") + assert(newContextCreated === true, "new context not created") + assert(StreamingContext.getActive().isEmpty, + "new initialized context returned before starting") + ssc.start() + assert(StreamingContext.getActive() === Some(ssc), + "active context not returned") + assert(StreamingContext.getActiveOrCreate(creatingFunc _) === ssc, + "active context not returned") + ssc.stop() + assert(StreamingContext.getActive().isEmpty, + "inactive context returned") + assert(StreamingContext.getActiveOrCreate(creatingFunc _) !== ssc, + "inactive context returned") + } + + // getActiveOrCreate and getActive should return independently created context after activating + testGetActiveOrCreate { + ssc = creatingFunc() // Create + assert(StreamingContext.getActive().isEmpty, + "new initialized context returned before starting") + ssc.start() + assert(StreamingContext.getActive() === Some(ssc), + "active context not returned") + assert(StreamingContext.getActiveOrCreate(creatingFunc _) === ssc, + "active context not returned") + ssc.stop() + assert(StreamingContext.getActive().isEmpty, + "inactive context returned") + } + } + + test("getActiveOrCreate with checkpoint") { + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetActiveOrCreate(body: => Unit): Unit = { + require(StreamingContext.getActive().isEmpty) // no active context + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + val corruptedCheckpointPath = createCorruptedCheckpoint() + val checkpointPath = createValidCheckpoint() + + // getActiveOrCreate should return the current active context if there is one + testGetActiveOrCreate { + ssc = new StreamingContext( + conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock"), batchDuration) + addInputStream(ssc).register() + ssc.start() + val returnedSsc = StreamingContext.getActiveOrCreate(checkpointPath, creatingFunction _) + assert(!newContextCreated, "new context created instead of returning") + assert(returnedSsc.eq(ssc), "returned context is not the activated context") + } + + // getActiveOrCreate should create new context with empty path + testGetActiveOrCreate { + ssc = StreamingContext.getActiveOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + // getActiveOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corruptedCheckpointPath, creatingFunction _) + } + + // getActiveOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getActiveOrCreate( + corruptedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getActiveOrCreate should create new context with fake + // checkpoint file and createOnError = true + testGetActiveOrCreate { + ssc = StreamingContext.getActiveOrCreate( + corruptedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + // getActiveOrCreate should recover context with checkpoint path, and recover old configuration + testGetActiveOrCreate { + ssc = StreamingContext.getActiveOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue") + } + } + + test("multiple streaming contexts") { + sc = new SparkContext( + conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")) + ssc = new StreamingContext(sc, Seconds(1)) + val input = addInputStream(ssc) + input.foreachRDD { rdd => rdd.count } + ssc.start() + + // Creating another streaming context should not create errors + val anotherSsc = new StreamingContext(sc, Seconds(10)) + val anotherInput = addInputStream(anotherSsc) + anotherInput.foreachRDD { rdd => rdd.count } + + val exception = intercept[IllegalStateException] { + anotherSsc.start() + } + assert(exception.getMessage.contains("StreamingContext"), "Did not get the right exception") + } + test("DStream and generated RDD creation sites") { testPackage.test() } + test("throw exception on using active or stopped context") { + val conf = new SparkConf() + .setMaster(master) + .setAppName(appName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") + ssc = new StreamingContext(conf, batchDuration) + require(ssc.getState() === StreamingContextState.INITIALIZED) + val input = addInputStream(ssc) + val transformed = input.map { x => x} + transformed.foreachRDD { rdd => rdd.count } + + def testForException(clue: String, expectedErrorMsg: String)(body: => Unit): Unit = { + withClue(clue) { + val ex = intercept[IllegalStateException] { + body + } + assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) + } + } + + ssc.start() + require(ssc.getState() === StreamingContextState.ACTIVE) + testForException("no error on adding input after start", "start") { + addInputStream(ssc) } + testForException("no error on adding transformation after start", "start") { + input.map { x => x * 2 } } + testForException("no error on adding output operation after start", "start") { + transformed.foreachRDD { rdd => rdd.collect() } } + + ssc.stop() + require(ssc.getState() === StreamingContextState.STOPPED) + testForException("no error on adding input after stop", "stop") { + addInputStream(ssc) } + testForException("no error on adding transformation after stop", "stop") { + input.map { x => x * 2 } } + testForException("no error on adding output operation after stop", "stop") { + transformed.foreachRDD { rdd => rdd.collect() } } + } + + test("queueStream doesn't support checkpointing") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(master, appName, batchDuration) + val rdd = ssc.sparkContext.parallelize(1 to 10) + ssc.queueStream[Int](Queue(rdd)).print() + ssc.checkpoint(checkpointDir.getAbsolutePath) + val e = intercept[NotSerializableException] { + ssc.start() + } + // StreamingContext.validate changes the message, so use "contains" here + assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) inputStream } + + def createValidCheckpoint(): String = { + val testDirectory = Utils.createTempDir().getAbsolutePath() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + ssc = new StreamingContext(conf.clone.set("someKey", "someValue"), batchDuration) + ssc.checkpoint(checkpointDirectory) + ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + checkpointDirectory + } + + def createCorruptedCheckpoint(): String = { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) + FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) + checkpointDirectory + } } class TestException(msg: String) extends Exception(msg) @@ -311,7 +789,8 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging } def onStop() { - // no cleanup to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own, so just wait for it. + receivingThreadOption.foreach(_.join()) } } @@ -319,6 +798,41 @@ object TestReceiver { val counter = new AtomicInteger(1) } +/** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + logInfo("Receiving started") + for(i <- 1 to totalRecords) { + Thread.sleep(1000 / recordsPerSecond) + store(i) + } + SlowTestReceiver.receivedAllRecords = true + logInfo(s"Received all $totalRecords records") + } + } + receivingThreadOption = Some(thread) + thread.start() + } + + def onStop() { + // Simulate slow receiver by waiting for all records to be produced + while (!SlowTestReceiver.receivedAllRecords) { + Thread.sleep(100) + } + // no clean to be done, the receiving thread should stop on it own + } +} + +object SlowTestReceiver { + var receivedAllRecords = false +} + /** Streaming application for testing DStream and RDD creation sites */ package object testPackage extends Assertions { def test() { @@ -356,3 +870,18 @@ package object testPackage extends Assertions { } } } + +/** + * Helper methods for testing StreamingContextSuite + * This includes methods to access private methods and fields in StreamingContext and MetricsSystem + */ +private object StreamingContextSuite extends PrivateMethodTester { + private val _sources = PrivateMethod[ArrayBuffer[Source]]('sources) + private def getSources(metricsSystem: MetricsSystem): ArrayBuffer[Source] = { + metricsSystem.invokePrivate(_sources()) + } + private val _streamingSource = PrivateMethod[StreamingSource]('streamingSource) + private def getStreamingSource(streamingContext: StreamingContext): StreamingSource = { + streamingContext.invokePrivate(_streamingSource()) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f7..d840c349bbbc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global @@ -36,20 +36,67 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) - val batchInfos = collector.batchInfos - batchInfos should have size 4 - batchInfos.foreach(info => { + // SPARK-6766: batch info should be submitted + val batchInfosSubmitted = collector.batchInfosSubmitted + batchInfosSubmitted should have size 4 + + batchInfosSubmitted.foreach(info => { + info.schedulingDelay should be (None) + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + batchInfosSubmitted.foreach { info => + info.numRecords should be (1L) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) + } + + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + + // SPARK-6766: processingStartTime of batch info should not be None when starting + val batchInfosStarted = collector.batchInfosStarted + batchInfosStarted should have size 4 + + batchInfosStarted.foreach(info => { + info.schedulingDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + batchInfosStarted.foreach { info => + info.numRecords should be (1L) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) + } + + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + + // test onBatchCompleted + val batchInfosCompleted = collector.batchInfosCompleted + batchInfosCompleted should have size 4 + + batchInfosCompleted.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -58,13 +105,18 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + batchInfosCompleted.foreach { info => + info.numRecords should be (1L) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) + } + + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) @@ -73,7 +125,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 @@ -90,8 +142,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for(i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) return false + for (i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) { + return false + } } true } @@ -99,17 +153,29 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfos = new ArrayBuffer[BatchInfo] + val batchInfosCompleted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] + val batchInfosStarted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] + val batchInfosSubmitted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { + batchInfosSubmitted += batchSubmitted.batchInfo + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { + batchInfosStarted += batchStarted.batchInfo + } + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfos += batchCompleted.batchInfo + batchInfosCompleted += batchCompleted.batchInfo } } /** Listener that collects information on processed batches */ class ReceiverInfoCollector extends StreamingListener { - val startedReceiverStreamIds = new ArrayBuffer[Int] - val stoppedReceiverStreamIds = new ArrayBuffer[Int]() - val receiverErrors = new ArrayBuffer[(Int, String, String)]() + val startedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + val stoppedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + val receiverErrors = + new ArrayBuffer[(Int, String, String)] with SynchronizedBuffer[(Int, String, String)] override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { startedReceiverStreamIds += receiverStarted.receiverInfo.streamId diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 7d82c3e4aadc..0d58a7b54412 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,17 +24,34 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} -import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener} -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{ManualClock, Utils} + +/** + * A dummy stream that does absolutely nothing. + */ +private[streaming] class DummyDStream(ssc: StreamingContext) extends DStream[Int](ssc) { + override def dependencies: List[DStream[Int]] = List.empty + override def slideDuration: Duration = Seconds(1) + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + +/** + * A dummy input stream that does absolutely nothing. + */ +private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputDStream[Int](ssc) { + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -54,8 +71,13 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } + + // Report the input data's information to InputInfoTracker for testing + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -69,9 +91,11 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], * * The buffer contains a sequence of RDD's, each containing a sequence of items */ -class TestOutputStream[T: ClassTag](parent: DStream[T], - val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) - extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { +class TestOutputStream[T: ClassTag]( + parent: DStream[T], + val output: SynchronizedBuffer[Seq[T]] = + new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + ) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected }) { @@ -91,8 +115,10 @@ class TestOutputStream[T: ClassTag](parent: DStream[T], * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each * containing a sequence of items. */ -class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], - val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]()) +class TestOutputStreamWithPartitions[T: ClassTag]( + parent: DStream[T], + val output: SynchronizedBuffer[Seq[Seq[T]]] = + new ArrayBuffer[Seq[Seq[T]]] with SynchronizedBuffer[Seq[Seq[T]]]) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.glom().collect().map(_.toSeq) output += collected @@ -104,8 +130,6 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], ois.defaultReadObject() output.clear() } - - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) } /** @@ -140,43 +164,77 @@ class BatchCounter(ssc: StreamingContext) { def getNumStartedBatches: Int = this.synchronized { numStartedBatches } + + /** + * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if + * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's + * timeout. + * + * @param expectedNumCompletedBatches the `expectedNumCompletedBatches` batches to wait + * @param timeout the maximum time to wait in milliseconds. + */ + def waitUntilBatchesCompleted(expectedNumCompletedBatches: Int, timeout: Long): Boolean = + waitUntilConditionBecomeTrue(numCompletedBatches >= expectedNumCompletedBatches, timeout) + + /** + * Wait until `expectedNumStartedBatches` batches are completed, or timeout. Return true if + * `expectedNumStartedBatches` batches are completed. Otherwise, return false to indicate it's + * timeout. + * + * @param expectedNumStartedBatches the `expectedNumStartedBatches` batches to wait + * @param timeout the maximum time to wait in milliseconds. + */ + def waitUntilBatchesStarted(expectedNumStartedBatches: Int, timeout: Long): Boolean = + waitUntilConditionBecomeTrue(numStartedBatches >= expectedNumStartedBatches, timeout) + + private def waitUntilConditionBecomeTrue(condition: => Boolean, timeout: Long): Boolean = { + synchronized { + var now = System.currentTimeMillis() + val timeoutTick = now + timeout + while (!condition && timeoutTick > now) { + wait(timeoutTick - now) + now = System.currentTimeMillis() + } + condition + } + } } /** * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -189,10 +247,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def beforeFunction() { if (useManualClock) { logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") } else { logInfo("Using real clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.SystemClock") } } @@ -333,23 +391,24 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.currentTime()) + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) if (actuallyWait) { for (i <- 1 to numBatches) { logInfo("Actually waiting for " + batchDuration) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } } else { - clock.addToTime(numBatches * batchDuration.milliseconds) + clock.advance(numBatches * batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.currentTime()) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - ssc.awaitTermination(50) + ssc.awaitTerminationOrTimeout(50) } val timeTaken = System.currentTimeMillis() - startTime logInfo("Output generated in " + timeTaken + " milliseconds") @@ -384,12 +443,21 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { logInfo("--------------------------------") // Match the output with the expected output - assert(output.size === expectedOutput.size, "Number of outputs do not match") for (i <- 0 until output.size) { if (useSet) { - assert(output(i).toSet === expectedOutput(i).toSet) + assert( + output(i).toSet === expectedOutput(i).toSet, + s"Set comparison failed\n" + + s"Expected output (${expectedOutput.size} items):\n${expectedOutput.mkString("\n")}\n" + + s"Generated output (${output.size} items): ${output.mkString("\n")}" + ) } else { - assert(output(i).toList === expectedOutput(i).toList) + assert( + output(i).toList === expectedOutput(i).toList, + s"Ordered list comparison failed\n" + + s"Expected output (${expectedOutput.size} items):\n${expectedOutput.mkString("\n")}\n" + + s"Generated output (${output.size} items): ${output.mkString("\n")}" + ) } } logInfo("Output verified successfully") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala index 5579ac364346..e6a01656f479 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala @@ -69,6 +69,9 @@ class TimeSuite extends TestSuiteBase { assert(new Time(1200).floor(new Duration(200)) == new Time(1200)) assert(new Time(199).floor(new Duration(200)) == new Time(0)) assert(new Time(1).floor(new Duration(1)) == new Time(1)) + assert(new Time(1350).floor(new Duration(200), new Time(50)) == new Time(1250)) + assert(new Time(1350).floor(new Duration(200), new Time(150)) == new Time(1350)) + assert(new Time(1350).floor(new Duration(200), new Time(200)) == new Time(1200)) } test("isMultipleOf") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala new file mode 100644 index 000000000000..068a6cb0e8fa --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.collection.mutable.Queue + +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.selenium.WebBrowser +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.ui.SparkUICssErrorHandler + +/** + * Selenium tests for the Spark Streaming Web UI. + */ +class UISeleniumSuite + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + + implicit var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + } + + /** + * Create a test SparkStreamingContext with the SparkUI enabled. + */ + private def newSparkStreamingContext(): StreamingContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val ssc = new StreamingContext(conf, Seconds(1)) + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + ssc + } + + private def setupStreams(ssc: StreamingContext): Unit = { + val rdds = Queue(ssc.sc.parallelize(1 to 4, 4)) + val inputStream = ssc.queueStream(rdds) + inputStream.foreachRDD { rdd => + rdd.foreach(_ => {}) + rdd.foreach(_ => {}) + } + inputStream.foreachRDD { rdd => + rdd.foreach(_ => {}) + try { + rdd.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException if e.getMessage.contains("Oops") => + } + } + } + + test("attaching and detaching a Streaming tab") { + withStreamingContext(newSparkStreamingContext()) { ssc => + setupStreams(ssc) + ssc.start() + + val sparkUI = ssc.sparkContext.ui.get + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/")) + find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check whether streaming page exists + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + val h3Text = findAll(cssSelector("h3")).map(_.text).toSeq + h3Text should contain("Streaming Statistics") + + // Check stat table + val statTableHeaders = findAll(cssSelector("#stat-table th")).map(_.text).toSeq + statTableHeaders.exists( + _.matches("Timelines \\(Last \\d+ batches, \\d+ active, \\d+ completed\\)") + ) should be (true) + statTableHeaders should contain ("Histograms") + + val statTableCells = findAll(cssSelector("#stat-table td")).map(_.text).toSeq + statTableCells.exists(_.contains("Input Rate")) should be (true) + statTableCells.exists(_.contains("Scheduling Delay")) should be (true) + statTableCells.exists(_.contains("Processing Time")) should be (true) + statTableCells.exists(_.contains("Total Delay")) should be (true) + + // Check batch tables + val h4Text = findAll(cssSelector("h4")).map(_.text).toSeq + h4Text.exists(_.matches("Active Batches \\(\\d+\\)")) should be (true) + h4Text.exists(_.matches("Completed Batches \\(last \\d+ out of \\d+\\)")) should be (true) + + findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", + "Status") + } + findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", + "Total Delay (?)") + } + + val batchLinks = + findAll(cssSelector("""#completed-batches-table a""")).flatMap(_.attribute("href")).toSeq + batchLinks.size should be >= 1 + + // Check a normal batch page + go to (batchLinks.last) // Last should be the first batch, so it will have some jobs + val summaryText = findAll(cssSelector("li strong")).map(_.text).toSeq + summaryText should contain ("Batch Duration:") + summaryText should contain ("Input data size:") + summaryText should contain ("Scheduling delay:") + summaryText should contain ("Processing time:") + summaryText should contain ("Total delay:") + + findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { + List("Output Op Id", "Description", "Duration", "Job Id", "Duration", + "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") + } + + // Check we have 2 output op ids + val outputOpIds = findAll(cssSelector(".output-op-id-cell")).toSeq + outputOpIds.map(_.attribute("rowspan")) should be (List(Some("2"), Some("2"))) + outputOpIds.map(_.text) should be (List("0", "1")) + + // Check job ids + val jobIdCells = findAll(cssSelector( """#batch-job-table a""")).toSeq + jobIdCells.map(_.text) should be (List("0", "1", "2", "3")) + + val jobLinks = jobIdCells.flatMap(_.attribute("href")) + jobLinks.size should be (4) + + // Check stage progress + findAll(cssSelector(""".stage-progress-cell""")).map(_.text).toSeq should be + (List("1/1", "1/1", "1/1", "0/1 (1 failed)")) + + // Check job progress + findAll(cssSelector(""".progress-cell""")).map(_.text).toSeq should be + (List("1/1", "1/1", "1/1", "0/1 (1 failed)")) + + // Check stacktrace + val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.text).toSeq + errorCells should have size 1 + errorCells(0) should include("java.lang.RuntimeException: Oops") + + // Check the job link in the batch page is right + go to (jobLinks(0)) + val jobDetails = findAll(cssSelector("li strong")).map(_.text).toSeq + jobDetails should contain("Status:") + jobDetails should contain("Completed Stages:") + + // Check a batch page without id + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/") + webDriver.getPageSource should include ("Missing id parameter") + + // Check a non-exist batch + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/?id=12345") + webDriver.getPageSource should include ("does not exist") + } + + ssc.stop(false) + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/")) + find(cssSelector( """ul li a[href*="streaming"]""")) should be(None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + val h3Text = findAll(cssSelector("h3")).map(_.text).toSeq + h3Text should not contain("Streaming Statistics") + } + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala deleted file mode 100644 index 8e3011826685..000000000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import scala.io.Source - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.SparkConf - -class UISuite extends FunSuite { - - // Ignored: See SPARK-1530 - ignore("streaming tab in spark UI") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - val ssc = new StreamingContext(conf, Seconds(1)) - assert(ssc.sc.ui.isDefined, "Spark UI is not started!") - val ui = ssc.sc.ui.get - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ui.appUIAddress).mkString - assert(!html.contains("random data that should not be present")) - // test if streaming tab exist - assert(html.toLowerCase.contains("streaming")) - // test if other Spark tabs still exist - assert(html.toLowerCase.contains("stages")) - } - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ui.appUIAddress.stripSuffix("/") + "/streaming").mkString - assert(html.toLowerCase.contains("batch")) - assert(html.toLowerCase.contains("network")) - } - } -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16..c39ad05f4152 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577d..cb017b798b2a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,17 +21,20 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} -import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} +import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} + +class WriteAheadLogBackedBlockRDDSuite + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -57,24 +60,35 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w System.clearProperty("spark.driver.port") } - test("Read data available in block manager and write ahead log") { - testRDD(5, 5) + test("Read data available in both block manager and write ahead log") { + testRDD(numPartitions = 5, numPartitionsInBM = 5, numPartitionsInWAL = 5) } test("Read data available only in block manager, not in write ahead log") { - testRDD(5, 0) + testRDD(numPartitions = 5, numPartitionsInBM = 5, numPartitionsInWAL = 0) } test("Read data available only in write ahead log, not in block manager") { - testRDD(0, 5) + testRDD(numPartitions = 5, numPartitionsInBM = 0, numPartitionsInWAL = 5) } - test("Read data available only in write ahead log, and test storing in block manager") { - testRDD(0, 5, testStoreInBM = true) + test("Read data with partially available in block manager, and rest in write ahead log") { + testRDD(numPartitions = 5, numPartitionsInBM = 3, numPartitionsInWAL = 2) } - test("Read data with partially available in block manager, and rest in write ahead log") { - testRDD(3, 2) + test("Test isBlockValid skips block fetching from BlockManager") { + testRDD( + numPartitions = 5, numPartitionsInBM = 5, numPartitionsInWAL = 0, testIsBlockValid = true) + } + + test("Test whether RDD is valid after removing blocks from block manager") { + testRDD( + numPartitions = 5, numPartitionsInBM = 5, numPartitionsInWAL = 5, testBlockRemove = true) + } + + test("Test storing of blocks recovered from write ahead log back into block manager") { + testRDD( + numPartitions = 5, numPartitionsInBM = 0, numPartitionsInWAL = 5, testStoreInBM = true) } /** @@ -82,23 +96,54 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * and the rest to a write ahead log, and then reading reading it all back using the RDD. * It can also test if the partitions that were read from the log were again stored in * block manager. - * @param numPartitionsInBM Number of partitions to write to the Block Manager - * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log - * @param testStoreInBM Test whether blocks read from log are stored back into block manager + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log. + * Partitions (numPartitions - 1 - numPartitionsInWAL) to + * (numPartitions - 1) will be written to WAL + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * @param testStoreInBM Test whether blocks read from log are stored back into block manager + * + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInWAL = 4 */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { - val numBlocks = numPartitionsInBM + numPartitionsInWAL - val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInWAL: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false, + testStoreInBM: Boolean = false + ) { + require(numPartitionsInBM <= numPartitions, + "Can't put more partitions in BlockManager than that in RDD") + require(numPartitionsInWAL <= numPartitions, + "Can't put more partitions in write ahead log than that in RDD") + val data = Seq.fill(numPartitions, 10)(scala.util.Random.nextString(50)) // Put the necessary blocks in the block manager - val blockIds = Array.fill(numBlocks)(StreamBlockId(Random.nextInt(), Random.nextInt())) + val blockIds = Array.fill(numPartitions)(StreamBlockId(Random.nextInt(), Random.nextInt())) data.zip(blockIds).take(numPartitionsInBM).foreach { case(block, blockId) => blockManager.putIterator(blockId, block.iterator, StorageLevel.MEMORY_ONLY_SER) } - // Generate write ahead log segments - val segments = generateFakeSegments(numPartitionsInBM) ++ - writeLogSegments(data.takeRight(numPartitionsInWAL), blockIds.takeRight(numPartitionsInWAL)) + // Generate write ahead log record handles + val recordHandles = generateFakeRecordHandles(numPartitions - numPartitionsInWAL) ++ + generateWALRecordHandles(data.takeRight(numPartitionsInWAL), + blockIds.takeRight(numPartitionsInWAL)) // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not require( @@ -106,30 +151,53 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Expected blocks not in BlockManager" ) require( - blockIds.takeRight(numPartitionsInWAL).forall(blockManager.get(_).isEmpty), + blockIds.takeRight(numPartitions - numPartitionsInBM).forall(blockManager.get(_).isEmpty), "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( - segments.takeRight(numPartitionsInWAL).forall(s => + recordHandles.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), "Expected blocks not in write ahead log" ) require( - segments.take(numPartitionsInBM).forall(s => + recordHandles.take(numPartitions - numPartitionsInWAL).forall(s => !new File(s.path.stripPrefix("file://")).exists()), "Unexpected blocks in write ahead log" ) // Create the RDD and verify whether the returned data is correct val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, - segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) + recordHandles.toArray, storeInBlockManager = false) assert(rdd.collect() === data.flatten) + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInWAL === 0, "No partitions must be in WAL") + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, + recordHandles.toArray, isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInWAL, "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.collect() === data.flatten) + } + if (testStoreInBM) { val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, - segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) + recordHandles.toArray, storeInBlockManager = true, storageLevel = StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( blockIds.forall(blockManager.get(_).nonEmpty), @@ -138,12 +206,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } } - private def writeLogSegments( + private def generateWALRecordHandles( blockData: Seq[Seq[String]], blockIds: Seq[BlockId] - ): Seq[WriteAheadLogFileSegment] = { + ): Seq[FileBasedWriteAheadLogSegment] = { require(blockData.size === blockIds.size) - val writer = new WriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) + val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => writer.write(blockManager.dataSerialize(id, data.iterator)) } @@ -151,7 +219,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w segments } - private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + private def generateFakeRecordHandles(count: Int): Seq[FileBasedWriteAheadLogSegment] = { + Array.fill(count)(new FileBasedWriteAheadLogSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala new file mode 100644 index 000000000000..a38cc603f219 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} + +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { + + private val blockIntervalMs = 10 + private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") + @volatile private var blockGenerator: BlockGenerator = null + + after { + if (blockGenerator != null) { + blockGenerator.stop() + } + } + + test("block generation and data callbacks") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + + require(blockIntervalMs > 5) + require(listener.onAddDataCalled === false) + require(listener.onGenerateBlockCalled === false) + require(listener.onPushBlockCalled === false) + + // Verify that creating the generator does not start it + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + assert(blockGenerator.isActive() === false, "block generator active before start()") + assert(blockGenerator.isStopped() === false, "block generator stopped before start()") + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + + // Verify start marks the generator active, but does not call the callbacks + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator active after start()") + assert(blockGenerator.isStopped() === false, "block generator stopped after start()") + withClue("callbacks called before adding data") { + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + + // Verify whether addData() adds data that is present in generated blocks + val data1 = 1 to 10 + data1.foreach { blockGenerator.addData _ } + withClue("callbacks called on adding data without metadata and without block generation") { + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + clock.advance(blockIntervalMs) // advance clock to generate blocks + withClue("blocks not generated or pushed") { + eventually(timeout(1 second)) { + assert(listener.onGenerateBlockCalled === true) + assert(listener.onPushBlockCalled === true) + } + } + listener.pushedData should contain theSameElementsInOrderAs (data1) + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + + // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly + val data2 = 11 to 20 + val metadata2 = data2.map { _.toString } + data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } + assert(listener.onAddDataCalled === true) + listener.addedData should contain theSameElementsInOrderAs (data2) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + } + + // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly + val data3 = 21 to 30 + val metadata3 = "metadata" + blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + } + + // Stop the block generator by starting the stop on a different thread and + // then advancing the manual clock for the stopping to proceed. + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + clock.advance(blockIntervalMs) + assert(blockGenerator.isStopped() === true) + } + thread.join() + + // Verify that the generator cannot be used any more + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, 1) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), 1) + } + intercept[SparkException] { + blockGenerator.start() + } + blockGenerator.stop() // Calling stop again should be fine + } + + test("stop ensures correct shutdown") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + require(listener.onGenerateBlockCalled === false) + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator") + assert(blockGenerator.isStopped() === false) + + val data = 1 to 1000 + data.foreach { blockGenerator.addData _ } + + // Verify that stop() shutdowns everything in the right order + // - First, stop receiving new data + // - Second, wait for final block with all buffered data to be generated + // - Finally, wait for all blocks to be pushed + clock.advance(1) // to make sure that the timer for another interval to complete + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(blockGenerator.isActive() === false) + } + assert(blockGenerator.isStopped() === false) + + // Verify that data cannot be added + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, null) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), null) + } + + // Verify that stop() stays blocked until another block containing all the data is generated + // This intercept always succeeds, as the body either will either throw a timeout exception + // (expected as stop() should never complete) or a SparkException (unexpected as stop() + // completed and thread terminated). + val exception = intercept[Exception] { + failAfter(200 milliseconds) { + thread.join() + throw new SparkException( + "BlockGenerator.stop() completed before generating timer was stopped") + } + } + exception should not be a [SparkException] + + + // Verify that the final data is present in the final generated block and + // pushed before complete stop + assert(blockGenerator.isStopped() === false) // generator has not stopped yet + clock.advance(blockIntervalMs) // force block generation + failAfter(1 second) { + thread.join() + } + assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped + assert(listener.pushedData === data, "All data not pushed by stop()") + } + + test("block push errors are reported") { + val listener = new TestBlockGeneratorListener { + @volatile var errorReported = false + override def onPushBlock( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + throw new SparkException("test") + } + override def onError(message: String, throwable: Throwable): Unit = { + errorReported = true + } + } + blockGenerator = new BlockGenerator(listener, 0, conf) + blockGenerator.start() + assert(listener.errorReported === false) + blockGenerator.addData(1) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(listener.errorReported === true) + } + blockGenerator.stop() + } + + /** + * Helper method to stop the block generator with manual clock in a different thread, + * so that the main thread can advance the clock that allows the stopping to proceed. + */ + private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = { + val thread = new Thread() { + override def run(): Unit = { + blockGenerator.stop() + } + } + thread.start() + thread + } + + /** A listener for BlockGenerator that records the data in the callbacks */ + private class TestBlockGeneratorListener extends BlockGeneratorListener { + val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + @volatile var onGenerateBlockCalled = false + @volatile var onAddDataCalled = false + @volatile var onPushBlockCalled = false + + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + pushedData ++= arrayBuffer + onPushBlockCalled = true + } + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = { + onGenerateBlockCalled = true + } + override def onAddData(data: Any, metadata: Any): Unit = { + addedData += data + addedMetadata += metadata + onAddDataCalled = true + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala new file mode 100644 index 000000000000..c6330eb3673f --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite + +/** Testsuite for testing the network receiver behavior */ +class RateLimiterSuite extends SparkFunSuite { + + test("rate limiter initializes even without a maxRate set") { + val conf = new SparkConf() + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter updates when below maxRate") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter stays below maxRate despite large updates") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit === 100) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala new file mode 100644 index 000000000000..f5248acf712b --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.{Time, Duration, StreamingContext} + +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { + + private var ssc: StreamingContext = _ + + before { + val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker") + if (ssc == null) { + ssc = new StreamingContext(conf, Duration(1000)) + } + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + } + + test("test report and get InputInfo from InputInfoTracker") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val streamId2 = 1 + val time = Time(0L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId2, 300L) + inputInfoTracker.reportInfo(time, inputInfo1) + inputInfoTracker.reportInfo(time, inputInfo2) + + val batchTimeToInputInfos = inputInfoTracker.getInfo(time) + assert(batchTimeToInputInfos.size == 2) + assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2)) + assert(batchTimeToInputInfos(streamId1) === inputInfo1) + assert(batchTimeToInputInfos(streamId2) === inputInfo2) + assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1) + } + + test("test cleanup InputInfo from InputInfoTracker") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) + inputInfoTracker.reportInfo(Time(0), inputInfo1) + inputInfoTracker.reportInfo(Time(1), inputInfo2) + + inputInfoTracker.cleanup(Time(0)) + assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1) + assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) + + inputInfoTracker.cleanup(Time(1)) + assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None) + assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala new file mode 100644 index 000000000000..9b6cd4bc4e31 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.util.concurrent.CountDownLatch + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming._ +import org.apache.spark.util.{ManualClock, Utils} + +class JobGeneratorSuite extends TestSuiteBase { + + // SPARK-6222 is a tricky regression bug which causes received block metadata + // to be deleted before the corresponding batch has completed. This occurs when + // the following conditions are met. + // 1. streaming checkpointing is enabled by setting streamingContext.checkpoint(dir) + // 2. input data is received through a receiver as blocks + // 3. a batch processing a set of blocks takes a long time, such that a few subsequent + // batches have been generated and submitted for processing. + // + // The JobGenerator (as of Mar 16, 2015) checkpoints twice per batch, once after generation + // of a batch, and another time after the completion of a batch. The cleanup of + // checkpoint data (including block metadata, etc.) from DStream must be done only after the + // 2nd checkpoint has completed, that is, after the batch has been completely processed. + // However, the issue is that the checkpoint data and along with it received block data is + // cleaned even in the case of the 1st checkpoint, causing pre-mature deletion of received block + // data. For example, if the 3rd batch is still being process, the 7th batch may get generated, + // and the corresponding "1st checkpoint" will delete received block metadata of batch older + // than 6th batch. That, is 3rd batch's block metadata gets deleted even before 3rd batch has + // been completely processed. + // + // This test tries to create that scenario by the following. + // 1. enable checkpointing + // 2. generate batches with received blocks + // 3. make the 3rd batch never complete + // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) + // 5. verify whether 3rd batch's block metadata still exists + // + // TODO: SPARK-7420 enable this test + ignore("SPARK-6222: Do not clear received block data too soon") { + import JobGeneratorSuite._ + val checkpointDir = Utils.createTempDir() + val testConf = conf + testConf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + testConf.set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + + withStreamingContext(new StreamingContext(testConf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val numBatches = 10 + val longBatchNumber = 3 // 3rd batch will take a long time + val longBatchTime = longBatchNumber * batchDuration.milliseconds + + val testTimeout = timeout(10 seconds) + val inputStream = ssc.receiverStream(new TestReceiver) + + inputStream.foreachRDD((rdd: RDD[Int], time: Time) => { + if (time.milliseconds == longBatchTime) { + while (waitLatch.getCount() > 0) { + waitLatch.await() + } + } + }) + + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.getAbsolutePath) + ssc.start() + + // Make sure the only 1 batch of information is to be remembered + assert(inputStream.rememberDuration === batchDuration) + val receiverTracker = ssc.scheduler.receiverTracker + + // Get the blocks belonging to a batch + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { + receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) + } + + // Wait for new blocks to be received + def waitForNewReceivedBlocks() { + eventually(testTimeout) { + assert(receiverTracker.hasUnallocatedBlocks) + } + } + + // Wait for received blocks to be allocated to a batch + def waitForBlocksToBeAllocatedToBatch(batchTime: Long) { + eventually(testTimeout) { + assert(getBlocksOfBatch(batchTime).nonEmpty) + } + } + + // Generate a large number of batches with blocks in them + for (batchNum <- 1 to numBatches) { + waitForNewReceivedBlocks() + clock.advance(batchDuration.milliseconds) + waitForBlocksToBeAllocatedToBatch(clock.getTimeMillis()) + } + + // Wait for 3rd batch to start + eventually(testTimeout) { + ssc.scheduler.getPendingTimes().contains(Time(numBatches * batchDuration.milliseconds)) + } + + // Verify that the 3rd batch's block data is still present while the 3rd batch is incomplete + assert(getBlocksOfBatch(longBatchTime).nonEmpty, "blocks of incomplete batch already deleted") + assert(batchCounter.getNumCompletedBatches < longBatchNumber) + waitLatch.countDown() + } + } +} + +object JobGeneratorSuite { + val waitLatch = new CountDownLatch(1) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala new file mode 100644 index 000000000000..1eb52b7029a2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +class RateControllerSuite extends TestSuiteBase { + + override def useManualClock: Boolean = false + + override def batchDuration: Duration = Milliseconds(50) + + test("RateController - rate controller publishes updates after batches complete") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateTestInputDStream(ssc) + dstream.register() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.publishedRates > 0) + } + } + } + + test("ReceiverRateController - published rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val estimator = new ConstantEstimator(100) + val dstream = new RateTestInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, estimator)) + } + dstream.register() + ssc.start() + + // Wait for receiver to start + eventually(timeout(5.seconds)) { + RateTestReceiver.getActive().nonEmpty + } + + // Update rate in the estimator and verify whether the rate was published to the receiver + def updateRateAndVerify(rate: Long): Unit = { + estimator.updateRate(rate) + eventually(timeout(5.seconds)) { + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === rate) + } + } + + // Verify multiple rate update + Seq(100, 200, 300).foreach { rate => + updateRateAndVerify(rate) + } + } + } +} + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala new file mode 100644 index 000000000000..b2a51d72bac2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite + +class ReceiverSchedulingPolicySuite extends SparkFunSuite { + + val receiverSchedulingPolicy = new ReceiverSchedulingPolicy + + test("rescheduleReceiver: empty executors") { + val scheduledExecutors = + receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) + assert(scheduledExecutors === Seq.empty) + } + + test("rescheduleReceiver: receiver preferredLocation") { + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) + assert(scheduledExecutors.toSet === Set("host1", "host2")) + } + + test("rescheduleReceiver: return all idle executors if there are any idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // host3 is idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 1, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + } + + test("rescheduleReceiver: return all executors that have minimum weight if no idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0, host4 = 0.5, host5 = 0.5 + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), + 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None), + 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, Some(Seq("host4", "host5")), None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 4, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + } + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more receivers than executors") { + val receivers = (0 until 6).map(new RateTestReceiver(_)) + val executors = (10000 until 10003).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 2 receivers running on each executor and each receiver has one executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + } + assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) + } + + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more executors than receivers") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) + val executors = (10000 until 10006).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has two executors + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 2) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + } + + test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ + (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) + val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ + (10003 until 10006).map(port => s"localhost2:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has 1 executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + // Make sure we schedule the receivers to their preferredLocations + val executorsForReceiversWithPreferredLocation = + scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + // We can simply check the executor set because we only know each receiver only has 1 executor + assert(executorsForReceiversWithPreferredLocation.toSet === + (10000 until 10003).map(port => s"localhost:${port}").toSet) + } + + test("scheduleReceivers: return empty if no receiver") { + assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + } + + test("scheduleReceivers: return empty scheduled executors if no executors") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.isEmpty) + } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala new file mode 100644 index 000000000000..dd292ba4dd94 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver._ + +/** Testsuite for receiver scheduling */ +class ReceiverTrackerSuite extends TestSuiteBase { + + test("send rate update to receivers") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + ssc.scheduler.listenerBus.start(ssc.sc) + + val newRateLimit = 100L + val inputDStream = new RateTestInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + } + + + // Verify that the rate of the block generator in the receiver get updated + val activeReceiver = RateTestReceiver.getActive().get + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + eventually(timeout(5 seconds)) { + assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === newRateLimit, + "default block generator did not receive rate update") + assert(activeReceiver.getCustomBlockGeneratorRateLimit() === newRateLimit, + "other block generator did not receive rate update") + } + } finally { + tracker.stop(false) + } + } + } +} + +/** An input DStream with for testing rate controlling */ +private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) + extends ReceiverInputDStream[Int](ssc_) { + + override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) + + @volatile + var publishedRates = 0 + + override val rateController: Option[RateController] = { + Some(new RateController(id, new ConstantEstimator(100)) { + override def publish(rate: Long): Unit = { + publishedRates += 1 + } + }) + } +} + +/** A receiver implementation for testing rate controlling */ +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + private lazy val customBlockGenerator = supervisor.createBlockGenerator( + new BlockGeneratorListener { + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = {} + override def onAddData(data: Any, metadata: Any): Unit = {} + } + ) + + setReceiverId(receiverId) + + override def onStart(): Unit = { + customBlockGenerator + RateTestReceiver.registerReceiver(this) + } + + override def onStop(): Unit = { + RateTestReceiver.deregisterReceiver() + } + + override def preferredLocation: Option[String] = host + + def getDefaultBlockGeneratorRateLimit(): Long = { + supervisor.getCurrentRateLimit + } + + def getCustomBlockGeneratorRateLimit(): Long = { + customBlockGenerator.getCurrentLimit + } +} + +/** + * A helper object to RateTestReceiver that give access to the currently active RateTestReceiver + * instance. + */ +private[streaming] object RateTestReceiver { + @volatile private var activeReceiver: RateTestReceiver = null + + def registerReceiver(receiver: RateTestReceiver): Unit = { + activeReceiver = receiver + } + + def deregisterReceiver(): Unit = { + activeReceiver = null + } + + def getActive(): Option[RateTestReceiver] = Option(activeReceiver) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala new file mode 100644 index 000000000000..a1af95be81c8 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import scala.util.Random + +import org.scalatest.Inspectors.forAll +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.Seconds + +class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { + + test("the right estimator is created") { + val conf = new SparkConf + conf.set("spark.streaming.backpressure.rateEstimator", "pid") + val pid = RateEstimator.create(conf, Seconds(1)) + pid.getClass should equal(classOf[PIDRateEstimator]) + } + + test("estimator checks ranges") { + intercept[IllegalArgumentException] { + new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, proportional = -1, 2, 3, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, integral = -1, 3, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, derivative = -1, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = 0) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = -10) + } + } + + test("first estimate is None") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) should equal(None) + } + + test("second estimate is not None") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + // 1000 elements / s + p.compute(10, 10, 10, 0) should equal(Some(1000)) + } + + test("no estimate when no time difference between successive calls") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(time = 10, 10, 10, 0) shouldNot equal(None) + p.compute(time = 10, 10, 10, 0) should equal(None) + } + + test("no estimate when no records in previous batch") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, numElements = 0, 10, 0) should equal(None) + p.compute(20, numElements = -10, 10, 0) should equal(None) + } + + test("no estimate when there is no processing delay") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, 10, processingDelay = 0, 0) should equal(None) + p.compute(20, 10, processingDelay = -10, 0) should equal(None) + } + + test("estimate is never less than min rate") { + val minRate = 5D + val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate) + // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing + // this might point the estimator to try and decrease the bound, but we test it never + // goes below the min rate, which would be nonsensical. + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.fill(50)(1) // no processing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(100) // strictly positive accumulation + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.fill(49)(Some(minRate))) + } + + test("with no accumulated or positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) + // prepare a series of batch updates, one every 20ms with an increasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail) + } + + test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) + // prepare a series of batch updates, one every 20ms with an decreasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term, + // asking for less and less elements + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail) + } + + test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { + val minRate = 10D + val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val rng = new Random() + val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000) + val procDelayMs = 20 + val proc = List.fill(50)(procDelayMs) // 20ms of processing + val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait + val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) + + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + forAll(List.range(1, 50)) { (n) => + res(n) should not be None + if (res(n).get > 0 && sched(n) > 0) { + res(n).get should be < speeds(n) + res(n).get should be >= minRate + } + } + } + + private def createDefaultEstimator(): PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D, 10) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala new file mode 100644 index 000000000000..995f1197ccdf --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import java.util.Properties + +import org.scalatest.Matchers + +import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ + +class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + + private def createJobStart( + batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { + val properties = new Properties() + properties.setProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, batchTime.milliseconds.toString) + properties.setProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, outputOpId.toString) + SparkListenerJobStart(jobId = jobId, + 0L, // unused + Nil, // unused + properties) + } + + override def batchDuration: Duration = Milliseconds(100) + + test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + + "onReceiverStarted, onReceiverError, onReceiverStopped") { + ssc = setupStreams(input, operation) + val listener = new StreamingJobProgressListener(ssc) + + val streamIdToInputInfo = Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) + + // onBatchSubmitted + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (0) + + // onBatchStarted + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (600) + + // onJobStart + val jobStart1 = createJobStart(Time(1000), outputOpId = 0, jobId = 0) + listener.onJobStart(jobStart1) + + val jobStart2 = createJobStart(Time(1000), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart2) + + val jobStart3 = createJobStart(Time(1000), outputOpId = 1, jobId = 0) + listener.onJobStart(jobStart3) + + val jobStart4 = createJobStart(Time(1000), outputOpId = 1, jobId = 1) + listener.onJobStart(jobStart4) + + val batchUIData = listener.getBatchUIData(Time(1000)) + batchUIData should not be None + batchUIData.get.batchTime should be (batchInfoStarted.batchTime) + batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) + batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) + batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) + batchUIData.get.streamIdToInputInfo should be (Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test")))) + batchUIData.get.numRecords should be(600) + batchUIData.get.outputOpIdSparkJobIdPairs should be + Seq(OutputOpIdAndSparkJobId(0, 0), + OutputOpIdAndSparkJobId(0, 1), + OutputOpIdAndSparkJobId(1, 0), + OutputOpIdAndSparkJobId(1, 1)) + + // onBatchCompleted + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (List(BatchUIData(batchInfoCompleted))) + listener.lastCompletedBatch should be (Some(BatchUIData(batchInfoCompleted))) + listener.numUnprocessedBatches should be (0) + listener.numTotalCompletedBatches should be (1) + listener.numTotalProcessedRecords should be (600) + listener.numTotalReceivedRecords should be (600) + + // onReceiverStarted + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") + listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (None) + + // onReceiverError + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") + listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (None) + + // onReceiverStopped + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") + listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (Some(receiverInfoStopped)) + listener.receiverInfo(3) should be (None) + } + + test("Remove the old completed batches when exceeding the limit") { + ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) + val listener = new StreamingJobProgressListener(ssc) + + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) + + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + + for(_ <- 0 until (limit + 10)) { + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + listener.retainedCompletedBatches.size should be (limit) + listener.numTotalCompletedBatches should be(limit + 10) + } + + test("out-of-order onJobStart and onBatchXXX") { + ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) + val listener = new StreamingJobProgressListener(ssc) + + // fulfill completedBatchInfos + for(i <- 0 until limit) { + val batchInfoCompleted = + BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart) + } + + // onJobStart happens before onBatchSubmitted + val jobStart = createJobStart(Time(1000 + limit * 100), outputOpId = 0, jobId = 0) + listener.onJobStart(jobStart) + + val batchInfoSubmitted = + BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + + // We still can see the info retrieved from onJobStart + val batchUIData = listener.getBatchUIData(Time(1000 + limit * 100)) + batchUIData should not be None + batchUIData.get.batchTime should be (batchInfoSubmitted.batchTime) + batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) + batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) + batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) + batchUIData.get.streamIdToInputInfo should be (Map.empty) + batchUIData.get.numRecords should be (0) + batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) + + // A lot of "onBatchCompleted"s happen before "onJobStart" + for(i <- limit + 1 to limit * 2) { + val batchInfoCompleted = + BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + for(i <- limit + 1 to limit * 2) { + val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart) + } + + // We should not leak memory + listener.batchTimeToOutputOpIdSparkJobIdPair.size() should be <= + (listener.waitingBatches.size + listener.runningBatches.size + + listener.retainedCompletedBatches.size + 10) + } + + test("detect memory leak") { + ssc = setupStreams(input, operation) + val listener = new StreamingJobProgressListener(ssc) + + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) + + for (_ <- 0 until 2 * limit) { + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) + + // onBatchSubmitted + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + + // onBatchStarted + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) + + // onJobStart + val jobStart1 = createJobStart(Time(1000), outputOpId = 0, jobId = 0) + listener.onJobStart(jobStart1) + + val jobStart2 = createJobStart(Time(1000), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart2) + + val jobStart3 = createJobStart(Time(1000), outputOpId = 1, jobId = 0) + listener.onJobStart(jobStart3) + + val jobStart4 = createJobStart(Time(1000), outputOpId = 1, jobId = 1) + listener.onJobStart(jobStart4) + + // onBatchCompleted + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + listener.waitingBatches.size should be (0) + listener.runningBatches.size should be (0) + listener.retainedCompletedBatches.size should be (limit) + listener.batchTimeToOutputOpIdSparkJobIdPair.size() should be <= + (listener.waitingBatches.size + listener.runningBatches.size + + listener.retainedCompletedBatches.size + 10) + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala new file mode 100644 index 000000000000..d3ca2b58f36c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import java.util.TimeZone +import java.util.concurrent.TimeUnit + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ + + test("shortTimeUnitString") { + assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) + assert("us" === UIUtils.shortTimeUnitString(TimeUnit.MICROSECONDS)) + assert("ms" === UIUtils.shortTimeUnitString(TimeUnit.MILLISECONDS)) + assert("sec" === UIUtils.shortTimeUnitString(TimeUnit.SECONDS)) + assert("min" === UIUtils.shortTimeUnitString(TimeUnit.MINUTES)) + assert("hrs" === UIUtils.shortTimeUnitString(TimeUnit.HOURS)) + assert("days" === UIUtils.shortTimeUnitString(TimeUnit.DAYS)) + } + + test("normalizeDuration") { + verifyNormalizedTime(900, TimeUnit.MILLISECONDS, 900) + verifyNormalizedTime(1.0, TimeUnit.SECONDS, 1000) + verifyNormalizedTime(1.0, TimeUnit.MINUTES, 60 * 1000) + verifyNormalizedTime(1.0, TimeUnit.HOURS, 60 * 60 * 1000) + verifyNormalizedTime(1.0, TimeUnit.DAYS, 24 * 60 * 60 * 1000) + } + + private def verifyNormalizedTime( + expectedTime: Double, expectedUnit: TimeUnit, input: Long): Unit = { + val (time, unit) = UIUtils.normalizeDuration(input) + time should be (expectedTime +- 1E-6) + unit should be (expectedUnit) + } + + test("convertToTimeUnit") { + verifyConvertToTimeUnit(60.0 * 1000 * 1000 * 1000, 60 * 1000, TimeUnit.NANOSECONDS) + verifyConvertToTimeUnit(60.0 * 1000 * 1000, 60 * 1000, TimeUnit.MICROSECONDS) + verifyConvertToTimeUnit(60 * 1000, 60 * 1000, TimeUnit.MILLISECONDS) + verifyConvertToTimeUnit(60, 60 * 1000, TimeUnit.SECONDS) + verifyConvertToTimeUnit(1, 60 * 1000, TimeUnit.MINUTES) + verifyConvertToTimeUnit(1.0 / 60, 60 * 1000, TimeUnit.HOURS) + verifyConvertToTimeUnit(1.0 / 60 / 24, 60 * 1000, TimeUnit.DAYS) + } + + private def verifyConvertToTimeUnit( + expectedTime: Double, milliseconds: Long, unit: TimeUnit): Unit = { + val convertedTime = UIUtils.convertToTimeUnit(milliseconds, unit) + convertedTime should be (expectedTime +- 1E-6) + } + + test("formatBatchTime") { + val tzForTest = TimeZone.getTimeZone("America/Los_Angeles") + val batchTime = 1431637480452L // Thu May 14 14:04:40 PDT 2015 + assert("2015/05/14 14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, timezone = tzForTest)) + assert("2015/05/14 14:04:40.452" === + UIUtils.formatBatchTime(batchTime, 999, timezone = tzForTest)) + assert("14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, false, timezone = tzForTest)) + assert("14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, false, timezone = tzForTest)) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f42..78fc344b0017 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7ce9499dc614..5e49fd00769a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,33 +18,39 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer +import java.util +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} +import scala.reflect.ClassTag -import WriteAheadLogSuite._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ +import org.scalatest.BeforeAndAfter -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} + +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { + + import WriteAheadLogSuite._ val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null var testFile: String = null - var manager: WriteAheadLogManager = null + var writeAheadLog: FileBasedWriteAheadLog = null before { tempDir = Utils.createTempDir() testDir = tempDir.toString testFile = new File(tempDir, "testFile").toString - if (manager != null) { - manager.stop() - manager = null + if (writeAheadLog != null) { + writeAheadLog.close() + writeAheadLog = null } } @@ -52,16 +58,60 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogWriter - writing data") { + test("WriteAheadLogUtils - log selection and creation") { + val logDir = Utils.createTempDir().getAbsolutePath() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](conf1) + assertReceiverLogClass[FileBasedWriteAheadLog](conf1) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("FileBasedWriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() val segments = writeDataUsingWriter(testFile, dataToWrite) val writtenData = readDataManually(segments) assert(writtenData === dataToWrite) } - test("WriteAheadLogWriter - syncing of data by writing and reading immediately") { + test("FileBasedWriteAheadLogWriter - syncing of data by writing and reading immediately") { val dataToWrite = generateRandomData() - val writer = new WriteAheadLogWriter(testFile, hadoopConf) + val writer = new FileBasedWriteAheadLogWriter(testFile, hadoopConf) dataToWrite.foreach { data => val segment = writer.write(stringToByteBuffer(data)) val dataRead = readDataManually(Seq(segment)).head @@ -70,10 +120,10 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { writer.close() } - test("WriteAheadLogReader - sequentially reading data") { + test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() writeDataManually(writtenData, testFile) - val reader = new WriteAheadLogReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) assert(reader.hasNext === false) @@ -83,14 +133,14 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { reader.close() } - test("WriteAheadLogReader - sequentially reading data written with writer") { + test("FileBasedWriteAheadLogReader - sequentially reading data written with writer") { val dataToWrite = generateRandomData() writeDataUsingWriter(testFile, dataToWrite) val readData = readDataUsingReader(testFile) assert(readData === dataToWrite) } - test("WriteAheadLogReader - reading data written with writer after corrupted write") { + test("FileBasedWriteAheadLogReader - reading data written with writer after corrupted write") { // Write data manually for testing the sequential reader val dataToWrite = generateRandomData() writeDataUsingWriter(testFile, dataToWrite) @@ -113,38 +163,38 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1))) } - test("WriteAheadLogRandomReader - reading data using random reader") { + test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() val segments = writeDataManually(writtenData, testFile) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten - val reader = new WriteAheadLogRandomReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf) writtenDataAndSegments.foreach { case (data, segment) => assert(data === byteBufferToString(reader.read(segment))) } reader.close() } - test("WriteAheadLogRandomReader - reading data using random reader written with writer") { + test("FileBasedWriteAheadLogRandomReader- reading data using random reader written with writer") { // Write data using writer for testing the random reader val data = generateRandomData() val segments = writeDataUsingWriter(testFile, data) // Read a random sequence of segments and verify read data val dataAndSegments = data.zip(segments).toSeq.permutations.take(10).flatten - val reader = new WriteAheadLogRandomReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf) dataAndSegments.foreach { case (data, segment) => assert(data === byteBufferToString(reader.read(segment))) } reader.close() } - test("WriteAheadLogManager - write rotating logs") { - // Write data using manager + test("FileBasedWriteAheadLog - write rotating logs") { + // Write data with rotation using WriteAheadLog class val dataToWrite = generateRandomData() - writeDataUsingManager(testDir, dataToWrite) + writeDataUsingWriteAheadLog(testDir, dataToWrite) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) @@ -153,8 +203,8 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(writtenData === dataToWrite) } - test("WriteAheadLogManager - read rotating logs") { - // Write data manually for testing reading through manager + test("FileBasedWriteAheadLog - read rotating logs") { + // Write data manually for testing reading through WriteAheadLog val writtenData = (1 to 10).map { i => val data = generateRandomData() val file = testDir + s"/log-$i-$i" @@ -167,25 +217,25 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(fileSystem.exists(logDirectoryPath) === true) // Read data using manager and verify - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(readData === writtenData) } - test("WriteAheadLogManager - recover past logs when creating new manager") { + test("FileBasedWriteAheadLog - recover past logs when creating new manager") { // Write data with manager, recover with new manager and verify val dataToWrite = generateRandomData() - writeDataUsingManager(testDir, dataToWrite) + writeDataUsingWriteAheadLog(testDir, dataToWrite) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(dataToWrite === readData) } - test("WriteAheadLogManager - cleanup old logs") { + test("FileBasedWriteAheadLog - clean old logs") { logCleanUpTest(waitForCompletion = false) } - test("WriteAheadLogManager - cleanup old logs synchronously") { + test("FileBasedWriteAheadLog - clean old logs synchronously") { logCleanUpTest(waitForCompletion = true) } @@ -193,11 +243,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Write data with manager, recover with new manager and verify val manualClock = new ManualClock val dataToWrite = generateRandomData() - manager = writeDataUsingManager(testDir, dataToWrite, manualClock, stopManager = false) + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - manager.cleanupOldLogs(manualClock.currentTime() / 2, waitForCompletion) + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) if (waitForCompletion) { assert(getLogFilesInDirectory(testDir).size < logFiles.size) @@ -208,24 +258,24 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } } - test("WriteAheadLogManager - handling file errors while reading rotating logs") { + test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { // Generate a set of log files val manualClock = new ManualClock val dataToWrite1 = generateRandomData() - writeDataUsingManager(testDir, dataToWrite1, manualClock) + writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) val logFiles1 = getLogFilesInDirectory(testDir) assert(logFiles1.size > 1) // Recover old files and generate a second set of log files val dataToWrite2 = generateRandomData() - manualClock.addToTime(100000) - writeDataUsingManager(testDir, dataToWrite2, manualClock) + manualClock.advance(100000) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) val logFiles2 = getLogFilesInDirectory(testDir) assert(logFiles2.size > logFiles1.size) // Read the files and verify that all the written data can be read - val readData1 = readDataUsingManager(testDir) + val readData1 = readDataUsingWriteAheadLog(testDir) assert(readData1 === (dataToWrite1 ++ dataToWrite2)) // Corrupt the first set of files so that they are basically unreadable @@ -236,25 +286,51 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(readData === dataToWrite2) } + + test("FileBasedWriteAheadLog - do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile) + val wal = new FileBasedWriteAheadLog( + new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + wal.read(writtenSegment.head) + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + } } object WriteAheadLogSuite { + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() + + private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[WriteAheadLogFileSegment] = { - val segments = new ArrayBuffer[WriteAheadLogFileSegment]() + def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) data.foreach { item => val offset = writer.getPos val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) - segments += WriteAheadLogFileSegment(file, offset, bytes.size) + segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } writer.close() segments @@ -263,8 +339,11 @@ object WriteAheadLogSuite { /** * Write data to a file using the writer class and return an array of the file segments written. */ - def writeDataUsingWriter(filePath: String, data: Seq[String]): Seq[WriteAheadLogFileSegment] = { - val writer = new WriteAheadLogWriter(filePath, hadoopConf) + def writeDataUsingWriter( + filePath: String, + data: Seq[String] + ): Seq[FileBasedWriteAheadLogSegment] = { + val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) } @@ -272,27 +351,27 @@ object WriteAheadLogSuite { segments } - /** Write data to rotating files in log directory using the manager class. */ - def writeDataUsingManager( + /** Write data to rotating files in log directory using the WriteAheadLog class. */ + def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], manualClock: ManualClock = new ManualClock, - stopManager: Boolean = true - ): WriteAheadLogManager = { - if (manualClock.currentTime < 100000) manualClock.setTime(10000) - val manager = new WriteAheadLogManager(logDirectory, hadoopConf, - rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock) + closeLog: Boolean = true + ): FileBasedWriteAheadLog = { + if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => - manualClock.addToTime(500) - manager.writeToLog(item) + manualClock.advance(500) + wal.write(item, manualClock.getTimeMillis()) } - if (stopManager) manager.stop() - manager + if (closeLog) wal.close() + wal } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ - def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { + /** Read data from a segments of a log file directly and return the list of byte buffers. */ + def readDataManually(segments: Seq[FileBasedWriteAheadLogSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) try { @@ -331,18 +410,17 @@ object WriteAheadLogSuite { /** Read all the data from a log file using reader class and return the list of byte buffers. */ def readDataUsingReader(file: String): Seq[String] = { - val reader = new WriteAheadLogReader(file, hadoopConf) + val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) val readData = reader.toList.map(byteBufferToString) reader.close() readData } - /** Read all the data in the log file in a directory using the manager class. */ - def readDataUsingManager(logDirectory: String): Seq[String] = { - val manager = new WriteAheadLogManager(logDirectory, hadoopConf, - callerName = "WriteAheadLogSuite") - val data = manager.readFromLog().map(byteBufferToString).toSeq - manager.stop() + /** Read all the data in the log file in a directory using the WriteAheadLog class. */ + def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + val data = wal.readAll().asScala.map(byteBufferToString).toSeq + wal.close() data } diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74..d66750463033 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = { diff --git a/tools/pom.xml b/tools/pom.xml index e7419ed2c607..298ee2348b58 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -76,10 +76,6 @@ org.apache.maven.plugins maven-source-plugin - - org.codehaus.mojo - build-helper-maven-plugin - diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 595ded6ae67f..a0524cabff2d 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -15,13 +15,14 @@ * limitations under the License. */ +// scalastyle:off classforname package org.apache.spark.tools import java.io.File import java.util.jar.JarFile import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} import scala.util.Try @@ -92,7 +93,9 @@ object GenerateMIMAIgnore { ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { + // scalastyle:off println case _: Throwable => println("Error instrumenting class:" + className) + // scalastyle:on println } } (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) @@ -108,7 +111,9 @@ object GenerateMIMAIgnore { .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) } catch { case t: Throwable => + // scalastyle:off println println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + // scalastyle:on println Seq.empty[String] } } @@ -128,12 +133,14 @@ object GenerateMIMAIgnore { getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-class-excludes") .writeAll(previousContents + privateClasses.mkString("\n")) + // scalastyle:off println println("Created : .generated-mima-class-excludes in current directory.") val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) .getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-member-excludes").writeAll(previousMembersContents + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") + // scalastyle:on println } @@ -154,7 +161,7 @@ object GenerateMIMAIgnore { val path = packageName.replace('.', '/') val resources = classLoader.getResources(path) - val jars = resources.filter(x => x.getProtocol == "jar") + val jars = resources.asScala.filter(_.getProtocol == "jar") .map(_.getFile.split(":")(1).split("!")(0)).toSeq jars.flatMap(getClassesFromJar(_, path)) @@ -168,15 +175,18 @@ object GenerateMIMAIgnore { private def getClassesFromJar(jarPath: String, packageName: String) = { import scala.collection.mutable val jar = new JarFile(new File(jarPath)) - val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) + val enums = jar.entries().asScala.map(_.getName).filter(_.startsWith(packageName)) val classes = mutable.HashSet[Class[_]]() for (entry <- enums if entry.endsWith(".class")) { try { classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) } catch { + // scalastyle:off println case _: Throwable => println("Unable to load:" + entry) + // scalastyle:on println } } classes } } +// scalastyle:on classforname diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 8d0f09933c8d..856ea177a9a1 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.tools -import java.lang.reflect.Method +import java.lang.reflect.{Type, Method} import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -302,7 +302,7 @@ object JavaAPICompletenessChecker { private def isExcludedByInterface(method: Method): Boolean = { val excludedInterfaces = Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method) = + def toComparisionKey(method: Method): (Class[_], String, Type) = (method.getReturnType, method.getName, method.getGenericReturnType) val interfaces = method.getDeclaringClass.getInterfaces.filter { i => excludedInterfaces.contains(i.getName) @@ -323,11 +323,14 @@ object JavaAPICompletenessChecker { val missingMethods = javaEquivalents -- javaMethods for (method <- missingMethods) { + // scalastyle:off println println(method) + // scalastyle:on println } } def main(args: Array[String]) { + // scalastyle:off println println("Missing RDD methods") printMissingMethods(classOf[RDD[_]], classOf[JavaRDD[_]]) println() @@ -359,5 +362,6 @@ object JavaAPICompletenessChecker { println("Missing PairDStream methods") printMissingMethods(classOf[PairDStreamFunctions[_, _]], classOf[JavaPairDStream[_, _]]) println() + // scalastyle:on println } } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 15ee95070a3d..0dc2861253f1 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Writes simulated shuffle output from several threads and records the observed throughput. */ object StoragePerfTester { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) @@ -46,7 +46,8 @@ object StoragePerfTester { val totalRecords = dataSizeMb * 1000 val recordsPerMap = totalRecords / numMaps - val writeData = "1" * recordLength + val writeKey = "1" * (recordLength / 2) + val writeValue = "1" * (recordLength / 2) val executor = Executors.newFixedThreadPool(numMaps) val conf = new SparkConf() @@ -58,12 +59,12 @@ object StoragePerfTester { val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { + val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { - writers(i % numOutputSplits).write(writeData) + writers(i % numOutputSplits).write(writeKey, writeValue) } writers.map { w => w.commitAndClose() @@ -78,13 +79,15 @@ object StoragePerfTester { val totalBytes = new AtomicLong() for (task <- 1 to numMaps) { executor.submit(new Runnable() { - override def run() = { + override def run(): Unit = { try { writeOutputBytes(task, totalBytes) latch.countDown() } catch { case e: Exception => + // scalastyle:off println println("Exception in child thread: " + e + " " + e.getMessage) + // scalastyle:on println System.exit(1) } } @@ -96,9 +99,11 @@ object StoragePerfTester { val bytesPerSecond = totalBytes.get() / time val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + // scalastyle:off println System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + // scalastyle:on println executor.shutdown() sc.stop() diff --git a/tox.ini b/tox.ini index b568029a204c..76e3f42cde62 100644 --- a/tox.ini +++ b/tox.ini @@ -15,4 +15,4 @@ [pep8] max-line-length=100 -exclude=cloudpickle.py,heapq3.py +exclude=cloudpickle.py,heapq3.py,shared.py diff --git a/unsafe/pom.xml b/unsafe/pom.xml new file mode 100644 index 000000000000..89475ee3cf5a --- /dev/null +++ b/unsafe/pom.xml @@ -0,0 +1,112 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-unsafe_2.10 + jar + Spark Project Unsafe + http://spark.apache.org/ + + unsafe + + + + + + + com.google.code.findbugs + jsr305 + + + com.google.guava + guava + + + + + org.slf4j + slf4j-api + provided + + + + + junit + junit + test + + + com.novocode + junit-interface + test + + + org.mockito + mockito-core + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.commons + commons-lang3 + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + net.alchim31.maven + scala-maven-plugin + + + + -XDignore.symbol.file + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + -XDignore.symbol.file + + + + + + + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java new file mode 100644 index 000000000000..5c9d5d9a3831 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.io.IOException; + +public abstract class KVIterator { + + public abstract boolean next() throws IOException; + + public abstract K getKey(); + + public abstract V getValue(); + + public abstract void close(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java new file mode 100644 index 000000000000..1c16da982923 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java new file mode 100644 index 000000000000..cf42877bf9fd --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.Platform; + +public class ByteArrayMethods { + + private ByteArrayMethods() { + // Private constructor, since this class only contains static methods. + } + + /** Returns the next number greater or equal num that is power of 2. */ + public static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } + + public static int roundNumberOfBytesToNearestWord(int numBytes) { + int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + if (remainder == 0) { + return numBytes; + } else { + return numBytes + (8 - remainder); + } + } + + /** + * Optimized byte array equality check for byte arrays. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEquals( + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + int i = 0; + while (i <= length - 8) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { + return false; + } + i += 8; + } + while (i < length) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { + return false; + } + i += 1; + } + return true; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java new file mode 100644 index 000000000000..74105050e419 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * An array of long values. Compared with native JVM arrays, this: + *
      + *
    • supports using both in-heap and off-heap memory
    • + *
    • has no bound checking, and thus can crash the JVM process when assert is turned off
    • + *
    + */ +public final class LongArray { + + // This is a long so that we perform long multiplications when computing offsets. + private static final long WIDTH = 8; + + private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; + + private final long length; + + public LongArray(MemoryBlock memory) { + assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; + assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; + this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); + this.length = memory.size() / WIDTH; + } + + public MemoryBlock memoryBlock() { + return memory; + } + + /** + * Returns the number of elements this array can hold. + */ + public long size() { + return length; + } + + /** + * Sets the value at position {@code index}. + */ + public void set(int index, long value) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + } + + /** + * Returns the value at position {@code index}. + */ + public long get(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + return Platform.getLong(baseObj, baseOffset + index * WIDTH); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java new file mode 100644 index 000000000000..7c124173b0bb --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * A fixed size uncompressed bit set backed by a {@link LongArray}. + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSet { + + /** A long array for the bits. */ + private final LongArray words; + + /** Length of the long array. */ + private final int numWords; + + private final Object baseObject; + private final long baseOffset; + + /** + * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be + * multiple of 8 bytes (i.e. 64 bits). + */ + public BitSet(MemoryBlock memory) { + words = new LongArray(memory); + assert (words.size() <= Integer.MAX_VALUE); + numWords = (int) words.size(); + baseObject = words.memoryBlock().getBaseObject(); + baseOffset = words.memoryBlock().getBaseOffset(); + } + + public MemoryBlock memoryBlock() { + return words.memoryBlock(); + } + + /** + * Returns the number of bits in this {@code BitSet}. + */ + public long capacity() { + return numWords * 64; + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public void set(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.set(baseObject, baseOffset, index); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public void unset(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.unset(baseObject, baseOffset, index); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public boolean isSet(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + return BitSetMethods.isSet(baseObject, baseOffset, index); + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

    + * To iterate over the true bits in a BitSet, use the following loop: + *

    +   * 
    +   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
    +   *    // operate on index i here
    +   *  }
    +   * 
    +   * 
    + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + public int nextSetBit(int fromIndex) { + return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); + } + + /** + * Returns {@code true} if any bit is set. + */ + public boolean anySet() { + return BitSetMethods.anySet(baseObject, baseOffset, numWords); + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java new file mode 100644 index 000000000000..7857bf66a72a --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.Platform; + +/** + * Methods for working with fixed-size uncompressed bitsets. + * + * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length). + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSetMethods { + + private static final long WORD_SIZE = 8; + + private BitSetMethods() { + // Make the default constructor private, since this only holds static methods. + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public static void set(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word | mask); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public static void unset(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word & ~mask); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public static boolean isSet(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = Platform.getLong(baseObject, wordOffset); + return (word & mask) != 0; + } + + /** + * Returns {@code true} if any bit is set. + */ + public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { + long addr = baseOffset; + for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { + if (Platform.getLong(baseObject, addr) != 0) { + return true; + } + } + return false; + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

    + * To iterate over the true bits in a BitSet, use the following loop: + *

    +   * 
    +   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
    +   *    // operate on index i here
    +   *  }
    +   * 
    +   * 
    + * + * @param fromIndex the index to start checking from (inclusive) + * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words + * @return the index of the next set bit, or -1 if there is no such bit + */ + public static int nextSetBit( + Object baseObject, + long baseOffset, + int fromIndex, + int bitsetSizeInWords) { + int wi = fromIndex >> 6; + if (wi >= bitsetSizeInWords) { + return -1; + } + + // Try to find the next set bit in the current word + final int subIndex = fromIndex & 0x3f; + long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + if (word != 0) { + return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); + } + + // Find the next set bit in the rest of the words + wi += 1; + while (wi < bitsetSizeInWords) { + word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE); + if (word != 0) { + return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); + } + wi += 1; + } + + return -1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java new file mode 100644 index 000000000000..4276f25c2165 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import org.apache.spark.unsafe.Platform; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +public final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = seed; + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = Platform.getInt(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + public int hashLong(long input) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java new file mode 100644 index 000000000000..cbbe8594627a --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import javax.annotation.concurrent.GuardedBy; + +/** + * Manages memory for an executor. Individual operators / tasks allocate memory through + * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. + */ +public class ExecutorMemoryManager { + + /** + * Allocator, exposed for enabling untracked allocations of temporary data structures. + */ + public final MemoryAllocator allocator; + + /** + * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. + */ + final boolean inHeap; + + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap>>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + + /** + * Construct a new ExecutorMemoryManager. + * + * @param allocator the allocator that will be used + */ + public ExecutorMemoryManager(MemoryAllocator allocator) { + this.inHeap = allocator instanceof HeapMemoryAllocator; + this.allocator = allocator; + } + + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + // At some point, we should explore supporting pooling for off-heap memory, but for now we'll + // ignore that case in the interest of simplicity. + return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError { + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + return allocator.allocate(size); + } else { + return allocator.allocate(size); + } + } + + void free(MemoryBlock memory) { + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference(memory)); + } + } else { + allocator.free(memory); + } + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java new file mode 100644 index 000000000000..6722301df19d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. + */ +public class HeapMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + if (size % 8 != 0) { + throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); + } + long[] array = new long[(int) (size / 8)]; + return MemoryBlock.fromLongArray(array); + } + + @Override + public void free(MemoryBlock memory) { + // Do nothing + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java new file mode 100644 index 000000000000..5192f68c862c --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +public interface MemoryAllocator { + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError; + + void free(MemoryBlock memory); + + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + + MemoryAllocator HEAP = new HeapMemoryAllocator(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java new file mode 100644 index 000000000000..dd7582083437 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class MemoryBlock extends MemoryLocation { + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * MemoryManager. This is package-private and is modified by MemoryManager. + */ + int pageNumber = -1; + + public MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java new file mode 100644 index 000000000000..74ebc87dc978 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java new file mode 100644 index 000000000000..97b2c93f0dc3 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import java.util.*; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages the memory allocated by an individual task. + *

    + * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + *

    + * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + *

    + * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is + * approximately 35 terabytes of memory. + */ +public class TaskMemoryManager { + + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * size is limited by the maximum amount of data that can be stored in a long[] array, which is + * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both in- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an in-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + /** + * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean + * up leaked memory. + */ + private final HashSet allocatedNonPageMemory = new HashSet(); + + private final ExecutorMemoryManager executorMemoryManager; + + /** + * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + private final boolean inHeap; + + /** + * Construct a new MemoryManager. + */ + public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { + this.inHeap = executorMemoryManager.inHeap; + this.executorMemoryManager = executorMemoryManager; + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of memory that will be shared between operators. + */ + public MemoryBlock allocatePage(long size) { + if (size > MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + final MemoryBlock page = executorMemoryManager.allocate(size); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + */ + public void freePage(MemoryBlock page) { + assert (page.pageNumber != -1) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + // Cannot access a page once it's freed. + executorMemoryManager.free(page); + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended + * to be used for allocating operators' internal data structures. For data pages that you want to + * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since + * that will enable intra-memory pointers (see + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's + * top-level Javadoc for more details). + */ + public MemoryBlock allocate(long size) throws OutOfMemoryError { + assert(size > 0) : "Size must be positive, but got " + size; + final MemoryBlock memory = executorMemoryManager.allocate(size); + synchronized(allocatedNonPageMemory) { + allocatedNonPageMemory.add(memory); + } + return memory; + } + + /** + * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. + */ + public void free(MemoryBlock memory) { + assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; + executorMemoryManager.free(memory); + synchronized(allocatedNonPageMemory) { + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); + } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (inHeap) { + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + assert (page.getBaseObject() != null); + return page.getBaseObject(); + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); + if (inHeap) { + return offsetInPage; + } else { + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; + for (MemoryBlock page : pageTable) { + if (page != null) { + freedBytes += page.size(); + freePage(page); + } + } + + synchronized (allocatedNonPageMemory) { + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } + } + return freedBytes; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java new file mode 100644 index 000000000000..cda7826c8c99 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +/** + * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. + */ +public class UnsafeMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + if (size % 8 != 0) { + throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); + } + long address = Platform.allocateMemory(size); + return new MemoryBlock(null, address, size); + } + + @Override + public void free(MemoryBlock memory) { + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + Platform.freeMemory(memory.offset); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java new file mode 100644 index 000000000000..c08c9c73d239 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.unsafe.Platform; + +public class ByteArray { + + /** + * Writes the content of a byte array into a memory address, identified by an object and an + * offset. The target memory address must already been allocated, and have enough space to + * hold all the bytes in this string. + */ + public static void writeToMemory(byte[] src, Object target, long targetOffset) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java new file mode 100644 index 000000000000..30e175807636 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import java.io.Serializable; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * The internal representation of interval type. + */ +public final class CalendarInterval implements Serializable { + public static final long MICROS_PER_MILLI = 1000L; + public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; + public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; + public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60; + public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; + public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + + /** + * A function to generate regex which matches interval string's unit part like "3 years". + * + * First, we can leave out some units in interval string, and we only care about the value of + * unit, so here we use non-capturing group to wrap the actual regex. + * At the beginning of the actual regex, we should match spaces before the unit part. + * Next is the number part, starts with an optional "-" to represent negative value. We use + * capturing group to wrap this part as we need the value later. + * Finally is the unit name, ends with an optional "s". + */ + private static String unitRegex(String unit) { + return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?"; + } + + private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") + + unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + + private static Pattern yearMonthPattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + + private static Pattern dayTimePattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); + + private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); + + private static long toLong(String s) { + if (s == null) { + return 0; + } else { + return Long.valueOf(s); + } + } + + public static CalendarInterval fromString(String s) { + if (s == null) { + return null; + } + s = s.trim(); + Matcher m = p.matcher(s); + if (!m.matches() || s.equals("interval")) { + return null; + } else { + long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); + long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; + microseconds += toLong(m.group(4)) * MICROS_PER_DAY; + microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; + microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; + microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; + microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; + microseconds += toLong(m.group(9)); + return new CalendarInterval((int) months, microseconds); + } + } + + public static long toLongWithRange(String fieldName, + String s, long minValue, long maxValue) throws IllegalArgumentException { + long result = 0; + if (s != null) { + result = Long.valueOf(s); + if (result < minValue || result > maxValue) { + throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", + fieldName, result, minValue, maxValue)); + } + } + return result; + } + + /** + * Parse YearMonth string in form: [-]YYYY-MM + * + * adapted from HiveIntervalYearMonth.valueOf + */ + public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval year-month string was null"); + } + s = s.trim(); + Matcher m = yearMonthPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match year-month format of 'y-m': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); + int months = (int) toLongWithRange("month", m.group(3), 0, 11); + result = new CalendarInterval(sign * (years * 12 + months), 0); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval year-month string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn + * + * adapted from HiveIntervalDayTime.valueOf + */ + public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval day-time string was null"); + } + s = s.trim(); + Matcher m = dayTimePattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE); + long hours = toLongWithRange("hour", m.group(3), 0, 23); + long minutes = toLongWithRange("minute", m.group(4), 0, 59); + long seconds = toLongWithRange("second", m.group(5), 0, 59); + // Hive allow nanosecond precision interval + long nanos = toLongWithRange("nanosecond", m.group(7), 0L, 999999999L); + result = new CalendarInterval(0, sign * ( + days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + + seconds * MICROS_PER_SECOND + nanos / 1000L)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval day-time string: " + e.getMessage(), e); + } + } + return result; + } + + public static CalendarInterval fromSingleUnitString(String unit, String s) + throws IllegalArgumentException { + + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); + } + s = s.trim(); + Matcher m = quoteTrimPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + if (unit.equals("year")) { + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + + } else if (unit.equals("month")) { + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + + } else if (unit.equals("day")) { + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + + } else if (unit.equals("hour")) { + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + + } else if (unit.equals("minute")) { + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + + } else if (unit.equals("second")) { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + } + } catch (Exception e) { + throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse second_nano string in ss.nnnnnnnnn format to microseconds + */ + public static long parseSecondNano(String secondNano) throws IllegalArgumentException { + String[] parts = secondNano.split("\\."); + if (parts.length == 1) { + return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, + Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; + + } else if (parts.length == 2) { + long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], + Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); + long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); + return seconds * MICROS_PER_SECOND + nanos / 1000L; + + } else { + throw new IllegalArgumentException( + "Interval string does not match second-nano format of ss.nnnnnnnnn"); + } + } + + public final int months; + public final long microseconds; + + public CalendarInterval(int months, long microseconds) { + this.months = months; + this.microseconds = microseconds; + } + + public CalendarInterval add(CalendarInterval that) { + int months = this.months + that.months; + long microseconds = this.microseconds + that.microseconds; + return new CalendarInterval(months, microseconds); + } + + public CalendarInterval subtract(CalendarInterval that) { + int months = this.months - that.months; + long microseconds = this.microseconds - that.microseconds; + return new CalendarInterval(months, microseconds); + } + + public CalendarInterval negate() { + return new CalendarInterval(-this.months, -this.microseconds); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof CalendarInterval)) return false; + + CalendarInterval o = (CalendarInterval) other; + return this.months == o.months && this.microseconds == o.microseconds; + } + + @Override + public int hashCode() { + return 31 * months + (int) microseconds; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("interval"); + + if (months != 0) { + appendUnit(sb, months / 12, "year"); + appendUnit(sb, months % 12, "month"); + } + + if (microseconds != 0) { + long rest = microseconds; + appendUnit(sb, rest / MICROS_PER_WEEK, "week"); + rest %= MICROS_PER_WEEK; + appendUnit(sb, rest / MICROS_PER_DAY, "day"); + rest %= MICROS_PER_DAY; + appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); + rest %= MICROS_PER_HOUR; + appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); + rest %= MICROS_PER_MINUTE; + appendUnit(sb, rest / MICROS_PER_SECOND, "second"); + rest %= MICROS_PER_SECOND; + appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); + rest %= MICROS_PER_MILLI; + appendUnit(sb, rest, "microsecond"); + } + + return sb.toString(); + } + + private void appendUnit(StringBuilder sb, long value, String unit) { + if (value != 0) { + sb.append(" " + value + " " + unit + "s"); + } + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java new file mode 100644 index 000000000000..cbcab958c05a --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -0,0 +1,1000 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import javax.annotation.Nonnull; +import java.io.*; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.Map; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; + +import static org.apache.spark.unsafe.Platform.*; + + +/** + * A UTF-8 String for internal Spark use. + *

    + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + *

    + * Note: This is not designed for general use cases, should not be used outside SQL. + */ +public final class UTF8String implements Comparable, Externalizable { + + // These are only updated by readExternal() + @Nonnull + private Object base; + private long offset; + private int numBytes; + + public Object getBaseObject() { return base; } + public long getBaseOffset() { return offset; } + + private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6}; + + private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + + private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); + + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + } else { + return null; + } + } + + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + } else { + return null; + } + } + + /** + * Creates an UTF8String from given address (base and offset) and length. + */ + public static UTF8String fromAddress(Object base, long offset, int numBytes) { + if (base != null) { + return new UTF8String(base, offset, numBytes); + } else { + return null; + } + } + + /** + * Creates an UTF8String from String. + */ + public static UTF8String fromString(String str) { + if (str == null) return null; + try { + return fromBytes(str.getBytes("utf-8")); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + throwException(e); + return null; + } + } + + /** + * Creates an UTF8String that contains `length` spaces. + */ + public static UTF8String blankString(int length) { + byte[] spaces = new byte[length]; + Arrays.fill(spaces, (byte) ' '); + return fromBytes(spaces); + } + + protected UTF8String(Object base, long offset, int numBytes) { + this.base = base; + this.offset = offset; + this.numBytes = numBytes; + } + + // for serialization + public UTF8String() { + this(null, 0, 0); + } + + /** + * Writes the content of this string into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + Platform.copyMemory(base, offset, target, targetOffset, numBytes); + } + + /** + * Returns the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + private static int numBytesForFirstByte(final byte b) { + final int offset = (b & 0xFF) - 192; + return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + } + + /** + * Returns the number of bytes + */ + public int numBytes() { + return numBytes; + } + + /** + * Returns the number of code points in it. + */ + public int numChars() { + int len = 0; + for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) { + len += 1; + } + return len; + } + + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public long getPrefix() { + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. + // If size is 0, just return 0. + // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the string. + long p; + long mask = 0; + if (isLittleEndian) { + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = (long) Platform.getInt(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); + } else { + // byteOrder == ByteOrder.BIG_ENDIAN + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = ((long) Platform.getInt(base, offset)) << 32; + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + } + p &= ~mask; + return p; + } + + /** + * Returns the underline bytes, will be a copy of it if it's part of another array. + */ + public byte[] getBytes() { + // avoid copy if `base` is `byte[]` + if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] + && ((byte[]) base).length == numBytes) { + return (byte[]) base; + } else { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return bytes; + } + } + + /** + * Returns a substring of this. + * @param start the position of first code point + * @param until the position after last code point, exclusive. + */ + public UTF8String substring(final int start, final int until) { + if (until <= start || start >= numBytes) { + return EMPTY_UTF8; + } + + int i = 0; + int c = 0; + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + int j = i; + while (i < numBytes && c < until) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + if (i > j) { + byte[] bytes = new byte[i - j]; + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + return fromBytes(bytes); + } else { + return EMPTY_UTF8; + } + } + + public UTF8String substringSQL(int pos, int length) { + // Information regarding the pos calculation: + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + int len = numChars(); + int start = (pos > 0) ? pos -1 : ((pos < 0) ? len + pos : 0); + int end = (length == Integer.MAX_VALUE) ? len : start + length; + return substring(start, end); + } + + /** + * Returns whether this contains `substring` or not. + */ + public boolean contains(final UTF8String substring) { + if (substring.numBytes == 0) { + return true; + } + + byte first = substring.getByte(0); + for (int i = 0; i <= numBytes - substring.numBytes; i++) { + if (getByte(i) == first && matchAt(substring, i)) { + return true; + } + } + return false; + } + + /** + * Returns the byte at position `i`. + */ + private byte getByte(int i) { + return Platform.getByte(base, offset + i); + } + + private boolean matchAt(final UTF8String s, int pos) { + if (s.numBytes + pos > numBytes || pos < 0) { + return false; + } + return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + } + + public boolean startsWith(final UTF8String prefix) { + return matchAt(prefix, 0); + } + + public boolean endsWith(final UTF8String suffix) { + return matchAt(suffix, numBytes - suffix.numBytes); + } + + /** + * Returns the upper case of this string + */ + public UTF8String toUpperCase() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte[] bytes = new byte[numBytes]; + bytes[0] = (byte) Character.toTitleCase(getByte(0)); + for (int i = 0; i < numBytes; i++) { + byte b = getByte(i); + if (numBytesForFirstByte(b) != 1) { + // fallback + return toUpperCaseSlow(); + } + int upper = Character.toUpperCase((int) b); + if (upper > 127) { + // fallback + return toUpperCaseSlow(); + } + bytes[i] = (byte) upper; + } + return fromBytes(bytes); + } + + private UTF8String toUpperCaseSlow() { + return fromString(toString().toUpperCase()); + } + + /** + * Returns the lower case of this string + */ + public UTF8String toLowerCase() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte[] bytes = new byte[numBytes]; + bytes[0] = (byte) Character.toTitleCase(getByte(0)); + for (int i = 0; i < numBytes; i++) { + byte b = getByte(i); + if (numBytesForFirstByte(b) != 1) { + // fallback + return toLowerCaseSlow(); + } + int lower = Character.toLowerCase((int) b); + if (lower > 127) { + // fallback + return toLowerCaseSlow(); + } + bytes[i] = (byte) lower; + } + return fromBytes(bytes); + } + + private UTF8String toLowerCaseSlow() { + return fromString(toString().toLowerCase()); + } + + /** + * Returns the title case of this string, that could be used as title. + */ + public UTF8String toTitleCase() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte[] bytes = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) { + byte b = getByte(i); + if (i == 0 || getByte(i - 1) == ' ') { + if (numBytesForFirstByte(b) != 1) { + // fallback + return toTitleCaseSlow(); + } + int upper = Character.toTitleCase(b); + if (upper > 127) { + // fallback + return toTitleCaseSlow(); + } + bytes[i] = (byte) upper; + } else { + bytes[i] = b; + } + } + return fromBytes(bytes); + } + + private UTF8String toTitleCaseSlow() { + StringBuffer sb = new StringBuffer(); + String s = toString(); + sb.append(s); + sb.setCharAt(0, Character.toTitleCase(sb.charAt(0))); + for (int i = 1; i < s.length(); i++) { + if (sb.charAt(i - 1) == ' ') { + sb.setCharAt(i, Character.toTitleCase(sb.charAt(i))); + } + } + return fromString(sb.toString()); + } + + /* + * Returns the index of the string `match` in this String. This string has to be a comma separated + * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, + * 0 will be returned, else the index of match (1-based index) + */ + public int findInSet(UTF8String match) { + if (match.contains(COMMA_UTF8)) { + return 0; + } + + int n = 1, lastComma = -1; + for (int i = 0; i < numBytes; i++) { + if (getByte(i) == (byte) ',') { + if (i - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + lastComma = i; + n++; + } + } + if (numBytes - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + return 0; + } + + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] result = new byte[this.numBytes]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <= 0) { + return EMPTY_UTF8; + } + + byte[] newBytes = new byte[numBytes * times]; + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while (i < numBytes); + + return -1; + } + + /** + * Find the `str` from left to right. + */ + private int find(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start <= numBytes - str.numBytes) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; + } + start += 1; + } + return -1; + } + + /** + * Find the `str` from right to left. + */ + private int rfind(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start >= 0) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; + } + start -= 1; + } + return -1; + } + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. subStringIndex performs a case-sensitive match when searching for delim. + */ + public UTF8String subStringIndex(UTF8String delim, int count) { + if (delim.numBytes == 0 || count == 0) { + return EMPTY_UTF8; + } + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = find(delim, idx + 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx == 0) { + return EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + return fromBytes(bytes); + + } else { + int idx = numBytes - delim.numBytes + 1; + count = -count; + while (count > 0) { + idx = rfind(delim, idx - 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx + delim.numBytes == numBytes) { + return EMPTY_UTF8; + } + int size = numBytes - delim.numBytes - idx; + byte[] bytes = new byte[size]; + copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + return fromBytes(bytes); + } + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0 || pad.numBytes() == 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + int offset = this.numBytes; + int idx = 0; + while (idx < count) { + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + ++ idx; + offset += pad.numBytes; + } + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0 || pad.numBytes() == 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + while (idx < count) { + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + ++ idx; + offset += pad.numBytes; + } + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + offset += remain.numBytes; + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + + /** + * Concatenates input strings together into a single string. Returns null if any input is null. + */ + public static UTF8String concat(UTF8String... inputs) { + // Compute the total length of the result. + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].numBytes; + } else { + return null; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it. + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].numBytes; + copyMemory( + inputs[i].base, inputs[i].offset, + result, BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return fromBytes(result); + } + + /** + * Concatenates input strings together into a single string using the separator. + * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c". + */ + public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { + if (separator == null) { + return null; + } + + int numInputBytes = 0; // total number of bytes from the inputs + int numInputs = 0; // number of non-null inputs + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + numInputBytes += inputs[i].numBytes; + numInputs++; + } + } + + if (numInputs == 0) { + // Return an empty string if there is no input, or all the inputs are null. + return fromBytes(new byte[0]); + } + + // Allocate a new byte array, and copy the inputs one by one into it. + // The size of the new array is the size of all inputs, plus the separators. + final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + int len = inputs[i].numBytes; + copyMemory( + inputs[i].base, inputs[i].offset, + result, BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + copyMemory( + separator.base, separator.offset, + result, BYTE_ARRAY_OFFSET + offset, + separator.numBytes); + offset += separator.numBytes; + } + } + } + return fromBytes(result); + } + + public UTF8String[] split(UTF8String pattern, int limit) { + String[] splits = toString().split(pattern.toString(), limit); + UTF8String[] res = new UTF8String[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = fromString(splits[i]); + } + return res; + } + + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes + public UTF8String translate(Map dict) { + String srcStr = this.toString(); + + StringBuilder sb = new StringBuilder(); + for(int k = 0; k< srcStr.length(); k++) { + if (null == dict.get(srcStr.charAt(k))) { + sb.append(srcStr.charAt(k)); + } else if ('\0' != dict.get(srcStr.charAt(k))){ + sb.append(dict.get(srcStr.charAt(k))); + } + } + return fromString(sb.toString()); + } + + @Override + public String toString() { + try { + return new String(getBytes(), "utf-8"); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + throwException(e); + return "unknown"; // we will never reach here. + } + } + + @Override + public UTF8String clone() { + return fromBytes(getBytes()); + } + + @Override + public int compareTo(@Nonnull final UTF8String other) { + int len = Math.min(numBytes, other.numBytes); + // TODO: compare 8 bytes as unsigned long + for (int i = 0; i < len; i ++) { + // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. + int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF); + if (res != 0) { + return res; + } + } + return numBytes - other.numBytes; + } + + public int compare(final UTF8String other) { + return compareTo(other); + } + + @Override + public boolean equals(final Object other) { + if (other instanceof UTF8String) { + UTF8String o = (UTF8String) other; + if (numBytes != o.numBytes) { + return false; + } + return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + } else { + return false; + } + } + + /** + * Levenshtein distance is a metric for measuring the distance of two strings. The distance is + * defined by the minimum number of single-character edits (i.e. insertions, deletions or + * substitutions) that are required to change one of the strings into the other. + */ + public int levenshteinDistance(UTF8String other) { + // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance + + int n = numChars(); + int m = other.numChars(); + + if (n == 0) { + return m; + } else if (m == 0) { + return n; + } + + UTF8String s, t; + + if (n <= m) { + s = this; + t = other; + } else { + s = other; + t = this; + int swap; + swap = n; + n = m; + m = swap; + } + + int p[] = new int[n + 1]; + int d[] = new int[n + 1]; + int swap[]; + + int i, i_bytes, j, j_bytes, num_bytes_j, cost; + + for (i = 0; i <= n; i++) { + p[i] = i; + } + + for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) { + num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes)); + d[0] = j + 1; + + for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) { + if (s.getByte(i_bytes) != t.getByte(j_bytes) || + num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { + cost = 1; + } else { + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + } + d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); + } + + swap = p; + p = d; + d = swap; + } + + return p[n]; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < numBytes; i ++) { + result = 31 * result + getByte(i); + } + return result; + } + + /** + * Soundex mapping table + */ + private static final byte[] US_ENGLISH_MAPPING = {'0', '1', '2', '3', '0', '1', '2', '7', + '0', '2', '2', '4', '5', '5', '0', '1', '2', '6', '2', '3', '0', '1', '7', '2', '0', '2'}; + + /** + * Encodes a string into a Soundex value. Soundex is an encoding used to relate similar names, + * but can also be used as a general purpose scheme to find word with similar phonemes. + * https://en.wikipedia.org/wiki/Soundex + */ + public UTF8String soundex() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte b = getByte(0); + if ('a' <= b && b <= 'z') { + b -= 32; + } else if (b < 'A' || 'Z' < b) { + // first character must be a letter + return this; + } + byte sx[] = {'0', '0', '0', '0'}; + sx[0] = b; + int sxi = 1; + int idx = b - 'A'; + byte lastCode = US_ENGLISH_MAPPING[idx]; + + for (int i = 1; i < numBytes; i++) { + b = getByte(i); + if ('a' <= b && b <= 'z') { + b -= 32; + } else if (b < 'A' || 'Z' < b) { + // not a letter, skip it + lastCode = '0'; + continue; + } + idx = b - 'A'; + byte code = US_ENGLISH_MAPPING[idx]; + if (code == '7') { + // ignore it + } else { + if (code != '0' && code != lastCode) { + sx[sxi++] = code; + if (sxi > 3) break; + } + lastCode = code; + } + } + return UTF8String.fromBytes(sx); + } + + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; + numBytes = in.readInt(); + base = new byte[numBytes]; + in.readFully((byte[]) base); + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java new file mode 100644 index 000000000000..5974cf91ff99 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class LongArraySuite { + + @Test + public void basicTest() { + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + arr.set(0, 1L); + arr.set(1, 2L); + arr.set(1, 3L); + Assert.assertEquals(2, arr.size()); + Assert.assertEquals(1L, arr.get(0)); + Assert.assertEquals(3L, arr.get(1)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java new file mode 100644 index 000000000000..a93fc0ee297c --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import junit.framework.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class BitSetSuite { + + private static BitSet createBitSet(int capacity) { + assert capacity % 64 == 0; + return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); + } + + @Test + public void basicOps() { + BitSet bs = createBitSet(64); + Assert.assertEquals(64, bs.capacity()); + + // Make sure the bit set starts empty. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertFalse(bs.isSet(i)); + } + // another form of asserting that the bit set is empty + Assert.assertFalse(bs.anySet()); + + // Set every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + bs.set(i); + Assert.assertTrue(bs.isSet(i)); + } + + // Unset every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertTrue(bs.isSet(i)); + bs.unset(i); + Assert.assertFalse(bs.isSet(i)); + } + + // Make sure anySet() can detect any set bit + bs = createBitSet(256); + bs.set(64); + Assert.assertTrue(bs.anySet()); + } + + @Test + public void traversal() { + BitSet bs = createBitSet(256); + + Assert.assertEquals(-1, bs.nextSetBit(0)); + Assert.assertEquals(-1, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(64)); + + bs.set(10); + Assert.assertEquals(10, bs.nextSetBit(0)); + Assert.assertEquals(10, bs.nextSetBit(1)); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(11)); + + bs.set(11); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(11, bs.nextSetBit(11)); + + // Skip a whole word and find it + bs.set(190); + Assert.assertEquals(190, bs.nextSetBit(12)); + + Assert.assertEquals(-1, bs.nextSetBit(191)); + Assert.assertEquals(-1, bs.nextSetBit(256)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java new file mode 100644 index 000000000000..2f8cb132ac8b --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import junit.framework.Assert; +import org.apache.spark.unsafe.Platform; +import org.junit.Test; + +/** + * Test file based on Guava's Murmur3Hash32Test. + */ +public class Murmur3_x86_32Suite { + + private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0); + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(593689054, hasher.hashInt(0)); + Assert.assertEquals(-189366624, hasher.hashInt(-42)); + Assert.assertEquals(-1134849565, hasher.hashInt(42)); + Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE)); + Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(1669671676, hasher.hashLong(0L)); + Assert.assertEquals(-846261623, hasher.hashLong(-42L)); + Assert.assertEquals(1871679806, hasher.hashLong(42L)); + Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = ("" + i).getBytes(); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java new file mode 100644 index 000000000000..06fb08118365 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.junit.Assert; +import org.junit.Test; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedNonPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocate(1024); // leak memory + Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock dataPage = manager.allocatePage(256); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java new file mode 100644 index 000000000000..80d4982c4b57 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -0,0 +1,240 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import org.junit.Test; + +import static junit.framework.Assert.*; +import static org.apache.spark.unsafe.types.CalendarInterval.*; + +public class CalendarIntervalSuite { + + @Test + public void equalsTest() { + CalendarInterval i1 = new CalendarInterval(3, 123); + CalendarInterval i2 = new CalendarInterval(3, 321); + CalendarInterval i3 = new CalendarInterval(1, 123); + CalendarInterval i4 = new CalendarInterval(3, 123); + + assertNotSame(i1, i2); + assertNotSame(i1, i3); + assertNotSame(i2, i3); + assertEquals(i1, i4); + } + + @Test + public void toStringTest() { + CalendarInterval i; + + i = new CalendarInterval(34, 0); + assertEquals(i.toString(), "interval 2 years 10 months"); + + i = new CalendarInterval(-34, 0); + assertEquals(i.toString(), "interval -2 years -10 months"); + + i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + + i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + + i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + } + + @Test + public void fromStringTest() { + testSingleUnit("year", 3, 36, 0); + testSingleUnit("month", 3, 3, 0); + testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK); + testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY); + testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR); + testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE); + testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND); + testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI); + testSingleUnit("microsecond", 3, 0, 3); + + String input; + + input = "interval -5 years 23 month"; + CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); + assertEquals(CalendarInterval.fromString(input), result); + + input = "interval -5 years 23 month "; + assertEquals(CalendarInterval.fromString(input), result); + + input = " interval -5 years 23 month "; + assertEquals(CalendarInterval.fromString(input), result); + + // Error cases + input = "interval 3month 1 hour"; + assertEquals(CalendarInterval.fromString(input), null); + + input = "interval 3 moth 1 hour"; + assertEquals(CalendarInterval.fromString(input), null); + + input = "interval"; + assertEquals(CalendarInterval.fromString(input), null); + + input = "int"; + assertEquals(CalendarInterval.fromString(input), null); + + input = ""; + assertEquals(CalendarInterval.fromString(input), null); + + input = null; + assertEquals(CalendarInterval.fromString(input), null); + } + + @Test + public void fromYearMonthStringTest() { + String input; + CalendarInterval i; + + input = "99-10"; + i = new CalendarInterval(99 * 12 + 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + input = "-8-10"; + i = new CalendarInterval(-8 * 12 - 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + try { + input = "99-15"; + CalendarInterval.fromYearMonthString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("month 15 outside range")); + } + } + + @Test + public void fromDayTimeStringTest() { + String input; + CalendarInterval i; + + input = "5 12:40:30.999999999"; + i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "10 0:12:0.888"; + i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "-3 0:0:0"; + i = new CalendarInterval(0, -3 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + try { + input = "5 30:12:20"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("hour 30 outside range")); + } + + try { + input = "5 30-12"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not match day-time format")); + } + } + + @Test + public void fromSingleUnitStringTest() { + String input; + CalendarInterval i; + + input = "12"; + i = new CalendarInterval(12 * 12, 0L); + assertEquals(CalendarInterval.fromSingleUnitString("year", input), i); + + input = "100"; + i = new CalendarInterval(0, 100 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromSingleUnitString("day", input), i); + + input = "1999.38888"; + i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); + assertEquals(CalendarInterval.fromSingleUnitString("second", input), i); + + try { + input = String.valueOf(Integer.MAX_VALUE); + CalendarInterval.fromSingleUnitString("year", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + + try { + input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); + CalendarInterval.fromSingleUnitString("hour", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + } + + @Test + public void addTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); + } + + @Test + public void subtractTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); + } + + private void testSingleUnit(String unit, int number, int months, long microseconds) { + String input1 = "interval " + number + " " + unit; + String input2 = "interval " + number + " " + unit + "s"; + CalendarInterval result = new CalendarInterval(months, microseconds); + assertEquals(CalendarInterval.fromString(input1), result); + assertEquals(CalendarInterval.fromString(input2), result); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java new file mode 100644 index 000000000000..98aa8a2469a7 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -0,0 +1,492 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import java.io.UnsupportedEncodingException; +import java.util.Arrays; +import java.util.HashMap; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static junit.framework.Assert.*; + +import static org.apache.spark.unsafe.types.UTF8String.*; + +public class UTF8StringSuite { + + private void checkBasic(String str, int len) throws UnsupportedEncodingException { + UTF8String s1 = fromString(str); + UTF8String s2 = fromBytes(str.getBytes("utf8")); + assertEquals(s1.numChars(), len); + assertEquals(s2.numChars(), len); + + assertEquals(s1.toString(), str); + assertEquals(s2.toString(), str); + assertEquals(s1, s2); + + assertEquals(s1.hashCode(), s2.hashCode()); + + assertEquals(s1.compareTo(s2), 0); + + assertEquals(s1.contains(s2), true); + assertEquals(s2.contains(s1), true); + assertEquals(s1.startsWith(s1), true); + assertEquals(s1.endsWith(s1), true); + } + + @Test + public void basicTest() throws UnsupportedEncodingException { + checkBasic("", 0); + checkBasic("hello", 5); + checkBasic("大 千 世 界", 7); + } + + @Test + public void emptyStringTest() { + assertEquals(fromString(""), EMPTY_UTF8); + assertEquals(fromBytes(new byte[0]), EMPTY_UTF8); + assertEquals(0, EMPTY_UTF8.numChars()); + assertEquals(0, EMPTY_UTF8.numBytes()); + } + + @Test + public void prefix() { + assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue( + fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); + assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); + assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + + byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + byte[] buf2 = {1, 2, 3}; + UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); + UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); + UTF8String str3 = UTF8String.fromBytes(buf2); + assertTrue(str1.getPrefix() - str2.getPrefix() < 0); + assertEquals(str1.getPrefix(), str3.getPrefix()); + } + + @Test + public void compareTo() { + assertTrue(fromString("").compareTo(fromString("a")) < 0); + assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); + assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); + assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); + assertTrue(fromString("aBcabcabc").compareTo(fromString("Abcabcabc")) > 0); + assertTrue(fromString("Abcabcabc").compareTo(fromString("abcabcabC")) < 0); + assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabC")) > 0); + + assertTrue(fromString("abc").compareTo(fromString("世界")) < 0); + assertTrue(fromString("你好").compareTo(fromString("世界")) > 0); + assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0); + } + + protected void testUpperandLower(String upper, String lower) { + UTF8String us = fromString(upper); + UTF8String ls = fromString(lower); + assertEquals(ls, us.toLowerCase()); + assertEquals(us, ls.toUpperCase()); + assertEquals(us, us.toUpperCase()); + assertEquals(ls, ls.toLowerCase()); + } + + @Test + public void upperAndLower() { + testUpperandLower("", ""); + testUpperandLower("0123456", "0123456"); + testUpperandLower("ABCXYZ", "abcxyz"); + testUpperandLower("ЀЁЂѺΏỀ", "ѐёђѻώề"); + testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头"); + } + + @Test + public void titleCase() { + assertEquals(fromString(""), fromString("").toTitleCase()); + assertEquals(fromString("Ab Bc Cd"), fromString("ab bc cd").toTitleCase()); + assertEquals(fromString("Ѐ Ё Ђ Ѻ Ώ Ề"), fromString("ѐ ё ђ ѻ ώ ề").toTitleCase()); + assertEquals(fromString("大千世界 数据砖头"), fromString("大千世界 数据砖头").toTitleCase()); + } + + @Test + public void concatTest() { + assertEquals(EMPTY_UTF8, concat()); + assertEquals(null, concat((UTF8String) null)); + assertEquals(EMPTY_UTF8, concat(EMPTY_UTF8)); + assertEquals(fromString("ab"), concat(fromString("ab"))); + assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); + assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); + assertEquals(null, concat(fromString("a"), null, fromString("c"))); + assertEquals(null, concat(fromString("a"), null, null)); + assertEquals(null, concat(null, null, null)); + assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); + } + + @Test + public void concatWsTest() { + // Returns null if the separator is null + assertEquals(null, concatWs(null, (UTF8String)null)); + assertEquals(null, concatWs(null, fromString("a"))); + + // If separator is null, concatWs should skip all null inputs and never return null. + UTF8String sep = fromString("哈哈"); + assertEquals( + EMPTY_UTF8, + concatWs(sep, EMPTY_UTF8)); + assertEquals( + fromString("ab"), + concatWs(sep, fromString("ab"))); + assertEquals( + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); + assertEquals( + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + assertEquals( + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); + assertEquals( + fromString("a"), + concatWs(sep, fromString("a"), null, null)); + assertEquals( + EMPTY_UTF8, + concatWs(sep, null, null, null)); + assertEquals( + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); + } + + @Test + public void contains() { + assertTrue(EMPTY_UTF8.contains(EMPTY_UTF8)); + assertTrue(fromString("hello").contains(fromString("ello"))); + assertFalse(fromString("hello").contains(fromString("vello"))); + assertFalse(fromString("hello").contains(fromString("hellooo"))); + assertTrue(fromString("大千世界").contains(fromString("千世界"))); + assertFalse(fromString("大千世界").contains(fromString("世千"))); + assertFalse(fromString("大千世界").contains(fromString("大千世界好"))); + } + + @Test + public void startsWith() { + assertTrue(EMPTY_UTF8.startsWith(EMPTY_UTF8)); + assertTrue(fromString("hello").startsWith(fromString("hell"))); + assertFalse(fromString("hello").startsWith(fromString("ell"))); + assertFalse(fromString("hello").startsWith(fromString("hellooo"))); + assertTrue(fromString("数据砖头").startsWith(fromString("数据"))); + assertFalse(fromString("大千世界").startsWith(fromString("千"))); + assertFalse(fromString("大千世界").startsWith(fromString("大千世界好"))); + } + + @Test + public void endsWith() { + assertTrue(EMPTY_UTF8.endsWith(EMPTY_UTF8)); + assertTrue(fromString("hello").endsWith(fromString("ello"))); + assertFalse(fromString("hello").endsWith(fromString("ellov"))); + assertFalse(fromString("hello").endsWith(fromString("hhhello"))); + assertTrue(fromString("大千世界").endsWith(fromString("世界"))); + assertFalse(fromString("大千世界").endsWith(fromString("世"))); + assertFalse(fromString("数据砖头").endsWith(fromString("我的数据砖头"))); + } + + @Test + public void substring() { + assertEquals(EMPTY_UTF8, fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(EMPTY_UTF8, fromString(" ").trim()); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)); + assertEquals(-1, EMPTY_UTF8.indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(EMPTY_UTF8, 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void substring_index() { + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), 3)); + assertEquals(fromString("www.apache"), + fromString("www.apache.org").subStringIndex(fromString("."), 2)); + assertEquals(fromString("www"), + fromString("www.apache.org").subStringIndex(fromString("."), 1)); + assertEquals(fromString(""), + fromString("www.apache.org").subStringIndex(fromString("."), 0)); + assertEquals(fromString("org"), + fromString("www.apache.org").subStringIndex(fromString("."), -1)); + assertEquals(fromString("apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), -2)); + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), -3)); + // str is empty string + assertEquals(fromString(""), + fromString("").subStringIndex(fromString("."), 1)); + // empty string delim + assertEquals(fromString(""), + fromString("www.apache.org").subStringIndex(fromString(""), 1)); + // delim does not exist in str + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("#"), 2)); + // delim is 2 chars + assertEquals(fromString("www||apache"), + fromString("www||apache||org").subStringIndex(fromString("||"), 2)); + assertEquals(fromString("apache||org"), + fromString("www||apache||org").subStringIndex(fromString("||"), -2)); + // non ascii chars + assertEquals(fromString("大千世界大"), + fromString("大千世界大千世界").subStringIndex(fromString("千"), 2)); + // overlapped delim + assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3)); + assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(EMPTY_UTF8, fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals( + fromString("孙行者孙行者孙行数据砖头"), + fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals( + fromString("数据砖头孙行者孙行者孙行"), + fromString("数据砖头").rpad(12, fromString("孙行者"))); + + assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, fromString("孙行者"))); + assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, EMPTY_UTF8)); + assertEquals(fromString("数据砖头"), fromString("数据砖头").lpad(5, EMPTY_UTF8)); + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, EMPTY_UTF8)); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.lpad(3, EMPTY_UTF8)); + + assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, fromString("孙行者"))); + assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, EMPTY_UTF8)); + assertEquals(fromString("数据砖头"), fromString("数据砖头").rpad(5, EMPTY_UTF8)); + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, EMPTY_UTF8)); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.rpad(3, EMPTY_UTF8)); + } + + @Test + public void substringSQL() { + UTF8String e = fromString("example"); + assertEquals(e.substringSQL(0, 2), fromString("ex")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 7), fromString("example")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 100), fromString("example")); + assertEquals(e.substringSQL(1, 100), fromString("example")); + assertEquals(e.substringSQL(2, 2), fromString("xa")); + assertEquals(e.substringSQL(1, 6), fromString("exampl")); + assertEquals(e.substringSQL(2, 100), fromString("xample")); + assertEquals(e.substringSQL(0, 0), fromString("")); + assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); + assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + } + + @Test + public void split() { + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + } + + @Test + public void levenshteinDistance() { + assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0); + assertEquals(EMPTY_UTF8.levenshteinDistance(fromString("a")), 1); + assertEquals(fromString("aaapppp").levenshteinDistance(EMPTY_UTF8), 7); + assertEquals(fromString("frog").levenshteinDistance(fromString("fog")), 1); + assertEquals(fromString("fly").levenshteinDistance(fromString("ant")),3); + assertEquals(fromString("elephant").levenshteinDistance(fromString("hippo")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("elephant")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("zzzzzzzz")), 8); + assertEquals(fromString("hello").levenshteinDistance(fromString("hallo")),1); + assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); + } + + @Test + public void translate() { + assertEquals( + fromString("1a2s3ae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '1', + 'n', '2', + 'l', '3', + 't', '\0' + ))); + assertEquals( + fromString("translate"), + fromString("translate").translate(new HashMap())); + assertEquals( + fromString("asae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '\0', + 'n', '\0', + 'l', '\0', + 't', '\0' + ))); + assertEquals( + fromString("aa世b"), + fromString("花花世界").translate(ImmutableMap.of( + '花', 'a', + '界', 'b' + ))); + } + + @Test + public void createBlankString() { + assertEquals(fromString(" "), blankString(1)); + assertEquals(fromString(" "), blankString(2)); + assertEquals(fromString(" "), blankString(3)); + assertEquals(fromString(""), blankString(0)); + } + + @Test + public void findInSet() { + assertEquals(fromString("ab").findInSet(fromString("ab")), 1); + assertEquals(fromString("a,b").findInSet(fromString("b")), 2); + assertEquals(fromString("abc,b,ab,c,def").findInSet(fromString("ab")), 3); + assertEquals(fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab")), 1); + assertEquals(fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab")), 4); + assertEquals(fromString(",ab,abc,b,ab,c,def").findInSet(fromString("")), 1); + assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("ab")), 4); + assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("def")), 6); + } + + @Test + public void soundex() { + assertEquals(fromString("Robert").soundex(), fromString("R163")); + assertEquals(fromString("Rupert").soundex(), fromString("R163")); + assertEquals(fromString("Rubin").soundex(), fromString("R150")); + assertEquals(fromString("Ashcraft").soundex(), fromString("A261")); + assertEquals(fromString("Ashcroft").soundex(), fromString("A261")); + assertEquals(fromString("Burroughs").soundex(), fromString("B620")); + assertEquals(fromString("Burrows").soundex(), fromString("B620")); + assertEquals(fromString("Ekzampul").soundex(), fromString("E251")); + assertEquals(fromString("Example").soundex(), fromString("E251")); + assertEquals(fromString("Ellery").soundex(), fromString("E460")); + assertEquals(fromString("Euler").soundex(), fromString("E460")); + assertEquals(fromString("Ghosh").soundex(), fromString("G200")); + assertEquals(fromString("Gauss").soundex(), fromString("G200")); + assertEquals(fromString("Gutierrez").soundex(), fromString("G362")); + assertEquals(fromString("Heilbronn").soundex(), fromString("H416")); + assertEquals(fromString("Hilbert").soundex(), fromString("H416")); + assertEquals(fromString("Jackson").soundex(), fromString("J250")); + assertEquals(fromString("Kant").soundex(), fromString("K530")); + assertEquals(fromString("Knuth").soundex(), fromString("K530")); + assertEquals(fromString("Lee").soundex(), fromString("L000")); + assertEquals(fromString("Lukasiewicz").soundex(), fromString("L222")); + assertEquals(fromString("Lissajous").soundex(), fromString("L222")); + assertEquals(fromString("Ladd").soundex(), fromString("L300")); + assertEquals(fromString("Lloyd").soundex(), fromString("L300")); + assertEquals(fromString("Moses").soundex(), fromString("M220")); + assertEquals(fromString("O'Hara").soundex(), fromString("O600")); + assertEquals(fromString("Pfister").soundex(), fromString("P236")); + assertEquals(fromString("Rubin").soundex(), fromString("R150")); + assertEquals(fromString("Robert").soundex(), fromString("R163")); + assertEquals(fromString("Rupert").soundex(), fromString("R163")); + assertEquals(fromString("Soundex").soundex(), fromString("S532")); + assertEquals(fromString("Sownteks").soundex(), fromString("S532")); + assertEquals(fromString("Tymczak").soundex(), fromString("T522")); + assertEquals(fromString("VanDeusen").soundex(), fromString("V532")); + assertEquals(fromString("Washington").soundex(), fromString("W252")); + assertEquals(fromString("Wheaton").soundex(), fromString("W350")); + + assertEquals(fromString("a").soundex(), fromString("A000")); + assertEquals(fromString("ab").soundex(), fromString("A100")); + assertEquals(fromString("abc").soundex(), fromString("A120")); + assertEquals(fromString("abcd").soundex(), fromString("A123")); + assertEquals(fromString("").soundex(), fromString("")); + assertEquals(fromString("123").soundex(), fromString("123")); + assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); + } +} diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala new file mode 100644 index 000000000000..12a002befa0a --- /dev/null +++ b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types + +import org.apache.commons.lang3.StringUtils + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.prop.GeneratorDrivenPropertyChecks +// scalastyle:off +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8} + +/** + * This TestSuite utilize ScalaCheck to generate randomized inputs for UTF8String testing. + */ +class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenPropertyChecks with Matchers { +// scalastyle:on + + test("toString") { + forAll { (s: String) => + assert(toUTF8(s).toString() === s) + } + } + + test("numChars") { + forAll { (s: String) => + assert(toUTF8(s).numChars() === s.length) + } + } + + test("startsWith") { + forAll { (s: String) => + val utf8 = toUTF8(s) + assert(utf8.startsWith(utf8)) + for (i <- 1 to s.length) { + assert(utf8.startsWith(toUTF8(s.dropRight(i)))) + } + } + } + + test("endsWith") { + forAll { (s: String) => + val utf8 = toUTF8(s) + assert(utf8.endsWith(utf8)) + for (i <- 1 to s.length) { + assert(utf8.endsWith(toUTF8(s.drop(i)))) + } + } + } + + test("toUpperCase") { + forAll { (s: String) => + assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase)) + } + } + + test("toLowerCase") { + forAll { (s: String) => + assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase)) + } + } + + test("compare") { + forAll { (s1: String, s2: String) => + assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2))) + } + } + + test("substring") { + forAll { (s: String) => + for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) { + assert(toUTF8(s).substring(start, end).toString === s.substring(start, end)) + } + } + } + + test("contains") { + forAll { (s: String) => + for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) { + val substring = s.substring(start, end) + assert(toUTF8(s).contains(toUTF8(substring)) === s.contains(substring)) + } + } + } + + val whitespaceChar: Gen[Char] = Gen.choose(0x00, 0x20).map(_.toChar) + val whitespaceString: Gen[String] = Gen.listOf(whitespaceChar).map(_.mkString) + val randomString: Gen[String] = Arbitrary.arbString.arbitrary + + test("trim, trimLeft, trimRight") { + // lTrim and rTrim are both modified from java.lang.String.trim + def lTrim(s: String): String = { + var st = 0 + val array: Array[Char] = s.toCharArray + while ((st < s.length) && (array(st) <= ' ')) { + st += 1 + } + if (st > 0) s.substring(st, s.length) else s + } + def rTrim(s: String): String = { + var len = s.length + val array: Array[Char] = s.toCharArray + while ((len > 0) && (array(len - 1) <= ' ')) { + len -= 1 + } + if (len < s.length) s.substring(0, len) else s + } + + forAll( + whitespaceString, + randomString, + whitespaceString + ) { (start: String, middle: String, end: String) => + val s = start + middle + end + assert(toUTF8(s).trim() === toUTF8(s.trim())) + assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s))) + assert(toUTF8(s).trimRight() === toUTF8(rTrim(s))) + } + } + + test("reverse") { + forAll { (s: String) => + assert(toUTF8(s).reverse === toUTF8(s.reverse)) + } + } + + test("indexOf") { + forAll { (s: String) => + for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) { + val substring = s.substring(start, end) + assert(toUTF8(s).indexOf(toUTF8(substring), 0) === s.indexOf(substring)) + } + } + } + + val randomInt = Gen.choose(-100, 100) + + test("repeat") { + def repeat(str: String, times: Int): String = { + if (times > 0) str * times else "" + } + // ScalaCheck always generating too large repeat times which might hang the test forever. + forAll(randomString, randomInt) { (s: String, times: Int) => + assert(toUTF8(s).repeat(times) === toUTF8(repeat(s, times))) + } + } + + test("lpad, rpad") { + def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { + if (length <= 0) return "" + if (length <= origin.length) { + if (length <= 0) "" else origin.substring(0, length) + } else { + if (pad.length == 0) return origin + val toPad = length - origin.length + val partPad = if (toPad % pad.length == 0) "" else pad.substring(0, toPad % pad.length) + if (isLPad) { + pad * (toPad / pad.length) + partPad + origin + } else { + origin + pad * (toPad / pad.length) + partPad + } + } + } + + forAll ( + randomString, + randomString, + randomInt + ) { (s: String, pad: String, length: Int) => + assert(toUTF8(s).lpad(length, toUTF8(pad)) === + toUTF8(padding(s, pad, length, true))) + assert(toUTF8(s).rpad(length, toUTF8(pad)) === + toUTF8(padding(s, pad, length, false))) + } + } + + val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString)) + + test("concat") { + def concat(orgin: Seq[String]): String = + if (orgin.exists(_ == null)) null else orgin.mkString + + forAll { (inputs: Seq[String]) => + assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString)) + } + forAll (nullalbeSeq) { (inputs: Seq[String]) => + assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(concat(inputs))) + } + } + + test("concatWs") { + def concatWs(sep: String, inputs: Seq[String]): String = { + if (sep == null) return null + inputs.filter(_ != null).mkString(sep) + } + + forAll { (sep: String, inputs: Seq[String]) => + assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) === + toUTF8(inputs.mkString(sep))) + } + forAll(randomString, nullalbeSeq) {(sep: String, inputs: Seq[String]) => + assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) === + toUTF8(concatWs(sep, inputs))) + } + } + + // TODO: enable this when we find a proper way to generate valid patterns + ignore("split") { + forAll { (s: String, pattern: String, limit: Int) => + assert(toUTF8(s).split(toUTF8(pattern), limit) === + s.split(pattern, limit).map(toUTF8(_))) + } + } + + test("levenshteinDistance") { + forAll { (one: String, another: String) => + assert(toUTF8(one).levenshteinDistance(toUTF8(another)) === + StringUtils.getLevenshteinDistance(one, another)) + } + } + + test("hashCode") { + forAll { (s: String) => + assert(toUTF8(s).hashCode() === toUTF8(s).hashCode()) + } + } + + test("equals") { + forAll { (one: String, another: String) => + assert(toUTF8(one).equals(toUTF8(another)) === one.equals(another)) + } + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 65344aa8738e..f6737695307a 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.5.0-SNAPSHOT ../pom.xml @@ -38,6 +38,19 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-network-yarn_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api @@ -86,6 +99,27 @@ + + + org.eclipse.jetty.orbit + javax.servlet.jsp + 2.2.0.v201112011158 + test + + + org.eclipse.jetty.orbit + javax.servlet.jsp.jstl + 1.2.0.v201105211821 + test + + + + org.apache.hadoop hadoop-yarn-server-tests @@ -94,62 +128,38 @@ org.mockito - mockito-all + mockito-core test + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + test + + + com.sun.jersey + jersey-json + test + + + com.sun.jersey + jersey-server + test + - - - - hadoop-2.2 - - 1.9 - - - - org.mortbay.jetty - jetty - 6.1.26 - - - org.mortbay.jetty - servlet-api - - - test - - - com.sun.jersey - jersey-core - ${jersey.version} - test - - - com.sun.jersey - jersey-json - ${jersey.version} - test - - - stax - stax-api - - - - - com.sun.jersey - jersey-server - ${jersey.version} - test - - - - - target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala new file mode 100644 index 000000000000..56e4741b9387 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.security.PrivilegedExceptionAction +import java.util.concurrent.{Executors, TimeUnit} + +import scala.language.postfixOps + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.deploy.SparkHadoopUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.ThreadUtils + +/* + * The following methods are primarily meant to make sure long-running apps like Spark + * Streaming apps can run without interruption while writing to secure HDFS. The + * scheduleLoginFromKeytab method is called on the driver when the + * CoarseGrainedScheduledBackend starts up. This method wakes up a thread that logs into the KDC + * once 75% of the renewal interval of the original delegation tokens used for the container + * has elapsed. It then creates new delegation tokens and writes them to HDFS in a + * pre-specified location - the prefix of which is specified in the sparkConf by + * spark.yarn.credentials.file (so the file(s) would be named c-1, c-2 etc. - each update goes + * to a new file, with a monotonically increasing suffix). After this, the credentials are + * updated once 75% of the new tokens renewal interval has elapsed. + * + * On the executor side, the updateCredentialsIfRequired method is called once 80% of the + * validity of the original tokens has elapsed. At that time the executor finds the + * credentials file with the latest timestamp and checks if it has read those credentials + * before (by keeping track of the suffix of the last file it read). If a new file has + * appeared, it will read the credentials and update the currently running UGI with it. This + * process happens again once 80% of the validity of this has expired. + */ +private[yarn] class AMDelegationTokenRenewer( + sparkConf: SparkConf, + hadoopConf: Configuration) extends Logging { + + private var lastCredentialsFileSuffix = 0 + + private val delegationTokenRenewer = + Executors.newSingleThreadScheduledExecutor( + ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread")) + + private val hadoopUtil = YarnSparkHadoopUtil.get + + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val freshHadoopConf = + hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) + + /** + * Schedule a login from the keytab and principal set using the --principal and --keytab + * arguments to spark-submit. This login happens only when the credentials of the current user + * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from + * SparkConf to do the login. This method is a no-op in non-YARN mode. + * + */ + private[spark] def scheduleLoginFromKeytab(): Unit = { + val principal = sparkConf.get("spark.yarn.principal") + val keytab = sparkConf.get("spark.yarn.keytab") + + /** + * Schedule re-login and creation of new tokens. If tokens have already expired, this method + * will synchronously create new ones. + */ + def scheduleRenewal(runnable: Runnable): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials + val renewalInterval = hadoopUtil.getTimeFromNowToRenewal(sparkConf, 0.75, credentials) + // Run now! + if (renewalInterval <= 0) { + logInfo("HDFS tokens have expired, creating new tokens now.") + runnable.run() + } else { + logInfo(s"Scheduling login from keytab in $renewalInterval millis.") + delegationTokenRenewer.schedule(runnable, renewalInterval, TimeUnit.MILLISECONDS) + } + } + + // This thread periodically runs on the driver to update the delegation tokens on HDFS. + val driverTokenRenewerRunnable = + new Runnable { + override def run(): Unit = { + try { + writeNewTokensToHDFS(principal, keytab) + cleanupOldFiles() + } catch { + case e: Exception => + // Log the error and try to write new tokens back in an hour + logWarning("Failed to write out new credentials to HDFS, will try again in an " + + "hour! If this happens too often tasks will fail.", e) + delegationTokenRenewer.schedule(this, 1, TimeUnit.HOURS) + return + } + scheduleRenewal(this) + } + } + // Schedule update of credentials. This handles the case of updating the tokens right now + // as well, since the renenwal interval will be 0, and the thread will get scheduled + // immediately. + scheduleRenewal(driverTokenRenewerRunnable) + } + + // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At + // least numFilesToKeep files are kept for safety + private def cleanupOldFiles(): Unit = { + import scala.concurrent.duration._ + try { + val remoteFs = FileSystem.get(freshHadoopConf) + val credentialsPath = new Path(credentialsFile) + val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis + hadoopUtil.listFilesSorted( + remoteFs, credentialsPath.getParent, + credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .dropRight(numFilesToKeep) + .takeWhile(_.getModificationTime < thresholdTime) + .foreach(x => remoteFs.delete(x.getPath, true)) + } catch { + // Such errors are not fatal, so don't throw. Make sure they are logged though + case e: Exception => + logWarning("Error while attempting to cleanup old tokens. If you are seeing many such " + + "warnings there may be an issue with your HDFS cluster.", e) + } + } + + private def writeNewTokensToHDFS(principal: String, keytab: String): Unit = { + // Keytab is copied by YARN to the working directory of the AM, so full path is + // not needed. + + // HACK: + // HDFS will not issue new delegation tokens, if the Credentials object + // passed in already has tokens for that FS even if the tokens are expired (it really only + // checks if there are tokens for the service, and not if they are valid). So the only real + // way to get new tokens is to make sure a different Credentials object is used each time to + // get new tokens and then the new tokens are copied over the the current user's Credentials. + // So: + // - we login as a different user and get the UGI + // - use that UGI to get the tokens (see doAs block below) + // - copy the tokens over to the current user's credentials (this will overwrite the tokens + // in the current user's Credentials object for this FS). + // The login to KDC happens each time new tokens are required, but this is rare enough to not + // have to worry about (like once every day or so). This makes this code clearer than having + // to login and then relogin every time (the HDFS API may not relogin since we don't use this + // UGI directly for HDFS communication. + logInfo(s"Attempting to login to KDC using principal: $principal") + val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + val tempCreds = keytabLoggedInUGI.getCredentials + val credentialsPath = new Path(credentialsFile) + val dst = credentialsPath.getParent + keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { + // Get a copy of the credentials + override def run(): Void = { + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst + hadoopUtil.obtainTokensForNamenodes(nns, freshHadoopConf, tempCreds) + null + } + }) + // Add the temp credentials back to the original ones. + UserGroupInformation.getCurrentUser.addCredentials(tempCreds) + val remoteFs = FileSystem.get(freshHadoopConf) + // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM + // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file + // and update the lastCredentialsFileSuffix. + if (lastCredentialsFileSuffix == 0) { + hadoopUtil.listFilesSorted( + remoteFs, credentialsPath.getParent, + credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .lastOption.foreach { status => + lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) + } + } + val nextSuffix = lastCredentialsFileSuffix + 1 + val tokenPathStr = + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + val tokenPath = new Path(tokenPathStr) + val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + logInfo("Writing out delegation tokens to " + tempTokenPath.toString) + val credentials = UserGroupInformation.getCurrentUser.getCredentials + credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) + logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") + remoteFs.rename(tempTokenPath, tokenPath) + logInfo("Delegation token file rename complete.") + lastCredentialsFileSuffix = nextSuffix + } + + def stop(): Unit = { + delegationTokenRenewer.shutdown() + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 37e98e01fddf..991b5cec00bd 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,26 +19,24 @@ package org.apache.spark.deploy.yarn import scala.util.control.NonFatal -import java.io.IOException +import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException -import java.net.Socket +import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} -import org.apache.spark.SparkException -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.rpc._ +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, + SparkException, SparkUserAppException} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -48,17 +46,26 @@ private[spark] class ApplicationMaster( client: YarnRMClient) extends Logging { + // Load the properties file with the Spark configuration and set entries as system properties, + // so that user code run inside the AM also has access to them. + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sys.props(k) = v + } + } + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val sparkConf = new SparkConf() private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf) .asInstanceOf[YarnConfiguration] - private val isDriver = args.userClass != null + private val isClusterMode = args.userClass != null // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + sparkConf.getInt("spark.yarn.max.worker.failures", + math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -67,21 +74,24 @@ private[spark] class ApplicationMaster( @volatile private var finalMsg: String = "" @volatile private var userClassThread: Thread = _ - private var reporterThread: Thread = _ - private var allocator: YarnAllocator = _ + @volatile private var reporterThread: Thread = _ + @volatile private var allocator: YarnAllocator = _ + private val allocatorLock = new Object() // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) + private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + final def run(): Int = { try { val appAttemptId = client.getAttemptId() - if (isDriver) { + if (isClusterMode) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -91,55 +101,57 @@ private[spark] class ApplicationMaster( // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + + // Propagate the attempt if, so that in case of event logging, + // different attempt's logs gets created in different directory + System.setProperty("spark.yarn.app.attemptId", appAttemptId.getAttemptId().toString()) } logInfo("ApplicationAttemptId: " + appAttemptId) val fs = FileSystem.get(yarnConf) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } - val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - - if (!finished) { - // This happens when the user application calls System.exit(). We have the choice - // of either failing or succeeding at this point. We report success to avoid - // retrying applications that have succeeded (System.exit(0)), which means that - // applications that explicitly exit with a non-zero status will also show up as - // succeeded in the RM UI. - finish(finalStatus, - ApplicationMaster.EXIT_SUCCESS, - "Shutdown hook called before final status was reported.") - } - if (!unregistered) { - // we only want to unregister if we don't want the RM to retry - if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { - unregister(finalStatus, finalMsg) - cleanupStagingDir(fs) - } + // This shutdown hook should run *after* the SparkContext is shut down. + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => + val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // This happens when the user application calls System.exit(). We have the choice + // of either failing or succeeding at this point. We report success to avoid + // retrying applications that have succeeded (System.exit(0)), which means that + // applications that explicitly exit with a non-zero status will also show up as + // succeeded in the RM UI. + finish(finalStatus, + ApplicationMaster.EXIT_SUCCESS, + "Shutdown hook called before final status was reported.") + } + + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir(fs) } } } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the // Hadoop UGI. This has to happen before the startUserApplication which does a // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) - if (isDriver) { + // If the credentials file config is present, we must periodically renew tokens. So create + // a new AMDelegationTokenRenewer + if (sparkConf.contains("spark.yarn.credentials.file")) { + delegationTokenRenewerOption = Some(new AMDelegationTokenRenewer(sparkConf, yarnConf)) + // If a principal and keytab have been set, use that to create new credentials for executors + // periodically + delegationTokenRenewerOption.foreach(_.scheduleLoginFromKeytab()) + } + + if (isClusterMode) { runDriver(securityMgr) } else { runExecutorLauncher(securityMgr) @@ -150,7 +162,7 @@ private[spark] class ApplicationMaster( logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e.getMessage()) + "Uncaught exception: " + e) } exitCode } @@ -161,8 +173,8 @@ private[spark] class ApplicationMaster( * status to SUCCEEDED in cluster mode to handle if the user calls System.exit * from the application code. */ - final def getDefaultFinalStatus() = { - if (isDriver) { + final def getDefaultFinalStatus(): FinalApplicationStatus = { + if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { FinalApplicationStatus.UNDEFINED @@ -174,31 +186,36 @@ private[spark] class ApplicationMaster( * This means the ResourceManager will not retry the application attempt on your behalf if * a failure occurred. */ - final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { - if (!unregistered) { - logInfo(s"Unregistering ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - unregistered = true - client.unregister(status, Option(diagnostics).getOrElse("")) + final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { + synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) + } } } - final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { - if (!finished) { - val inShutdown = Utils.inShutdown() - logInfo(s"Final app status: ${status}, exitCode: ${code}" + - Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status - finalMsg = msg - finished = true - if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { - logDebug("shutting down reporter thread") - reporterThread.interrupt() - } - if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { - logDebug("shutting down user thread") - userClassThread.interrupt() + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { + synchronized { + if (!finished) { + val inShutdown = ShutdownHookManager.inShutdown() + logInfo(s"Final app status: $status, exitCode: $code" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code + finalStatus = status + finalMsg = msg + finished = true + if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() + } + if (!inShutdown) delegationTokenRenewerOption.foreach(_.stop()) } } } @@ -214,17 +231,30 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM( + _rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + uiAddress: String, + securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() + val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = sparkConf.getOption("spark.yarn.historyServer.address") - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - allocator = client.register(yarnConf, - if (sc != null) sc.getConf else sparkConf, + val _sparkConf = if (sc != null) sc.getConf else sparkConf + val driverUrl = _rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + allocator = client.register(driverUrl, + driverRef, + yarnConf, + _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, @@ -235,23 +265,24 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. + * + * @return A reference to the driver's RPC endpoint. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, - isDriver: Boolean): Unit = { - - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + isClusterMode: Boolean): RpcEndpointRef = { + val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isDriver)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) + driverEndpoint } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -268,22 +299,22 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + val driverRef = runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), - isDriver = true) - registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + isClusterMode = true) + registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 - waitForSparkDriver() + val port = sparkConf.getInt("spark.yarn.am.port", 0) + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) + val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -294,11 +325,14 @@ private[spark] class ApplicationMaster( val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + val heartbeatInterval = math.max(0, math.min(expiryInterval / 2, + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) + + // we want to check more frequently for pending containers + val initialAllocationInterval = math.min(heartbeatInterval, + sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) - // must be <= expiryInterval / 2. - val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + var nextAllocationInterval = initialAllocationInterval // The number of failures in a row until Reporter thread give up val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) @@ -324,15 +358,29 @@ private[spark] class ApplicationMaster( if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + - s"${failureCount} time(s) from Reporter thread.") - + s"$failureCount time(s) from Reporter thread.") } else { - logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) + logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } } } try { - Thread.sleep(interval) + val numPendingAllocate = allocator.getNumPendingAllocate + val sleepInterval = + if (numPendingAllocate > 0) { + val currentAllocationInterval = + math.min(heartbeatInterval, nextAllocationInterval) + nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow + currentAllocationInterval + } else { + nextAllocationInterval = initialAllocationInterval + heartbeatInterval + } + logDebug(s"Number of pending allocations is $numPendingAllocate. " + + s"Sleeping for $sleepInterval.") + allocatorLock.synchronized { + allocatorLock.wait(sleepInterval) + } } catch { case e: InterruptedException => } @@ -343,7 +391,8 @@ private[spark] class ApplicationMaster( t.setDaemon(true) t.setName("Reporter") t.start() - logInfo("Started progress reporter thread - sleep time : " + interval) + logInfo(s"Started progress reporter thread with (heartbeat : $heartbeatInterval, " + + s"initial allocation : $initialAllocationInterval) intervals") t } @@ -372,13 +421,7 @@ private[spark] class ApplicationMaster( private def waitForSparkContextInitialized(): SparkContext = { logInfo("Waiting for spark context initialization") sparkContextRef.synchronized { - val waitTries = sparkConf.getOption("spark.yarn.applicationMaster.waitTries") - .map(_.toLong * 10000L) - if (waitTries.isDefined) { - logWarning( - "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime") - } - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", waitTries.getOrElse(100000L)) + val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { @@ -395,7 +438,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkDriver(): Unit = { + private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) @@ -403,8 +446,8 @@ private[spark] class ApplicationMaster( // Spark driver should already be up since it launched us, but we don't want to // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", 100000L) - val deadline = System.currentTimeMillis + totalWaitTime + val totalWaitTimeMs = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") + val deadline = System.currentTimeMillis + totalWaitTimeMs while (!driverUp && !finished && System.currentTimeMillis < deadline) { try { @@ -427,7 +470,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isDriver = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -435,11 +478,11 @@ private[spark] class ApplicationMaster( val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" val params = client.getAmIpFilterParams(yarnConf, proxyBase) - if (isDriver) { + if (isClusterMode) { System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -452,20 +495,34 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) + + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + val userClassLoader = + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } + + var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - System.setProperty("spark.submit.pyFiles", - PythonRunner.formatPaths(args.pyFiles).mkString(",")) + // When running pyspark, the app is run using PythonRunner. The second argument is the list + // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty. + userArgs = Seq(args.primaryPyFile, "") ++ userArgs + } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here } - val mainMethod = Class.forName(args.userClass, false, - Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) + val mainMethod = userClassLoader.loadClass(args.userClass) + .getMethod("main", classOf[Array[String]]) val userThread = new Thread { override def run() { try { - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) + mainMethod.invoke(null, userArgs.toArray) finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) logDebug("Done running users class") } catch { @@ -473,61 +530,58 @@ private[spark] class ApplicationMaster( e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class - case e: Exception => + case SparkUserAppException(exitCode) => + val msg = s"User application exited with status $exitCode" + logError(msg) + finish(FinalApplicationStatus.FAILED, exitCode, msg) + case cause: Throwable => + logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + e.getMessage) - // re-throw to get it logged - throw e + "User class threw exception: " + cause) } } } } + userThread.setContextClassLoader(userClassLoader) userThread.setName("Driver") userThread.start() userThread } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isDriver: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart() = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isDriver) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isDriver) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) + + } + override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } - case RequestExecutors(requestedTotal) => - logInfo(s"Driver requested a total number of $requestedTotal executor(s).") + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => Option(allocator) match { - case Some(a) => a.requestTotalExecutors(requestedTotal) - case None => logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => + allocatorLock.synchronized { + if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, + localityAwareTasks, hostToLocalTaskCount)) { + allocatorLock.notifyAll() + } + } + + case None => + logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -535,7 +589,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } @@ -543,8 +606,6 @@ private[spark] class ApplicationMaster( object ApplicationMaster extends Logging { - val SHUTDOWN_HOOK_PRIORITY: Int = 30 - // exit codes for different causes, no reason behind the values private val EXIT_SUCCESS = 0 private val EXIT_UNCAUGHT_EXCEPTION = 10 @@ -556,7 +617,7 @@ object ApplicationMaster extends Logging { private var master: ApplicationMaster = _ - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => @@ -565,11 +626,11 @@ object ApplicationMaster extends Logging { } } - private[spark] def sparkContextInitialized(sc: SparkContext) = { + private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext) = { + private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { master.sparkContextStopped(sc) } @@ -581,7 +642,7 @@ object ApplicationMaster extends Logging { */ object ExecutorLauncher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { ApplicationMaster.main(args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3aae..b08412414aa1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -25,11 +25,11 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var primaryPyFile: String = null - var pyFiles: String = null - var userArgs: Seq[String] = Seq[String]() + var primaryRFile: String = null + var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS + var propertiesFile: String = null parseArgs(args.toList) @@ -54,18 +54,14 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryPyFile = value args = tail - case ("--py-files") :: value :: tail => - pyFiles = value + case ("--primary-r-file") :: value :: tail => + primaryRFile = value args = tail case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail @@ -74,15 +70,27 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--properties-file") :: value :: tail => + propertiesFile = value + args = tail + case _ => printUsageAndExit(1, args) } } + if (primaryPyFile != null && primaryRFile != null) { + // scalastyle:off println + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + // scalastyle:on println + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + // scalastyle:off println if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) } @@ -92,6 +100,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file + | --primary-r-file A main R file | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. @@ -100,6 +109,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) """.stripMargin) + // scalastyle:on println System.exit(exitCode) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 91e8574e94e2..e9a02baafd28 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,22 +17,33 @@ package org.apache.spark.deploy.yarn +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, + OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer +import java.security.PrivilegedExceptionAction +import java.util.{Properties, UUID} +import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map} +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} +import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import scala.util.control.NonFatal +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects +import com.google.common.io.Files import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.mapred.Master +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.{TokenIdentifier, Token} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -40,10 +51,11 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.util.Utils private[spark] class Client( @@ -61,20 +73,20 @@ private[spark] class Client( private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) - private val credentials = UserGroupInformation.getCurrentUser.getCredentials + private var credentials: Credentials = null private val amMemoryOverhead = args.amMemoryOverhead // MB private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + private var loginFromKeytab = false + private var principal: String = null + private var keytab: String = null - def stop(): Unit = yarnClient.stop() + private val fireAndForget = isClusterMode && + !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) - /* ------------------------------------------------------------------------------------- * - | The following methods have much in common in the stable and alpha versions of Client, | - | but cannot be implemented in the parent trait due to subtle API differences across | - | hadoop versions. | - * ------------------------------------------------------------------------------------- */ + def stop(): Unit = yarnClient.stop() /** * Submit an application running our ApplicationMaster to the ResourceManager. @@ -84,28 +96,59 @@ private[spark] class Client( * available in the alpha API. */ def submitApplication(): ApplicationId = { - yarnClient.init(yarnConf) - yarnClient.start() - - logInfo("Requesting a new application from cluster with %d NodeManagers" - .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) - - // Get a new application from our RM - val newApp = yarnClient.createApplication() - val newAppResponse = newApp.getNewApplicationResponse() - val appId = newAppResponse.getApplicationId() - - // Verify whether the cluster has enough resources for our AM - verifyClusterResources(newAppResponse) - - // Set up the appropriate contexts to launch our AM - val containerContext = createContainerLaunchContext(newAppResponse) - val appContext = createApplicationSubmissionContext(newApp, containerContext) + var appId: ApplicationId = null + try { + // Setup the credentials before doing anything else, + // so we have don't have issues at any point. + setupCredentials() + yarnClient.init(yarnConf) + yarnClient.start() + + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) + + // Get a new application from our RM + val newApp = yarnClient.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + appId = newAppResponse.getApplicationId() + + // Verify whether the cluster has enough resources for our AM + verifyClusterResources(newAppResponse) + + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(newApp, containerContext) + + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + yarnClient.submitApplication(appContext) + appId + } catch { + case e: Throwable => + if (appId != null) { + cleanupStagingDir(appId) + } + throw e + } + } - // Finally, submit and monitor the application - logInfo(s"Submitting application ${appId.getId} to ResourceManager") - yarnClient.submitApplication(appContext) - appId + /** + * Cleanup application staging directory. + */ + private def cleanupStagingDir(appId: ApplicationId): Unit = { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } } /** @@ -120,6 +163,23 @@ private[spark] class Client( appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(containerContext) appContext.setApplicationType("SPARK") + sparkConf.getOption(CONF_SPARK_YARN_APPLICATION_TAGS) + .map(StringUtils.getTrimmedStringCollection(_)) + .filter(!_.isEmpty()) + .foreach { tagCollection => + try { + // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use + // reflection to set it, printing a warning if a tag was specified but the YARN version + // doesn't support it. + val method = appContext.getClass().getMethod( + "setApplicationTags", classOf[java.util.Set[String]]) + method.invoke(appContext, new java.util.HashSet[String](tagCollection)) + } catch { + case e: NoSuchMethodException => + logWarning(s"Ignoring $CONF_SPARK_YARN_APPLICATION_TAGS because this version of " + + "YARN does not support it") + } + } sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) match { case Some(v) => appContext.setMaxAppAttempts(v) case None => logDebug("spark.yarn.maxAppAttempts is not set. " + @@ -160,12 +220,14 @@ private[spark] class Client( val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + - s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, @@ -183,8 +245,7 @@ private[spark] class Client( private[yarn] def copyFileToRemote( destDir: Path, srcPath: Path, - replication: Short, - setPerms: Boolean = false): Path = { + replication: Short): Path = { val destFs = destDir.getFileSystem(hadoopConf) val srcFs = srcPath.getFileSystem(hadoopConf) var destPath = srcPath @@ -193,9 +254,7 @@ private[spark] class Client( logInfo(s"Uploading resource $srcPath -> $destPath") FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf) destFs.setReplication(destPath, replication) - if (setPerms) { - destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) - } + destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) } else { logInfo(s"Source and destination file systems are the same. Not copying $srcPath") } @@ -212,14 +271,22 @@ private[spark] class Client( * This is used for setting up a container launch context for our ApplicationMaster. * Exposed for testing. */ - def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources( + appStagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, LocalResource] = { logInfo("Preparing resources for our AM container") // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val nns = getNameNodesToAccess(sparkConf) + dst - obtainTokensForNamenodes(nns, hadoopConf, credentials) + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst + YarnSparkHadoopUtil.get.obtainTokensForNamenodes(nns, hadoopConf, credentials) + // Used to keep track of URIs added to the distributed cache. If the same URI is added + // multiple times, YARN will fail to launch containers for the app with an internal + // error. + val distributedUris = new HashSet[String] + obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials) + obtainTokenForHBase(sparkConf, hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -236,33 +303,89 @@ private[spark] class Client( "for alternatives.") } + def addDistributedUri(uri: URI): Boolean = { + val uriStr = uri.toString() + if (distributedUris.contains(uriStr)) { + logWarning(s"Resource $uri added multiple times to distributed cache.") + false + } else { + distributedUris += uriStr + true + } + } + + /** + * Distribute a file to the cluster. + * + * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied + * to HDFS (if not already there) and added to the application's distributed cache. + * + * @param path URI of the file to distribute. + * @param resType Type of resource being distributed. + * @param destName Name of the file in the distributed cache. + * @param targetDir Subdirectory where to place the file. + * @param appMasterOnly Whether to distribute only to the AM. + * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the + * localized path for non-local paths, or the input `path` for local paths. + * The localized path will be null if the URI has already been added to the cache. + */ + def distribute( + path: String, + resType: LocalResourceType = LocalResourceType.FILE, + destName: Option[String] = None, + targetDir: Option[String] = None, + appMasterOnly: Boolean = false): (Boolean, String) = { + val localURI = new URI(path.trim()) + if (localURI.getScheme != LOCAL_SCHEME) { + if (addDistributedUri(localURI)) { + val localPath = getQualifiedLocalPath(localURI, hadoopConf) + val linkname = targetDir.map(_ + "/").getOrElse("") + + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource( + destFs, hadoopConf, destPath, localResources, resType, linkname, statCache, + appMasterOnly = appMasterOnly) + (false, linkname) + } else { + (false, null) + } + } else { + (true, path.trim()) + } + } + + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val (_, localizedPath) = distribute(keytab, + destName = Some(sparkConf.get("spark.yarn.keytab")), + appMasterOnly = true) + require(localizedPath != null, "Keytab file already distributed.") + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. - * Each resource is represented by a 4-tuple of: + * Each resource is represented by a 3-tuple of: * (1) destination resource name, * (2) local path to the resource, - * (3) Spark property key to set if the scheme is not local, and - * (4) whether to set permissions for this resource + * (3) Spark property key to set if the scheme is not local */ List( - (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false), - (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true), - ("log4j.properties", oldLog4jConf.orNull, null, false) - ).foreach { case (destName, _localPath, confKey, setPermissions) => - val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (!localPath.isEmpty()) { - val localURI = new URI(localPath) - if (localURI.getScheme != LOCAL_SCHEME) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication, setPermissions) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) - } else if (confKey != null) { + (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), + (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), + ("log4j.properties", oldLog4jConf.orNull, null) + ).foreach { case (destName, path, confKey) => + if (path != null && !path.trim().isEmpty()) { + val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) + if (isLocal && confKey != null) { + require(localizedPath != null, s"Path $path already distributed.") // If the resource is intended for local use only, handle this downstream // by setting the appropriate property - sparkConf.set(confKey, localPath) + sparkConf.set(confKey, localizedPath) } } } @@ -282,19 +405,10 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { flist.split(',').foreach { file => - val localURI = new URI(file.trim()) - if (localURI.getScheme != LOCAL_SCHEME) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname - } - } else if (addToClasspath) { - // Resource is intended for local use only and should be added to the class path - cachedSecondaryJarLinks += file.trim() + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -303,24 +417,134 @@ private[spark] class Client( sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) } + if (isClusterMode && args.primaryPyFile != null) { + distribute(args.primaryPyFile, appMasterOnly = true) + } + + pySparkArchives.foreach { f => distribute(f) } + + // The python files list needs to be treated especially. All files that are not an + // archive need to be placed in a subdirectory that will be added to PYTHONPATH. + args.pyFiles.foreach { f => + val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None + distribute(f, targetDir = targetDir) + } + + // Distribute an archive with Hadoop and Spark configuration for the AM. + val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(), + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_CONF_DIR), + appMasterOnly = true) + require(confLocalizedPath != null) + localResources } + /** + * Create an archive with the config files for distribution. + * + * These are only used by the AM, since executors will use the configuration object broadcast by + * the driver. The files are zipped and added to the job as an archive, so that YARN will explode + * it when distributing to the AM. This directory is then added to the classpath of the AM + * process, just to make sure that everybody is using the same default config. + * + * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR + * shows up in the classpath before YARN_CONF_DIR. + * + * Currently this makes a shallow copy of the conf directory. If there are cases where a + * Hadoop config directory contains subdirectories, this code will have to be fixed. + * + * The archive also contains some Spark configuration. Namely, it saves the contents of + * SparkConf in a file to be loaded by the AM process. + */ + private def createConfArchive(): File = { + val hadoopConfFiles = new HashMap[String, File]() + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + sys.env.get(envKey).foreach { path => + val dir = new File(path) + if (dir.isDirectory()) { + dir.listFiles().foreach { file => + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { + hadoopConfFiles(file.getName()) = file + } + } + } + } + } + + val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + val confStream = new ZipOutputStream(new FileOutputStream(confArchive)) + + try { + confStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + if (file.canRead()) { + confStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, confStream) + confStream.closeEntry() + } + } + + // Save Spark configuration to a file in the archive. + val props = new Properties() + sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) + val writer = new OutputStreamWriter(confStream, UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + confStream.closeEntry() + } finally { + confStream.close() + } + confArchive + } + + /** + * Get the renewal interval for tokens. + */ + private def getTokenRenewalInterval(stagingDirPath: Path): Long = { + // We cannot use the tokens generated above since those have renewer yarn. Trying to renew + // those will fail with an access control issue. So create new tokens with the logged in + // user as renewer. + val creds = new Credentials() + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath + YarnSparkHadoopUtil.get.obtainTokensForNamenodes( + nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) + val t = creds.getAllTokens.asScala + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + .head + val newExpiration = t.renew(hadoopConf) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal Interval set to $interval") + interval + } + /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + private def setupLaunchEnv( + stagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - populateClasspath(args, yarnConf, sparkConf, env, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, true, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - - // Set the environment variables to be passed on to the executors. - distCacheMgr.setDistFilesEnv(env) - distCacheMgr.setDistArchivesEnv(env) + if (loginFromKeytab) { + val remoteFs = FileSystem.get(hadoopConf) + val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir) + val credentialsFile = "credentials-" + UUID.randomUUID().toString + sparkConf.set( + "spark.yarn.credentials.file", new Path(stagingDirPath, credentialsFile).toString) + logInfo(s"Credentials file set to: $credentialsFile") + val renewalInterval = getTokenRenewalInterval(stagingDirPath) + sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) + } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -337,6 +561,34 @@ private[spark] class Client( env("SPARK_YARN_USER_ENV") = userEnvs } + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH + // of the container processes too. Add all non-.py files directly to PYTHONPATH. + // + // NOTE: the code currently does not handle .py files defined with a "local:" scheme. + val pythonPath = new ListBuffer[String]() + val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + if (pyFiles.nonEmpty) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_PYTHON_DIR) + } + (pySparkArchives ++ pyArchives).foreach { path => + val uri = new URI(path) + if (uri.getScheme != LOCAL_SCHEME) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + new Path(path).getName()) + } else { + pythonPath += uri.getPath() + } + } + + // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. + if (pythonPath.nonEmpty) { + val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + .mkString(YarnSparkHadoopUtil.getClassPathSeparator) + env("PYTHONPATH") = pythonPathStr + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) + } + // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's // SparkContext will not let that set spark* system properties, which is expected behavior for @@ -382,14 +634,24 @@ private[spark] class Client( private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) : ContainerLaunchContext = { logInfo("Setting up container launch context for our AM") - val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(appStagingDir) + val pySparkArchives = + if (sparkConf.getBoolean("spark.yarn.isPython", false)) { + findPySparkArchives() + } else { + Nil + } + val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives) + val localResources = prepareLocalResources(appStagingDir, pySparkArchives) + + // Set the environment variables to be passed on to the executors. + distCacheMgr.setDistFilesEnv(launchEnv) + distCacheMgr.setDistArchivesEnv(launchEnv) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - amContainer.setEnvironment(launchEnv) + amContainer.setLocalResources(localResources.asJava) + amContainer.setEnvironment(launchEnv.asJava) val javaOpts = ListBuffer[String]() @@ -418,29 +680,25 @@ private[spark] class Client( // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:MaxTenuringThreshold=31" + javaOpts += "-XX:SurvivorRatio=8" javaOpts += "-XX:+CMSIncrementalMode" javaOpts += "-XX:+CMSIncrementalPacing" javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // Forward the Spark configuration to the application master / executors. - // TODO: it might be nicer to pass these as an internal environment variable rather than - // as Java options, due to complications with string parsing of nested quotes. - for ((k, v) <- sparkConf.getAll) { - javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") - } - // Include driver-specific java options if we are launching a driver if (isClusterMode) { - sparkConf.getOption("spark.driver.extraJavaOptions") + val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) - .map(Utils.splitCommandString).getOrElse(Seq.empty) - .foreach(opts => javaOpts += opts) + driverOpts.foreach { opts => + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + } val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -458,7 +716,11 @@ private[spark] class Client( val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts) + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + } + + sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -478,35 +740,36 @@ private[spark] class Client( Nil } val primaryPyFile = - if (args.primaryPyFile != null) { - Seq("--primary-py-file", args.primaryPyFile) + if (isClusterMode && args.primaryPyFile != null) { + Seq("--primary-py-file", new Path(args.primaryPyFile).getName()) } else { Nil } - val pyFiles = - if (args.pyFiles != null) { - Seq("--py-files", args.pyFiles) + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) } else { Nil } val amClass = if (isClusterMode) { - Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName + Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName + Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } - if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( @@ -519,10 +782,10 @@ private[spark] class Client( // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList - amContainer.setCommands(printableCommands) + amContainer.setCommands(printableCommands.asJava) logDebug("===============================================================================") - logDebug("Yarn AM launch context:") + logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } @@ -534,13 +797,40 @@ private[spark] class Client( // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) setupSecurityToken(amContainer) UserGroupInformation.getCurrentUser().addCredentials(credentials) amContainer } + def setupCredentials(): Unit = { + loginFromKeytab = args.principal != null || sparkConf.contains("spark.yarn.principal") + if (loginFromKeytab) { + principal = + if (args.principal != null) args.principal else sparkConf.get("spark.yarn.principal") + keytab = { + if (args.keytab != null) { + args.keytab + } else { + sparkConf.getOption("spark.yarn.keytab").orNull + } + } + + require(keytab != null, "Keytab must be specified when principal is specified.") + logInfo("Attempting to login to the Kerberos" + + s" using principal: $principal and keytab: $keytab") + val f = new File(keytab) + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + val keytabFileName = f.getName + "-" + UUID.randomUUID().toString + sparkConf.set("spark.yarn.keytab", keytabFileName) + sparkConf.set("spark.yarn.principal", principal) + } + credentials = UserGroupInformation.getCurrentUser.getCredentials + } + /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, @@ -560,42 +850,35 @@ private[spark] class Client( var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) - val report = getApplicationReport(appId) + val report: ApplicationReport = + try { + getApplicationReport(appId) + } catch { + case e: ApplicationNotFoundException => + logError(s"Application $appId not found.") + return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + case NonFatal(e) => + logError(s"Failed to contact YARN for application $appId.", e) + return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) + } val state = report.getYarnApplicationState if (logApplicationReport) { logInfo(s"Application report for $appId (state: $state)") - val details = Seq[(String, String)]( - ("client token", getClientToken(report)), - ("diagnostics", report.getDiagnostics), - ("ApplicationMaster host", report.getHost), - ("ApplicationMaster RPC port", report.getRpcPort.toString), - ("queue", report.getQueue), - ("start time", report.getStartTime.toString), - ("final status", report.getFinalApplicationStatus.toString), - ("tracking URL", report.getTrackingUrl), - ("user", report.getUser) - ) - - // Use more loggable format if value is null or empty - val formattedDetails = details - .map { case (k, v) => - val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") - s"\n\t $k: $newValue" } - .mkString("") // If DEBUG is enabled, log report details every iteration // Otherwise, log them every time the application changes state if (log.isDebugEnabled) { - logDebug(formattedDetails) + logDebug(formatReportDetails(report)) } else if (lastState != state) { - logInfo(formattedDetails) + logInfo(formatReportDetails(report)) } } if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + cleanupStagingDir(appId) return (state, report.getFinalApplicationStatus) } @@ -610,32 +893,81 @@ private[spark] class Client( throw new SparkException("While loop is depleted! This should never happen...") } + private def formatReportDetails(report: ApplicationReport): String = { + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + details.map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" + }.mkString("") + } + /** - * Submit an application to the ResourceManager and monitor its state. - * This continues until the application has exited for any reason. + * Submit an application to the ResourceManager. + * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive + * reporting the application's status until the application has exited for any reason. + * Otherwise, the client process will exit after submission. * If the application finishes with a failed, killed, or undefined status, * throw an appropriate SparkException. */ def run(): Unit = { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication()) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { - throw new SparkException("Application finished with failed status") - } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { - throw new SparkException("Application is killed") - } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { - throw new SparkException("The final status of application is undefined") + val appId = submitApplication() + if (fireAndForget) { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + logInfo(s"Application report for $appId (state: $state)") + logInfo(formatReportDetails(report)) + if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + throw new SparkException(s"Application $appId finished with status: $state") + } + } else { + val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) + if (yarnApplicationState == YarnApplicationState.FAILED || + finalApplicationStatus == FinalApplicationStatus.FAILED) { + throw new SparkException(s"Application $appId finished with failed status") + } + if (yarnApplicationState == YarnApplicationState.KILLED || + finalApplicationStatus == FinalApplicationStatus.KILLED) { + throw new SparkException(s"Application $appId is killed") + } + if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + throw new SparkException(s"The final status of application $appId is undefined") + } } } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + require(py4jFile.exists(), + "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) + } + } + } object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { - println("WARNING: This client is deprecated and will be removed in a " + + logWarning("WARNING: This client is deprecated and will be removed in a " + "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") } @@ -645,6 +977,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } @@ -669,6 +1005,10 @@ object Client extends Logging { // of the executors val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" + // Comma-separated list of strings to pass through as YARN application tags appearing + // in YARN ApplicationReports, which can be used for filtering when querying YARN. + val CONF_SPARK_YARN_APPLICATION_TAGS = "spark.yarn.tags" + // Staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) @@ -680,6 +1020,15 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" + // Subdirectory where the user's Spark and Hadoop config files will be placed. + val LOCALIZED_CONF_DIR = "__spark_conf__" + + // Name of the file in the conf archive containing Spark configuration. + val SPARK_CONF_FILE = "__spark_conf__.properties" + + // Subdirectory where the user's python files (not archives) will be placed. + val LOCALIZED_PYTHON_DIR = "__pyfiles__" + /** * Find the user-defined Spark jar if configured, or return the jar containing this * class if not. @@ -704,7 +1053,7 @@ object Client extends Logging { * Return the path to the given application's staging directory. */ private def getAppStagingDir(appId: ApplicationId): String = { - SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + buildPath(SPARK_STAGING, appId.toString()) } /** @@ -750,20 +1099,10 @@ object Client extends Logging { triedDefault.toOption } - /** - * In Hadoop 0.23, the MR application classpath comes with the YARN application - * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. - * So we need to use reflection to retrieve it. - */ private[yarn] def getDefaultMRApplicationClasspath: Option[Seq[String]] = { val triedDefault = Try[Seq[String]] { val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") - val value = if (field.getType == classOf[String]) { - StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray - } else { - field.get(null).asInstanceOf[Array[String]] - } - value.toSeq + StringUtils.getStrings(field.get(null).asInstanceOf[String]).toSeq } recoverWith { case e: NoSuchFieldException => Success(Seq.empty[String]) } @@ -780,60 +1119,66 @@ object Client extends Logging { /** * Populate the classpath entry in the given environment map. - * This includes the user jar, Spark jar, and any extra application jars. + * + * User jars are generally not added to the JVM's system classpath; those are handled by the AM + * and executor backend. When the deprecated `spark.yarn.user.classpath.first` is used, user jars + * are included in the system classpath, though. The extra class path and other uploaded files are + * always made available through the system class path. + * + * @param args Client arguments (when starting the AM) or null (when starting executors). */ private[yarn] def populateClasspath( args: ClientArguments, conf: Configuration, sparkConf: SparkConf, env: HashMap[String, String], + isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) - // Normally the users app.jar is last in case conflicts with spark jars - if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { - addUserClasspath(args, sparkConf, env) - addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - populateHadoopClasspath(conf, env) - } else { - addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - populateHadoopClasspath(conf, env) - addUserClasspath(args, sparkConf, env) + if (isAM) { + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + + LOCALIZED_CONF_DIR, env) } - // Append all jar files under the working directory to the classpath. - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + "*", env - ) + if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { + val userClassPath = + if (args != null) { + getUserClasspath(Option(args.userJar), Option(args.addJars)) + } else { + getUserClasspath(sparkConf) + } + userClassPath.foreach { x => + addFileToClasspath(sparkConf, x, null, env) + } + } + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) + populateHadoopClasspath(conf, env) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** - * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly - * to the classpath. + * Returns a list of URIs representing the user classpath. + * + * @param conf Spark configuration. */ - private def addUserClasspath( - args: ClientArguments, - conf: SparkConf, - env: HashMap[String, String]): Unit = { - - // If `args` is not null, we are launching an AM container. - // Otherwise, we are launching executor containers. - val (mainJar, secondaryJars) = - if (args != null) { - (args.userJar, args.addJars) - } else { - (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null)) - } + def getUserClasspath(conf: SparkConf): Array[URI] = { + getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR), + conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + } - addFileToClasspath(mainJar, APP_JAR, env) - if (secondaryJars != null) { - secondaryJars.split(",").filter(_.nonEmpty).foreach { jar => - addFileToClasspath(jar, null, env) - } - } + private def getUserClasspath( + mainJar: Option[String], + secondaryJars: Option[String]): Array[URI] = { + val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_)) + val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) + (mainUri ++ secondaryUris).toArray } /** @@ -844,27 +1189,21 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * - * @param path Path to add to classpath (optional). + * @param conf Spark configuration. + * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( - path: String, + conf: SparkConf, + uri: URI, fileName: String, env: HashMap[String, String]): Unit = { - if (path != null) { - scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { - val uri = new URI(path) - if (uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) - return - } - } - } - if (fileName != null) { - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + fileName, env - ) + if (uri != null && uri.getScheme == LOCAL_SCHEME) { + addClasspathEntry(getClusterPath(conf, uri.getPath), env) + } else if (fileName != null) { + addClasspathEntry(buildPath( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) } } @@ -876,41 +1215,124 @@ object Client extends Logging { YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) /** - * Get the list of namenodes the user may access. + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. */ - private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { - sparkConf.get("spark.yarn.access.namenodes", "") - .split(",") - .map(_.trim()) - .filter(!_.isEmpty) - .map(new Path(_)) - .toSet + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } } - private[yarn] def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) + /** + * Obtains token for the Hive metastore and adds them to the credentials. + */ + private def obtainTokenForHiveMetastore( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials) { + if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null) + + val hiveConf = hiveClass.getMethod("getConf").invoke(hive) + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + + val hiveConfGet = (param: String) => Option(hiveConfClass + .getMethod("get", classOf[java.lang.String]) + .invoke(hiveConf, param)) + + val metastore_uri = hiveConfGet("hive.metastore.uris") + + // Check for local metastore + if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val metastore_kerberos_principal_conf_var = mirror.classLoader + .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") + .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString + + val principal = hiveConfGet(metastore_kerberos_principal_conf_var) + + val username = Option(UserGroupInformation.getCurrentUser().getUserName) + if (principal != None && username != None) { + val tokenStr = hiveClass.getMethod("getDelegationToken", + classOf[java.lang.String], classOf[java.lang.String]) + .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] + + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + credentials.addToken(new Text("hive.server2.delegation.token"), hive2Token) + logDebug("Added hive.Server2.delegation.token to conf.") + hiveClass.getMethod("closeCurrent").invoke(null) + } else { + logError("Username or principal == NULL") + logError(s"""username=${username.getOrElse("(NULL)")}""") + logError(s"""principal=${principal.getOrElse("(NULL)")}""") + throw new IllegalArgumentException("username and/or principal is equal to null!") + } + } else { + logDebug("HiveMetaStore configured in localmode") + } + } catch { + case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e: Exception => { logError("Unexpected Exception " + e) + throw new RuntimeException("Unexpected exception", e) + } + } } - delegTokenRenewer } /** - * Obtains tokens for the namenodes passed in and adds them to the credentials. + * Obtain security token for HBase. */ - private def obtainTokensForNamenodes( - paths: Set[Path], + def obtainTokenForHBase( + sparkConf: SparkConf, conf: Configuration, - creds: Credentials): Unit = { - if (UserGroupInformation.isSecurityEnabled()) { - val delegTokenRenewer = getTokenRenewer(conf) - paths.foreach { dst => - val dstFs = dst.getFileSystem(conf) - logDebug("getting token for namenode: " + dst) - dstFs.addDelegationTokens(delegTokenRenewer, creds) + credentials: Credentials): Unit = { + if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + + logDebug("Attempting to fetch HBase security token.") + + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") + } + } catch { + case e: java.lang.NoSuchMethodException => + logInfo("HBase Method not found: " + e) + case e: java.lang.ClassNotFoundException => + logDebug("HBase Class not found: " + e) + case e: java.lang.NoClassDefFoundError => + logDebug("HBase Class not found: " + e) + case e: Exception => + logError("Exception when obtaining HBase security token: " + e) } } } @@ -960,4 +1382,32 @@ object Client extends Logging { new Path(qualifiedURI) } + /** + * Whether to consider jars provided by the user to have precedence over the Spark jars when + * loading user classes. + */ + def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = { + if (isDriver) { + conf.getBoolean("spark.driver.userClassPathFirst", false) + } else { + conf.getBoolean("spark.executor.userClassPathFirst", false) + } + } + + /** + * Joins all the path components using Path.SEPARATOR. + */ + def buildPath(components: String*): String = { + components.mkString(Path.SEPARATOR) + } + + /** + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ + def shouldGetTokens(conf: SparkConf, service: String): Boolean = { + conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3bc7eb1abf34..4f42ffefa77f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,8 +30,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: String = null + var pyFiles: Seq[String] = Nil var primaryPyFile: String = null + var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 @@ -41,17 +42,18 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var amCores: Int = 1 var appName: String = "Spark" var priority = 0 + var principal: String = null + var keytab: String = null def isClusterMode: Boolean = userClass != null - private var driverMemory: Int = 512 // MB + private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB private var driverCores: Int = 1 private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" private val amMemKey = "spark.yarn.am.memory" private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -95,6 +97,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** @@ -102,14 +110,19 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) * This is intended to be called only after the provided arguments have been parsed. */ private def validateArgs(): Unit = { - if (numExecutors <= 0) { + if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) { throw new IllegalArgumentException( - "You must specify at least 1 executor!\n" + getUsageMessage()) + s""" + |Number of executors was $numExecutors, but must be at least 1 + |(or 0 if dynamic executor allocation is enabled). + |${getUsageMessage()} + """.stripMargin) } if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { throw new SparkException("Executor cores must not be less than " + "spark.task.cpus.") } + // scalastyle:off println if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -131,11 +144,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .map(_.toInt) .foreach { cores => amCores = cores } } + // scalastyle:on println } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs + // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -150,6 +165,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") @@ -176,11 +195,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail @@ -211,7 +225,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) args = tail case ("--py-files") :: value :: tail => - pyFiles = value + pyFiles = value.split(",") args = tail case ("--files") :: value :: tail => @@ -222,29 +236,45 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) archives = value args = tail + case ("--principal") :: value :: tail => + principal = value + args = tail + + case ("--keytab") :: value :: tail => + keytab = value + args = tail + case Nil => case _ => throw new IllegalArgumentException(getUsageMessage(args)) } } + // scalastyle:on println + + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + - """ + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] |Options: | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster | mode) | --class CLASS_NAME Name of your application's main class (required) | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --name NAME The name of your application (Default: Spark) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index c592ecfdfce0..3d3a966960e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -95,13 +95,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -112,13 +112,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -160,7 +160,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def ancestorsHaveExecutePermissions( fs: FileSystem, path: Path, - statCache: Map[URI, FileStatus]): Boolean = { + statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { // the subdirs in the path should have execute permissions for others @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala new file mode 100644 index 000000000000..94feb6393fd6 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.concurrent.{Executors, TimeUnit} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{ThreadUtils, Utils} + +import scala.util.control.NonFatal + +private[spark] class ExecutorDelegationTokenUpdater( + sparkConf: SparkConf, + hadoopConf: Configuration) extends Logging { + + @volatile private var lastCredentialsFileSuffix = 0 + + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val freshHadoopConf = + SparkHadoopUtil.get.getConfBypassingFSCache( + hadoopConf, new Path(credentialsFile).toUri.getScheme) + + private val delegationTokenRenewer = + Executors.newSingleThreadScheduledExecutor( + ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread")) + + // On the executor, this thread wakes up and picks up new tokens from HDFS, if any. + private val executorUpdaterRunnable = + new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) + } + + def updateCredentialsIfRequired(): Unit = { + try { + val credentialsFilePath = new Path(credentialsFile) + val remoteFs = FileSystem.get(freshHadoopConf) + SparkHadoopUtil.get.listFilesSorted( + remoteFs, credentialsFilePath.getParent, + credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .lastOption.foreach { credentialsStatus => + val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) + if (suffix > lastCredentialsFileSuffix) { + logInfo("Reading new delegation tokens from " + credentialsStatus.getPath) + val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) + lastCredentialsFileSuffix = suffix + UserGroupInformation.getCurrentUser.addCredentials(newCredentials) + logInfo("Tokens updated from credentials file.") + } else { + // Check every hour to see if new credentials arrived. + logInfo("Updated delegation tokens were expected, but the driver has not updated the " + + "tokens yet, will check again in an hour.") + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS) + return + } + } + val timeFromNowToRenewal = + SparkHadoopUtil.get.getTimeFromNowToRenewal( + sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials) + if (timeFromNowToRenewal <= 0) { + executorUpdaterRunnable.run() + } else { + logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.") + delegationTokenRenewer.schedule( + executorUpdaterRunnable, timeFromNowToRenewal, TimeUnit.MILLISECONDS) + } + } catch { + // Since the file may get deleted while we are reading it, catch the Exception and come + // back in an hour to try again + case NonFatal(e) => + logWarning("Error while trying to update credentials, will try again in 1 hour", e) + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS) + } + } + + private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { + val stream = remoteFs.open(tokenPath) + try { + val newCredentials = new Credentials() + newCredentials.readTokenStorageStream(stream) + newCredentials + } finally { + stream.close() + } + } + + def stop(): Unit = { + delegationTokenRenewer.shutdown() + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ee2002a35f52..9abd09b3cc7a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -17,16 +17,16 @@ package org.apache.spark.deploy.yarn +import java.io.File import java.net.URI import java.nio.ByteBuffer +import java.util.Collections -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation @@ -37,8 +37,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils class ExecutorRunnable( container: Container, @@ -56,26 +57,26 @@ class ExecutorRunnable( var rpc: YarnRPC = YarnRPC.create(conf) var nmClient: NMClient = _ val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - lazy val env = prepareEnvironment - - def run = { + lazy val env = prepareEnvironment(container) + + override def run(): Unit = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() - startContainer + startContainer() } - def startContainer = { + def startContainer(): java.util.Map[String, ByteBuffer] = { logInfo("Setting up ContainerLaunchContext") val ctx = Records.newRecord(classOf[ContainerLaunchContext]) .asInstanceOf[ContainerLaunchContext] val localResources = prepareLocalResources - ctx.setLocalResources(localResources) + ctx.setLocalResources(localResources.asJava) - ctx.setEnvironment(env) + ctx.setEnvironment(env.asJava) val credentials = UserGroupInformation.getCurrentUser().getCredentials() val dob = new DataOutputBuffer() @@ -85,11 +86,19 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, appId, localResources) - logInfo(s"Setting up executor with environment: $env") - logInfo("Setting up executor with commands: " + commands) - ctx.setCommands(commands) + logInfo(s""" + |=============================================================================== + |YARN executor launch context: + | env: + |${env.map { case (k, v) => s" $k -> $v\n" }.mkString} + | command: + | ${commands.mkString(" ")} + |=============================================================================== + """.stripMargin) - ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + ctx.setCommands(commands.asJava) + ctx.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava) // If external shuffle service is enabled, register with the Yarn shuffle service already // started on the NodeManager and, if authentication is enabled, provide it with our secret @@ -104,11 +113,17 @@ class ExecutorRunnable( // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) } - ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes)) } // Send the start request to the ContainerManager - nmClient.startContainer(container, ctx) + try { + nmClient.startContainer(container, ctx) + } catch { + case ex: Exception => + throw new SparkException(s"Exception while starting container ${container.getId}" + + s" on host $hostname", ex) + } } private def prepareCommand( @@ -128,17 +143,18 @@ class ExecutorRunnable( // Set the JVM memory val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + javaOpts += "-Xms" + executorMemoryString + javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -173,17 +189,27 @@ class ExecutorRunnable( // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use // %20the%20Concurrent%20Low%20Pause%20Collector|outline - javaOpts += " -XX:+UseConcMarkSweepGC " - javaOpts += " -XX:+CMSIncrementalMode " - javaOpts += " -XX:+CMSIncrementalPacing " - javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " - javaOpts += " -XX:CMSIncrementalDutyCycle=10 " + javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:+CMSIncrementalMode" + javaOpts += "-XX:+CMSIncrementalPacing" + javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" + javaOpts += "-XX:CMSIncrementalDutyCycle=10" } */ // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => + val absPath = + if (new File(uri.getPath()).isAbsolute()) { + Client.getClusterPath(sparkConf, uri.getPath()) + } else { + Client.buildPath(Environment.PWD.$(), uri.getPath()) + } + Seq("--user-class-path", "file:" + absPath) + }.toSeq + val commands = prefixEnv ++ Seq( YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server", @@ -192,14 +218,16 @@ class ExecutorRunnable( // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? - "-XX:OnOutOfMemoryError='kill %p'") ++ + YarnSparkHadoopUtil.getOutOfMemoryErrorArgument) ++ javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", - masterAddress.toString, - slaveId.toString, - hostname.toString, - executorCores.toString, - appId, + "--driver-url", masterAddress.toString, + "--executor-id", slaveId.toString, + "--hostname", hostname.toString, + "--cores", executorCores.toString, + "--app-id", appId) ++ + userClassPath ++ + Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") @@ -254,10 +282,10 @@ class ExecutorRunnable( localResources } - private def prepareEnvironment: HashMap[String, String] = { + private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - Client.populateClasspath(null, yarnConf, sparkConf, env, extraCp) + Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp) sparkConf.getExecutorEnv.foreach { case (key, value) => // This assumes each executor environment variable set here is a path @@ -270,7 +298,25 @@ class ExecutorRunnable( YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + + // Add log urls + sys.env.get("SPARK_USER").foreach { user => + val containerId = ConverterUtils.toString(container.getId) + val address = container.getNodeHttpAddress + val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" + + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096" + env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" + } + + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } env } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala new file mode 100644 index 000000000000..081780204e42 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.mutable.{ArrayBuffer, HashMap, Set} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} +import org.apache.hadoop.yarn.util.RackResolver + +import org.apache.spark.SparkConf + +private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String]) + +/** + * This strategy is calculating the optimal locality preferences of YARN containers by considering + * the node ratio of pending tasks, number of required cores/containers and and locality of current + * existing containers. The target of this algorithm is to maximize the number of tasks that + * would run locally. + * + * Consider a situation in which we have 20 tasks that require (host1, host2, host3) + * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores + * and cpus per task is 1, so the required container number is 15, + * and host ratio is (host1: 30, host2: 30, host3: 20, host4: 10). + * + * 1. If requested container number (18) is more than the required container number (15): + * + * requests for 5 containers with nodes: (host1, host2, host3, host4) + * requests for 5 containers with nodes: (host1, host2, host3) + * requests for 5 containers with nodes: (host1, host2) + * requests for 3 containers with no locality preferences. + * + * The placement ratio is 3 : 3 : 2 : 1, and set the additional containers with no locality + * preferences. + * + * 2. If requested container number (10) is less than or equal to the required container number + * (15): + * + * requests for 4 containers with nodes: (host1, host2, host3, host4) + * requests for 3 containers with nodes: (host1, host2, host3) + * requests for 3 containers with nodes: (host1, host2) + * + * The placement ratio is 10 : 10 : 7 : 4, close to expected ratio (3 : 3 : 2 : 1) + * + * 3. If containers exist but none of them can match the requested localities, + * follow the method of 1 and 2. + * + * 4. If containers exist and some of them can match the requested localities. + * For example if we have 1 containers on each node (host1: 1, host2: 1: host3: 1, host4: 1), + * and the expected containers on each node would be (host1: 5, host2: 5, host3: 4, host4: 2), + * so the newly requested containers on each node would be updated to (host1: 4, host2: 4, + * host3: 3, host4: 1), 12 containers by total. + * + * 4.1 If requested container number (18) is more than newly required containers (12). Follow + * method 1 with updated ratio 4 : 4 : 3 : 1. + * + * 4.2 If request container number (10) is more than newly required containers (12). Follow + * method 2 with updated ratio 4 : 4 : 3 : 1. + * + * 5. If containers exist and existing localities can fully cover the requested localities. + * For example if we have 5 containers on each node (host1: 5, host2: 5, host3: 5, host4: 5), + * which could cover the current requested localities. This algorithm will allocate all the + * requested containers with no localities. + */ +private[yarn] class LocalityPreferredContainerPlacementStrategy( + val sparkConf: SparkConf, + val yarnConf: Configuration, + val resource: Resource) { + + // Number of CPUs per task + private val CPUS_PER_TASK = sparkConf.getInt("spark.task.cpus", 1) + + /** + * Calculate each container's node locality and rack locality + * @param numContainer number of containers to calculate + * @param numLocalityAwareTasks number of locality required tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return node localities and rack localities, each locality is an array of string, + * the length of localities is the same as number of containers + */ + def localityOfRequestedContainers( + numContainer: Int, + numLocalityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Array[ContainerLocalityPreferences] = { + val updatedHostToContainerCount = expectedHostToContainerCount( + numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap) + val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum + + // The number of containers to allocate, divided into two groups, one with preferred locality, + // and the other without locality preference. + val requiredLocalityFreeContainerNum = + math.max(0, numContainer - updatedLocalityAwareContainerNum) + val requiredLocalityAwareContainerNum = numContainer - requiredLocalityFreeContainerNum + + val containerLocalityPreferences = ArrayBuffer[ContainerLocalityPreferences]() + if (requiredLocalityFreeContainerNum > 0) { + for (i <- 0 until requiredLocalityFreeContainerNum) { + containerLocalityPreferences += ContainerLocalityPreferences( + null.asInstanceOf[Array[String]], null.asInstanceOf[Array[String]]) + } + } + + if (requiredLocalityAwareContainerNum > 0) { + val largestRatio = updatedHostToContainerCount.values.max + // Round the ratio of preferred locality to the number of locality required container + // number, which is used for locality preferred host calculating. + var preferredLocalityRatio = updatedHostToContainerCount.mapValues { ratio => + val adjustedRatio = ratio.toDouble * requiredLocalityAwareContainerNum / largestRatio + adjustedRatio.ceil.toInt + } + + for (i <- 0 until requiredLocalityAwareContainerNum) { + // Only filter out the ratio which is larger than 0, which means the current host can + // still be allocated with new container request. + val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray + val racks = hosts.map { h => + RackResolver.resolve(yarnConf, h).getNetworkLocation + }.toSet + containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) + + // Minus 1 each time when the host is used. When the current ratio is 0, + // which means all the required ratio is satisfied, this host will not be allocated again. + preferredLocalityRatio = preferredLocalityRatio.mapValues(_ - 1) + } + } + + containerLocalityPreferences.toArray + } + + /** + * Calculate the number of executors need to satisfy the given number of pending tasks. + */ + private def numExecutorsPending(numTasksPending: Int): Int = { + val coresPerExecutor = resource.getVirtualCores + (numTasksPending * CPUS_PER_TASK + coresPerExecutor - 1) / coresPerExecutor + } + + /** + * Calculate the expected host to number of containers by considering with allocated containers. + * @param localityAwareTasks number of locality aware tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return a map with hostname as key and required number of containers on this host as value + */ + private def expectedHostToContainerCount( + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Map[String, Int] = { + val totalLocalTaskNum = hostToLocalTaskCount.values.sum + hostToLocalTaskCount.map { case (host, count) => + val expectedCount = + count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum + val existedCount = allocatedHostToContainersMap.get(host) + .map(_.size) + .getOrElse(0) + + // If existing container can not fully satisfy the expected number of container, + // the required container number is expected count minus existed count. Otherwise the + // required container number is 0. + (host, math.max(0, (expectedCount - existedCount).ceil.toInt)) + } + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 0dbb6154b303..5f897cbcb4e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,7 +21,7 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -34,10 +34,11 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.AkkaUtils +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.util.Utils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -53,6 +54,8 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + driverUrl: String, + driverRef: RpcEndpointRef, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -69,8 +72,7 @@ private[yarn] class YarnAllocator( } // Visible for testing. - val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]] + val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] // Containers that we no longer care about. We've either already told the RM to release them or @@ -84,10 +86,19 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var maxExecutors = args.numExecutors + @volatile private var targetNumExecutors = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) + } else { + sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) + } // Keep track of which container is running which executor to remove the executors later - private val executorIdToContainer = new HashMap[String, Container] + // Visible for testing. + private[yarn] val executorIdToContainer = new HashMap[String, Container] + + private var numUnexpectedContainerRelease = 0L + private val containerIdToExecutorId = new HashMap[ContainerId, String] // Executor memory in MB. protected val executorMemory = args.executorMemory @@ -97,7 +108,7 @@ private[yarn] class YarnAllocator( // Number of cores per executor. protected val executorCores = args.executorCores // Resource capability requested for each executors - private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -107,16 +118,37 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) + private val labelExpression = sparkConf.getOption("spark.yarn.executor.nodeLabelExpression") + + // ContainerRequest constructor that can take a node label expression. We grab it through + // reflection because it's only available in later versions of YARN. + private val nodeLabelConstructor = labelExpression.flatMap { expr => + try { + Some(classOf[ContainerRequest].getConstructor(classOf[Resource], + classOf[Array[String]], classOf[Array[String]], classOf[Priority], classOf[Boolean], + classOf[String])) + } catch { + case e: NoSuchMethodException => { + logWarning(s"Node label expression $expr will be ignored because YARN version on" + + " classpath does not support it.") + None + } + } + } + + // A map to store preferred hostname and possible task numbers running on it. + private var hostToLocalTaskCounts: Map[String, Int] = Map.empty + + // Number of tasks that have locality preferences in active stages + private var numLocalityAwareTasks: Int = 0 + + // A container placement strategy based on pending tasks' locality preference + private[yarn] val containerPlacementStrategy = + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + def getNumExecutorsRunning: Int = numExecutorsRunning def getNumExecutorsFailed: Int = numExecutorsFailed @@ -130,13 +162,32 @@ private[yarn] class YarnAllocator( * Number of container requests at the given location that have not yet been fulfilled. */ private def getNumPendingAtLocation(location: String): Int = - amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala.map(_.size).sum /** - * Request as many executors from the ResourceManager as needed to reach the desired total. + * Request as many executors from the ResourceManager as needed to reach the desired total. If + * the requested total is smaller than the current number of running executors, no executors will + * be killed. + * @param requestedTotal total number of containers requested + * @param localityAwareTasks number of locality aware tasks to be used as container placement hint + * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as + * container placement hint. + * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { - maxExecutors = requestedTotal + def requestTotalExecutorsWithPreferredLocalities( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean = synchronized { + this.numLocalityAwareTasks = localityAwareTasks + this.hostToLocalTaskCounts = hostToLocalTaskCount + + if (requestedTotal != targetNumExecutors) { + logInfo(s"Driver requested a total number of $requestedTotal executor(s).") + targetNumExecutors = requestedTotal + true + } else { + false + } } /** @@ -145,10 +196,9 @@ private[yarn] class YarnAllocator( def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get + containerIdToExecutorId.remove(container.getId) internalReleaseContainer(container) numExecutorsRunning -= 1 - maxExecutors -= 1 - assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!") } else { logWarning(s"Attempted to kill unknown executor $executorId!") } @@ -163,15 +213,8 @@ private[yarn] class YarnAllocator( * This must be synchronized because variables read in this method are mutated by other methods. */ def allocateResources(): Unit = synchronized { - val numPendingAllocate = getNumPendingAllocate - val missing = maxExecutors - numPendingAllocate - numExecutorsRunning + updateResourceRequests() - if (missing > 0) { - logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + - s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") - } - - addResourceRequests(missing) val progressIndicator = 0.1f // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container // requests. @@ -186,14 +229,14 @@ private[yarn] class YarnAllocator( numExecutorsRunning, allocateResponse.getAvailableResources)) - handleAllocatedContainers(allocatedContainers) + handleAllocatedContainers(allocatedContainers.asScala) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - processCompletedContainers(completedContainers) + processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) @@ -201,18 +244,62 @@ private[yarn] class YarnAllocator( } /** - * Request numExecutors additional containers from YARN. Visible for testing. + * Update the set of container requests that we will sync with the RM based on the number of + * executors we have currently running and our target number of executors. + * + * Visible for testing. */ - def addResourceRequests(numExecutors: Int): Unit = { - for (i <- 0 until numExecutors) { - val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) - amClient.addContainerRequest(request) - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last - logInfo("Container request (host: %s, capability: %s".format(hostStr, resource)) + def updateResourceRequests(): Unit = { + val numPendingAllocate = getNumPendingAllocate + val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning + + // TODO. Consider locality preferences of pending container requests. + // Since the last time we made container requests, stages have completed and been submitted, + // and that the localities at which we requested our pending executors + // no longer apply to our current needs. We should consider to remove all outstanding + // container requests and add requests anew each time to avoid this. + if (missing > 0) { + logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") + + val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( + missing, numLocalityAwareTasks, hostToLocalTaskCounts, allocatedHostToContainersMap) + + for (locality <- containerLocalityPreferences) { + val request = createContainerRequest(resource, locality.nodes, locality.racks) + amClient.addContainerRequest(request) + val nodes = request.getNodes + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last + logInfo(s"Container request (host: $hostStr, capability: $resource)") + } + } else if (missing < 0) { + val numToCancel = math.min(numPendingAllocate, -missing) + logInfo(s"Canceling requests for $numToCancel executor containers") + + val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) + if (!matchingRequests.isEmpty) { + matchingRequests.iterator().next().asScala + .take(numToCancel).foreach(amClient.removeContainerRequest) + } else { + logWarning("Expected to find pending requests, but found none.") + } } } + /** + * Creates a container request, handling the reflection required to use YARN features that were + * added in recent versions. + */ + protected def createContainerRequest( + resource: Resource, + nodes: Array[String], + racks: Array[String]): ContainerRequest = { + nodeLabelConstructor.map { constructor => + constructor.newInstance(resource, nodes, racks, RM_REQUEST_PRIORITY, true: java.lang.Boolean, + labelExpression.orNull) + }.getOrElse(new ContainerRequest(resource, nodes, racks, RM_REQUEST_PRIORITY)) + } + /** * Handle containers granted by the RM by launching executors on them. * @@ -266,7 +353,7 @@ private[yarn] class YarnAllocator( * containersToUse or remaining. * * @param allocatedContainer container that was given to us by YARN - * @location resource name, either a node, rack, or * + * @param location resource name, either a node, rack, or * * @param containersToUse list of containers that will be used * @param remaining list of containers that will not be used */ @@ -275,8 +362,14 @@ private[yarn] class YarnAllocator( location: String, containersToUse: ArrayBuffer[Container], remaining: ArrayBuffer[Container]): Unit = { + // SPARK-6050: certain Yarn configurations return a virtual core count that doesn't match the + // request; for example, capacity scheduler + DefaultResourceCalculator. So match on requested + // memory, but use the asked vcore count for matching, effectively disabling matching on vcore + // count. + val matchingResource = Resource.newInstance(allocatedContainer.getResource.getMemory, + resource.getVirtualCores) val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location, - allocatedContainer.getResource) + matchingResource) // Match the allocation to a request if (!matchingRequests.isEmpty) { @@ -294,7 +387,7 @@ private[yarn] class YarnAllocator( private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = { for (container <- containersToUse) { numExecutorsRunning += 1 - assert(numExecutorsRunning <= maxExecutors) + assert(numExecutorsRunning <= targetNumExecutors) val executorHostname = container.getNodeId.getHost val containerId = container.getId executorIdCounter += 1 @@ -303,7 +396,8 @@ private[yarn] class YarnAllocator( assert(container.getResource.getMemory >= resource.getMemory) logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) - executorIdToContainer(executorId) = container + executorIdToContainer(executorId) = container + containerIdToExecutorId(container.getId) = executorId val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -330,15 +424,12 @@ private[yarn] class YarnAllocator( } } - private def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { + // Visible for testing. + private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - - if (releasedContainers.contains(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { + val alreadyReleased = releasedContainers.remove(containerId) + if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -349,7 +440,9 @@ private[yarn] class YarnAllocator( // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == -103) { // vmem limit exceeded + if (completedContainer.getExitStatus == ContainerExitStatus.PREEMPTED) { + logInfo("Container preempted: " + containerId) + } else if (completedContainer.getExitStatus == -103) { // vmem limit exceeded logWarning(memLimitExceededLogMessage( completedContainer.getDiagnostics, VMEM_EXCEEDED_PATTERN)) @@ -365,7 +458,7 @@ private[yarn] class YarnAllocator( } } - if (allocatedContainerToHostMap.containsKey(containerId)) { + if (allocatedContainerToHostMap.contains(containerId)) { val host = allocatedContainerToHostMap.get(containerId).get val containerSet = allocatedHostToContainersMap.get(host).get @@ -378,6 +471,18 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.remove(containerId) } + + containerIdToExecutorId.remove(containerId).foreach { eid => + executorIdToContainer.remove(eid) + + if (!alreadyReleased) { + // The executor could have gone away (like no route to host, node failure, etc) + // Notify backend about the failure of the executor + numUnexpectedContainerRelease += 1 + driverRef.send(RemoveExecutor(eid, + s"Yarn deallocated the executor $eid (container $containerId)")) + } + } } } @@ -386,6 +491,8 @@ private[yarn] class YarnAllocator( amClient.releaseAssignedContainer(container.getId()) } + private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + } private object YarnAllocator { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b13475136652..df042bf291de 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,20 +19,19 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -55,6 +54,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + driverUrl: String, + driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, + securityMgr) } /** @@ -89,9 +91,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { - val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - containerId.getApplicationAttemptId() + YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId() } /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ @@ -106,8 +106,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", classOf[Configuration]) val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] - val hosts = proxies.map { proxy => proxy.split(":")(0) } - val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + val hosts = proxies.asScala.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) } catch { case e: NoSuchMethodException => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 146b2c0f1a30..445d3dcd266d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -24,18 +24,21 @@ import java.util.regex.Pattern import scala.collection.mutable.HashMap import scala.util.Try +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} -import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.util.ConverterUtils -import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.launcher.YarnCommandBuilderUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils /** @@ -43,6 +46,8 @@ import org.apache.spark.util.Utils */ class YarnSparkHadoopUtil extends SparkHadoopUtil { + private var tokenRenewer: Option[ExecutorDelegationTokenUpdater] = None + override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { dest.addCredentials(source.getCredentials()) } @@ -82,14 +87,69 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { if (credentials != null) credentials.getSecretKey(new Text(key)) else null } + /** + * Get the list of namenodes the user may access. + */ + def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "") + .split(",") + .map(_.trim()) + .filter(!_.isEmpty) + .map(new Path(_)) + .toSet + } + + def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + def obtainTokensForNamenodes( + paths: Set[Path], + conf: Configuration, + creds: Credentials, + renewer: Option[String] = None + ): Unit = { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = renewer.getOrElse(getTokenRenewer(conf)) + paths.foreach { dst => + val dstFs = dst.getFileSystem(conf) + logInfo("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + + private[spark] override def startExecutorDelegationTokenRenewer(sparkConf: SparkConf): Unit = { + tokenRenewer = Some(new ExecutorDelegationTokenUpdater(sparkConf, conf)) + tokenRenewer.get.updateCredentialsIfRequired() + } + + private[spark] override def stopExecutorDelegationTokenRenewer(): Unit = { + tokenRenewer.foreach(_.stop()) + } + + private[spark] def getContainerId: ContainerId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + ConverterUtils.toContainerId(containerIdString) + } } object YarnSparkHadoopUtil { - // Additional memory overhead - // 7% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // Additional memory overhead + // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering + // the common cases. Memory overhead tends to grow with container size. - val MEMORY_OVERHEAD_FACTOR = 0.07 + val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 val ANY_HOST = "*" @@ -100,6 +160,14 @@ object YarnSparkHadoopUtil { // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) + def get: YarnSparkHadoopUtil = { + val yarnMode = java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (!yarnMode) { + throw new SparkException("YarnSparkHadoopUtil is not available in non-YARN mode!") + } + SparkHadoopUtil.get.asInstanceOf[YarnSparkHadoopUtil] + } /** * Add a path variable to the given environment map. * If the map already contains this key, append the value to the existing value instead. @@ -152,26 +220,61 @@ object YarnSparkHadoopUtil { } } + /** + * The handler if an OOM Exception is thrown by the JVM must be configured on Windows + * differently: the 'taskkill' command should be used, whereas Unix-based systems use 'kill'. + * + * As the JVM interprets both %p and %%p as the same, we can use either of them. However, + * some tests on Windows computers suggest, that the JVM only accepts '%%p'. + * + * Furthermore, the behavior of the character '%' on the Windows command line differs from + * the behavior of '%' in a .cmd file: it gets interpreted as an incomplete environment + * variable. Windows .cmd files escape a '%' by '%%'. Thus, the correct way of writing + * '%%p' in an escaped way is '%%%%p'. + * + * @return The correct OOM Error handler JVM option, platform dependent. + */ + def getOutOfMemoryErrorArgument : String = { + if (Utils.isWindows) { + escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p") + } else { + "-XX:OnOutOfMemoryError='kill %p'" + } + } + /** * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands - * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The - * argument is enclosed in single quotes and some key characters are escaped. + * using either + * + * (Unix-based) `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. + * The argument is enclosed in single quotes and some key characters are escaped. + * + * (Windows-based) part of a .cmd file in which case windows escaping for each argument must be + * applied. Windows is quite lenient, however it is usually Java that causes trouble, needing to + * distinguish between arguments starting with '-' and class names. If arguments are surrounded + * by ' java takes the following string as is, hence an argument is mistakenly taken as a class + * name which happens to start with a '-'. The way to avoid this, is to surround nothing with + * a ', but instead with a ". * * @param arg A single argument. * @return Argument quoted for execution via Yarn's generated shell script. */ def escapeForShell(arg: String): String = { if (arg != null) { - val escaped = new StringBuilder("'") - for (i <- 0 to arg.length() - 1) { - arg.charAt(i) match { - case '$' => escaped.append("\\$") - case '"' => escaped.append("\\\"") - case '\'' => escaped.append("'\\''") - case c => escaped.append(c) + if (Utils.isWindows) { + YarnCommandBuilderUtils.quoteForBatchScript(arg) + } else { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } } + escaped.append("'").toString() } - escaped.append("'").toString() } else { arg } @@ -212,3 +315,4 @@ object YarnSparkHadoopUtil { classPathSeparatorField.get(null).asInstanceOf[String] } } + diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala new file mode 100644 index 000000000000..3ac36ef0a1c3 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +/** + * Exposes needed methods + */ +private[spark] object YarnCommandBuilderUtils { + def quoteForBatchScript(arg: String) : String = { + CommandBuilderUtils.quoteForBatchScript(arg) + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 690f927e938c..d06d95140438 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -33,14 +33,13 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - @volatile private var stopping: Boolean = false + private var monitorThread: MonitorThread = null /** * Create a Yarn client to submit an application to the ResourceManager. * This waits until the application is running. */ override def start() { - super.start() val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort @@ -55,8 +54,22 @@ private[spark] class YarnClientSchedulerBackend( totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.submitApplication() + + // SPARK-8687: Ensure all necessary properties have already been set before + // we initialize our driver scheduler backend, which serves these properties + // to the executors + super.start() + waitForApplication() - asyncMonitorApplication() + + // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver + // reads the credentials from HDFS, just like the executors and updates its own credentials + // cache. + if (conf.contains("spark.yarn.credentials.file")) { + YarnSparkHadoopUtil.get.startExecutorDelegationTokenRenewer(conf) + } + monitorThread = asyncMonitorApplication() + monitorThread.start() } /** @@ -68,29 +81,21 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--py-files", null, "spark.submit.pyFiles") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_MASTER_MEMORY" -> "SPARK_DRIVER_MEMORY or --driver-memory through spark-submit", - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") - // Do the same for deprecated properties: property -> suggestion - val deprecatedProps = Map("spark.master.memory" -> "--driver-memory through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - if (deprecatedProps.contains(sparkProp)) { - logWarning(s"NOTE: $sparkProp is deprecated. Use ${deprecatedProps(sparkProp)} instead.") - } - } else if (System.getenv(envVar) != null) { + } else if (envVar != null && System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") @@ -123,33 +128,45 @@ private[spark] class YarnClientSchedulerBackend( } } + /** + * We create this class for SPARK-9519. Basically when we interrupt the monitor thread it's + * because the SparkContext is being shut down(sc.stop() called by user code), but if + * monitorApplication return, it means the Yarn application finished before sc.stop() was called, + * which means we should call sc.stop() here, and we don't allow the monitor to be interrupted + * before SparkContext stops successfully. + */ + private class MonitorThread extends Thread { + private var allowInterrupt = true + + override def run() { + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + allowInterrupt = false + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") + } + } + + def stopMonitor(): Unit = { + if (allowInterrupt) { + this.interrupt() + } + } + } + /** * Monitor the application state in a separate thread. * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Unit = { + private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId != null, "Application has not been submitted yet!") - val t = new Thread { - override def run() { - while (!stopping) { - val report = client.getApplicationReport(appId) - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.KILLED || - state == YarnApplicationState.FAILED) { - logError(s"Yarn application has already exited with state $state!") - sc.stop() - stopping = true - } - Thread.sleep(1000L) - } - Thread.currentThread().interrupt() - } - } + val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) - t.start() + t } /** @@ -157,9 +174,12 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - stopping = true + if (monitorThread != null) { + monitorThread.stopMonitor() + } super.stop() client.stop() + YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() logInfo("Stopped") } @@ -169,5 +189,4 @@ private[spark] class YarnClientSchedulerBackend( super.applicationId } } - } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index b1de81e6a8b0..1aed5a167507 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -17,10 +17,21 @@ package org.apache.spark.scheduler.cluster +import java.net.NetworkInterface + +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.yarn.api.records.NodeState +import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.conf.YarnConfiguration + import org.apache.spark.SparkContext +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.IntParam +import org.apache.spark.util.{IntParam, Utils} private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -39,12 +50,46 @@ private[spark] class YarnClusterSchedulerBackend( } override def applicationId(): String = - // In YARN Cluster mode, spark.yarn.app.id is expect to be set - // before user application is launched. - // So, if spark.yarn.app.id is not set, it is something wrong. + // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's + // not found. sc.getConf.getOption("spark.yarn.app.id").getOrElse { logError("Application ID is not set.") super.applicationId } + override def applicationAttemptId(): Option[String] = + // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's + // not found. + sc.getConf.getOption("spark.yarn.app.attemptId").orElse { + logError("Application attempt ID is not set.") + super.applicationAttemptId + } + + override def getDriverLogUrls: Option[Map[String, String]] = { + var driverLogs: Option[Map[String, String]] = None + try { + val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) + val containerId = YarnSparkHadoopUtil.get.getContainerId + + val httpAddress = System.getenv(Environment.NM_HOST.name()) + + ":" + System.getenv(Environment.NM_HTTP_PORT.name()) + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) + } catch { + case e: Exception => + logInfo("Error while building AM log links, so AM" + + " logs link will not appear in application UI", e) + } + driverLogs + } } diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 287c8e356350..6b8a5dbf6373 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +log4j.rootCategory=DEBUG, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.apache.hadoop=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala new file mode 100644 index 000000000000..b4f8049bff57 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, Matchers} + +import org.apache.spark._ +import org.apache.spark.util.Utils + +abstract class BaseYarnClusterSuite + extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { + + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. + protected val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + protected var tempDir: File = _ + private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ + private var logConfDir: File = _ + + + def yarnConfig: YarnConfiguration + + override def beforeAll() { + super.beforeAll() + + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, UTF_8) + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(yarnConfig) + yarnCluster.start() + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") + super.afterAll() + } + + protected def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val childClasspath = logConfDir.getAbsolutePath() + + File.pathSeparator + + sys.props("java.class.path") + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator) + props.setProperty("spark.driver.extraClassPath", childClasspath) + props.setProperty("spark.executor.extraClassPath", childClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig.asScala.foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + protected def checkResult(result: File, expected: String): Unit = { + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3..804dfecde786 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 2bb3dcffd61d..e7f2501e7899 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,26 +20,36 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap => MutableHashMap} +import scala.reflect.ClassTag +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.YarnClientApplication import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.Records import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.FunSuite -import org.scalatest.Matchers - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } -import scala.reflect.ClassTag -import scala.util.Try +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + System.setProperty("SPARK_YARN_MODE", "true") + } + + override def afterAll(): Unit = { + System.clearProperty("SPARK_YARN_MODE") + } test("default Yarn application classpath") { Client.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) @@ -82,10 +92,11 @@ class ClientSuite extends FunSuite with Matchers { test("Local jar URIs") { val conf = new Configuration() val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK) + .set("spark.yarn.user.classpath.first", "true") val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - Client.populateClasspath(args, conf, sparkConf, env) + Client.populateClasspath(args, conf, sparkConf, env, true) val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => @@ -96,16 +107,16 @@ class ClientSuite extends FunSuite with Matchers { cp should not contain (uri.getPath()) } }) - if (classOf[Environment].getMethods().exists(_.getName == "$$")) { - cp should contain("{{PWD}}") - cp should contain(s"{{PWD}}${Path.SEPARATOR}*") - } else if (Utils.isWindows) { - cp should contain("%PWD%") - cp should contain(s"%PWD%${Path.SEPARATOR}*") - } else { - cp should contain(Environment.PWD.$()) - cp should contain(s"${Environment.PWD.$()}${File.separator}*") - } + val pwdVar = + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + "{{PWD}}" + } else if (Utils.isWindows) { + "%PWD%" + } else { + Environment.PWD.$() + } + cp should contain(pwdVar) + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -117,11 +128,11 @@ class ClientSuite extends FunSuite with Matchers { val client = spy(new Client(args, conf, sparkConf)) doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), - any(classOf[Path]), anyShort(), anyBoolean()) + any(classOf[Path]), anyShort()) val tempDir = Utils.createTempDir() try { - client.prepareLocalResources(tempDir.getAbsolutePath()) + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's @@ -143,55 +154,56 @@ class ClientSuite extends FunSuite with Matchers { } } - test("check access nns empty") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set()) - } - - test("check access nns unset") { + test("Cluster path translation") { + val conf = new Configuration() val sparkConf = new SparkConf() - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set()) - } + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") - test("check access nns") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"))) - } + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") - test("check access nns space") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"))) + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") } - test("check access two nns") { + test("configuration and args propagate through createApplicationSubmissionContext") { + val conf = new Configuration() + // When parsing tags, duplicates and leading/trailing whitespace should be removed. + // Spaces between non-comma strings should be preserved as single tags. Empty strings may or + // may not be removed depending on the version of Hadoop being used. val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) - } - - test("check token renewer") { - val hadoopConf = new Configuration() - hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") - hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") - val renewer = Client.getTokenRenewer(hadoopConf) - renewer should be ("yarn/myrm:8032@SPARKTEST.COM") - } - - test("check token renewer default") { - val hadoopConf = new Configuration() - val caught = - intercept[SparkException] { - Client.getTokenRenewer(hadoopConf) - } - assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + .set(Client.CONF_SPARK_YARN_APPLICATION_TAGS, ",tag1, dup,tag2 , ,multi word , dup") + .set("spark.yarn.maxAppAttempts", "42") + val args = new ClientArguments(Array( + "--name", "foo-test-app", + "--queue", "staging-queue"), sparkConf) + + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) + val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext]) + + val client = new Client(args, conf, sparkConf) + client.createApplicationSubmissionContext( + new YarnClientApplication(getNewApplicationResponse, appContext), + containerLaunchContext) + + appContext.getApplicationName should be ("foo-test-app") + appContext.getQueue should be ("staging-queue") + appContext.getAMContainerSpec should be (containerLaunchContext) + appContext.getApplicationType should be ("SPARK") + appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => + val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] + tags should contain allOf ("tag1", "dup", "tag2", "multi word") + tags.asScala.filter(_.nonEmpty).size should be (4) + } + appContext.getMaxAppAttempts should be (42) } object Fixtures { @@ -227,19 +239,26 @@ class ClientSuite extends FunSuite with Matchers { testCode(conf) } - def newEnv = MutableHashMap[String, String]() + def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]() - def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;|") + def classpath(env: MutableHashMap[String, String]): Array[String] = + env(Environment.CLASSPATH.name).split(":|;|") - def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] = + (a ++ b).flatten.toArray - def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = - Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = { + Try(clazz.getField(field)) + .map(_.get(null).asInstanceOf[A]) + .toOption + .map(mapTo) + .getOrElse(defaults) + } def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala new file mode 100644 index 000000000000..b7fe4ccc67a3 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite + +class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + + private val yarnAllocatorSuite = new YarnAllocatorSuite + import yarnAllocatorSuite._ + + override def beforeEach() { + yarnAllocatorSuite.beforeEach() + } + + override def afterEach() { + yarnAllocatorSuite.afterEach() + } + + test("allocate locality preferred containers with enough resource and no matched existed " + + "containers") { + // 1. All the locations of current containers cannot satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array( + Array("host3", "host4", "host5"), + Array("host3", "host4", "host5"), + Array("host3", "host4"))) + } + + test("allocate locality preferred containers with enough resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === + Array(null, Array("host2", "host3"), Array("host2", "host3"))) + } + + test("allocate locality preferred containers with limited resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number cannot fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(Array("host2", "host3"))) + } + + test("allocate locality preferred containers with fully matched containers") { + // Current containers' locations can fully satisfy the new requirements + + val handler = createAllocator(5) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2"), + createContainer("host2"), + createContainer("host3") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null, null, null)) + } + + test("allocate containers with no locality preference") { + // Request new container without locality preference + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 0, Map.empty, handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null)) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 024b25f9d336..5d05f514adde 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,15 +25,18 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.apache.spark.SecurityManager +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.mockito.Mockito._ + +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} - class MockResolver extends DNSToSwitchMapping { override def resolve(names: JList[String]): JList[String] = { @@ -46,7 +49,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, @@ -79,19 +82,22 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { - override def equals(other: Any) = false + override def equals(other: Any): Boolean = false } def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( + "not used", + mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), @@ -107,8 +113,8 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("single container allocated") { // request a single container and receive it - val handler = createAllocator() - handler.addResourceRequests(1) + val handler = createAllocator(1) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (1) @@ -118,13 +124,15 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + + val size = rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size + size should be (0) } test("some containers allocated") { // request a few containers and receive some of them - val handler = createAllocator() - handler.addResourceRequests(4) + val handler = createAllocator(4) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) @@ -144,7 +152,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("receive more containers than requested") { val handler = createAllocator(2) - handler.addResourceRequests(2) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (2) @@ -162,6 +170,96 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.allocatedHostToContainersMap.contains("host4") should be (false) } + test("decrease total requested executors") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (3) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (1) + } + + test("decrease total requested executors to less than currently running") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (3) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.getNumExecutorsRunning should be (2) + + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (0) + handler.getNumExecutorsRunning should be (2) + } + + test("kill executors") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) + handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (1) + } + + test("lost executor removed from backend") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + handler.getNumExecutorsFailed should be (2) + handler.getNumUnexpectedContainerRelease should be (2) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + @@ -172,5 +270,4 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) } - } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 7165918e1bfc..5a4ea2ea2f4f 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,48 +18,43 @@ package org.apache.spark.deploy.yarn import java.io.File -import java.util.concurrent.TimeUnit +import java.net.URL -import scala.collection.JavaConversions._ - -import com.google.common.base.Charsets -import com.google.common.io.Files -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import scala.collection.mutable +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.Matchers -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +/** + * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN + * applications, and require the Spark assembly to be built before they can be successfully + * run. + */ +class YarnClusterSuite extends BaseYarnClusterSuite { - // log4j configuration for the Yarn containers, so that their output is collected - // by Yarn instead of trying to overwrite unit-tests.log. - private val LOG4J_CONF = """ - |log4j.rootCategory=DEBUG, console - |log4j.appender.console=org.apache.log4j.ConsoleAppender - |log4j.appender.console.target=System.err - |log4j.appender.console.layout=org.apache.log4j.PatternLayout - |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin + override def yarnConfig: YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ + |import mod1, mod2 |import sys |from operator import add | |from pyspark import SparkConf , SparkContext |if __name__ == "__main__": - | if len(sys.argv) != 3: - | print >> sys.stderr, "Usage: test.py [master] [result file]" + | if len(sys.argv) != 2: + | print >> sys.stderr, "Usage: test.py [result file]" | exit(-1) - | conf = SparkConf() - | conf.setMaster(sys.argv[1]).setAppName("python test in yarn cluster mode") - | sc = SparkContext(conf=conf) - | status = open(sys.argv[2],'w') + | sc = SparkContext(conf=SparkConf()) + | status = open(sys.argv[1],'w') | result = "failure" - | rdd = sc.parallelize(range(10)) + | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func()) | cnt = rdd.count() | if cnt == 10: | result = "success" @@ -68,165 +63,220 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit | sc.stop() """.stripMargin - private var yarnCluster: MiniYARNCluster = _ - private var tempDir: File = _ - private var fakeSparkJar: File = _ - private var oldConf: Map[String, String] = _ - - override def beforeAll() { - tempDir = Utils.createTempDir() - - val logConfDir = new File(tempDir, "log4j") - logConfDir.mkdir() - - val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8) - - val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator + - sys.props("java.class.path") - - oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap - - yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(new YarnConfiguration()) - yarnCluster.start() - - // There's a race in MiniYARNCluster in which start() may return before the RM has updated - // its address in the configuration. You can see this in the logs by noticing that when - // MiniYARNCluster prints the address, it still has port "0" assigned, although later the - // test works sometimes: - // - // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 - // - // That log message prints the contents of the RM_ADDRESS config variable. If you check it - // later on, it looks something like this: - // - // INFO YarnClusterSuite: RM address in configuration is blah:42631 - // - // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't - // done so in a timely manner (defined to be 10 seconds). - val config = yarnCluster.getConfig() - val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) - while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { - if (System.currentTimeMillis() > deadline) { - throw new IllegalStateException("Timed out waiting for RM to come up.") - } - logDebug("RM address still not set in configuration, waiting...") - TimeUnit.MILLISECONDS.sleep(100) - } + private val TEST_PYMODULE = """ + |def func(): + | return 42 + """.stripMargin - logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - config.foreach { e => - sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) - } + test("run Spark in yarn-client mode") { + testBasicYarnApp(true) + } - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - sys.props += ("spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - sys.props += ("spark.executorEnv.SPARK_HOME" -> sparkHome) - sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath())) - sys.props += ("spark.executor.instances" -> "1") - sys.props += ("spark.driver.extraClassPath" -> childClasspath) - sys.props += ("spark.executor.extraClassPath" -> childClasspath) + test("run Spark in yarn-cluster mode") { + testBasicYarnApp(false) + } - super.beforeAll() + test("run Spark in yarn-cluster mode unsuccessfully") { + // Don't provide arguments so the driver will fail. + val exception = intercept[SparkException] { + runSpark(false, mainClassName(YarnClusterDriver.getClass)) + fail("Spark application should have failed.") + } } - override def afterAll() { - yarnCluster.stop() - sys.props.retain { case (k, v) => !k.startsWith("spark.") } - sys.props ++= oldConf - super.afterAll() + test("run Python application in yarn-client mode") { + testPySpark(true) } - test("run Spark in yarn-client mode") { - var result = File.createTempFile("result", null, tempDir) - YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) - checkResult(result) + test("run Python application in yarn-cluster mode") { + testPySpark(false) } - test("run Spark in yarn-cluster mode") { - val main = YarnClusterDriver.getClass.getName().stripSuffix("$") - var result = File.createTempFile("result", null, tempDir) - - val args = Array("--class", main, - "--jar", "file:" + fakeSparkJar.getAbsolutePath(), - "--arg", "yarn-cluster", - "--arg", result.getAbsolutePath(), - "--num-executors", "1") - Client.main(args) - checkResult(result) + test("user class path first in client mode") { + testUseClassPathFirst(true) } - test("run Spark in yarn-cluster mode unsuccessfully") { - val main = YarnClusterDriver.getClass.getName().stripSuffix("$") + test("user class path first in cluster mode") { + testUseClassPathFirst(false) + } - // Use only one argument so the driver will fail - val args = Array("--class", main, - "--jar", "file:" + fakeSparkJar.getAbsolutePath(), - "--arg", "yarn-cluster", - "--num-executors", "1") - val exception = intercept[SparkException] { - Client.main(args) - } - assert(Utils.exceptionString(exception).contains("Application finished with failed status")) + private def testBasicYarnApp(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) } - test("run Python application in yarn-cluster mode") { + private def testPySpark(clientMode: Boolean): Unit = { val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, Charsets.UTF_8) - val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, Charsets.UTF_8) - var result = File.createTempFile("result", null, tempDir) - - val args = Array("--class", "org.apache.spark.deploy.PythonRunner", - "--primary-py-file", primaryPyFile.getAbsolutePath(), - "--py-files", pyFile.getAbsolutePath(), - "--arg", "yarn-cluster", - "--arg", result.getAbsolutePath(), - "--name", "python test in yarn-cluster mode", - "--num-executors", "1") - Client.main(args) + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFiles), + appArgs = Seq(result.getAbsolutePath())) checkResult(result) } - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - private def checkResult(result: File) = { - var resultString = Files.toString(result, Charsets.UTF_8) - resultString should be ("success") + private def testUseClassPathFirst(clientMode: Boolean): Unit = { + // Create a jar file that contains a different version of "test.resource". + val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) + val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) + val driverResult = File.createTempFile("driver", null, tempDir) + val executorResult = File.createTempFile("executor", null, tempDir) + runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), + extraClassPath = Seq(originalJar.getPath()), + extraJars = Seq("local:" + userJar.getPath()), + extraConf = Map( + "spark.driver.userClassPathFirst" -> "true", + "spark.executor.userClassPathFirst" -> "true")) + checkResult(driverResult, "OVERRIDDEN") + checkResult(executorResult, "OVERRIDDEN") + } + +} + +private[spark] class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + var driverLogs: Option[collection.Map[String, String]] = None + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo } + override def onApplicationStart(appStart: SparkListenerApplicationStart): Unit = { + driverLogs = appStart.driverLogs + } } private object YarnClusterDriver extends Logging with Matchers { - def main(args: Array[String]) = { - if (args.length != 2) { + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 1) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriver [master] [result file] + |Usage: YarnClusterDriver [result file] """.stripMargin) + // scalastyle:on println System.exit(1) } - val sc = new SparkContext(new SparkConf().setMaster(args(0)) + val sc = new SparkContext(new SparkConf() + .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) - val status = new File(args(1)) + val conf = sc.getConf + val status = new File(args(0)) var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) data should be (Set(1, 2, 3, 4)) result = "success" } finally { sc.stop() - Files.write(result, status, Charsets.UTF_8) + Files.write(result, status, UTF_8) + } + + // verify log urls are present + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) + val executorInfos = listener.addedExecutorInfos.values + assert(executorInfos.nonEmpty) + executorInfos.foreach { info => + assert(info.logUrlMap.nonEmpty) + } + + // If we are running in yarn-cluster mode, verify that driver logs links and present and are + // in the expected format. + if (conf.get("spark.master") == "yarn-cluster") { + assert(listener.driverLogs.nonEmpty) + val driverLogs = listener.driverLogs.get + assert(driverLogs.size === 2) + assert(driverLogs.contains("stderr")) + assert(driverLogs.contains("stdout")) + val urlStr = driverLogs("stderr") + // Ensure that this is a valid URL, else this will throw an exception + new URL(urlStr) + val containerId = YarnSparkHadoopUtil.get.getContainerId + val user = Utils.getCurrentUserName() + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) + } + } + +} + +private object YarnClasspathTest extends Logging { + + var exitCode = 0 + + def error(m: String, ex: Throwable = null): Unit = { + logError(m, ex) + // scalastyle:off println + System.out.println(m) + if (ex != null) { + ex.printStackTrace(System.out) + } + // scalastyle:on println + } + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + error( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: YarnClasspathTest [driver result file] [executor result file] + """.stripMargin) + // scalastyle:on println + } + + readResource(args(0)) + val sc = new SparkContext(new SparkConf()) + try { + sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) } + } finally { + sc.stop() + } + System.exit(exitCode) + } + + private def readResource(resultPath: String): Unit = { + var result = "failure" + try { + val ccl = Thread.currentThread().getContextClassLoader() + val resource = ccl.getResourceAsStream("test.resource") + val bytes = ByteStreams.toByteArray(resource) + result = new String(bytes, 0, bytes.length, UTF_8) + } catch { + case t: Throwable => + error(s"loading test.resource to $resultPath", t) + // set the exit code if not yet set + exitCode = 2 + } finally { + Files.write(result, new File(resultPath), UTF_8) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala new file mode 100644 index 000000000000..5e8238822b90 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -0,0 +1,109 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.Matchers + +import org.apache.spark._ +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} + +/** + * Integration test for the external shuffle service with a yarn mini-cluster + */ +class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = { + val yarnConfig = new YarnConfiguration() + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig + } + + test("external shuffle service") { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + + val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) + + logInfo("Shuffle service port = " + shuffleServicePort) + val result = File.createTempFile("result", null, tempDir) + runSpark( + false, + mainClassName(YarnExternalShuffleDriver.getClass), + appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), + extraConf = Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString + ) + ) + checkResult(result) + assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) + } +} + +private object YarnExternalShuffleDriver extends Logging with Matchers { + + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: ExternalShuffleDriver [result file] [registed exec file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + + val sc = new SparkContext(new SparkConf() + .setAppName("External Shuffle Test")) + val conf = sc.getConf + val status = new File(args(0)) + val registeredExecFile = new File(args(1)) + logInfo("shuffle service executor file = " + registeredExecFile) + var result = "failure" + val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup") + try { + val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }. + collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet) + result = "success" + // only one process can open a leveldb file at a time, so we copy the files + FileUtils.copyDirectory(registeredExecFile, execStateCopy) + assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty) + } finally { + sc.stop() + FileUtils.deleteDirectory(execStateCopy) + Files.write(result, status, UTF_8) + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index b5a2db8f6225..49bee0866dd4 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -20,18 +20,20 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try { @@ -46,11 +48,11 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { logWarning("Cannot execute bash, skipping bash tests.") } - def bashTest(name: String)(fn: => Unit) = + def bashTest(name: String)(fn: => Unit): Unit = if (hasBash) test(name)(fn) else ignore(name)(fn) bashTest("shell script escaping") { - val scriptFile = File.createTempFile("script.", ".sh") + val scriptFile = File.createTempFile("script.", ".sh", Utils.createTempDir()) val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") try { val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") @@ -173,4 +175,62 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { YarnSparkHadoopUtil.getClassPathSeparator() should be (":") } } + + test("check access nns empty") { + val sparkConf = new SparkConf() + val util = new YarnSparkHadoopUtil + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val util = new YarnSparkHadoopUtil + val renewer = util.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val util = new YarnSparkHadoopUtil + val caught = + intercept[SparkException] { + util.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } } diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala new file mode 100644 index 000000000000..aa46ec5100f0 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle + +import java.io.{IOException, File} +import java.util.concurrent.ConcurrentMap + +import com.google.common.annotations.VisibleForTesting +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.fusesource.leveldbjni.JniDBFactory +import org.iq80.leveldb.{DB, Options} + +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +/** + * just a cheat to get package-visible members in tests + */ +object ShuffleTestAccessor { + + def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = { + handler.blockManager + } + + def getExecutorInfo( + appId: ApplicationId, + execId: String, + resolver: ExternalShuffleBlockResolver + ): Option[ExecutorShuffleInfo] = { + val id = new AppExecId(appId.toString, execId) + Option(resolver.executors.get(id)) + } + + def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = { + resolver.registeredExecutorFile + } + + def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = { + resolver.db + } + + def reloadRegisteredExecutors( + file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + val options: Options = new Options + options.createIfMissing(true) + val factory = new JniDBFactory + val db = factory.open(file, options) + val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + db.close() + result + } + + def reloadRegisteredExecutors( + db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala new file mode 100644 index 000000000000..2f22cbdbeac3 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.annotation.tailrec + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration + + override def beforeEach(): Unit = { + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + + yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => + val d = new File(dir) + if (d.exists()) { + FileUtils.deleteDirectory(d) + } + FileUtils.forceMkdir(d) + logInfo(s"creating yarn.nodemanager.local-dirs: $d") + } + } + + var s1: YarnShuffleService = null + var s2: YarnShuffleService = null + var s3: YarnShuffleService = null + + override def afterEach(): Unit = { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } + + test("executor state kept across NM restart") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should + be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should + be (Some(shuffleInfo2)) + + if (!execStateFile.exists()) { + @tailrec def findExistingParent(file: File): File = { + if (file == null) file + else if (file.exists()) file + else findExistingParent(file.getParentFile()) + } + val existingParent = findExistingParent(execStateFile) + assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent") + } + assert(execStateFile.exists(), s"$execStateFile did not exist") + + // now we pretend the shuffle service goes down, and comes back up + s1.stop() + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped + // during the restart + s2.initializeApplication(app1Data) + s2.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None) + + // Act like the NM restarts one more time + s2.stop() + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + // app1 is still running + s3.initializeApplication(app1Data) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None) + s3.stop() + } + + test("removed applications should not be in registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + + val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + + s1.stopApplication(new ApplicationTerminationContext(app1Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + s1.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty + } + + test("shuffle service should be robust to corrupt registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + + val execStateFile = s1.registeredExecutorFile + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + + // now we pretend the shuffle service goes down, and comes back up. But we'll also + // make a corrupt registeredExecutor File + s1.stop() + + execStateFile.listFiles().foreach{_.delete()} + + val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT")) + out.writeInt(42) + out.close() + + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ... + s2.initializeApplication(app1Data) + // however, when we initialize a totally new app2, everything is still happy + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s2.initializeApplication(app2Data) + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2)) + s2.stop() + + // another stop & restart should be fine though (eg., we recover from previous corruption) + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + s3.initializeApplication(app2Data) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2)) + s3.stop() + + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala new file mode 100644 index 000000000000..db322cd18e15 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.File + +/** + * just a cheat to get package-visible members in tests + */ +object YarnTestAccessor { + def getShuffleServicePort: Int = { + YarnShuffleService.boundPort + } + + def getShuffleServiceInstance: YarnShuffleService = { + YarnShuffleService.instance + } + + def getRegisteredExecutorFile(service: YarnShuffleService): File = { + service.registeredExecutorFile + } + +}